feat(naivebayes): implement NaiveBayes algorithm
feat(example): a document classifier using NaiveBayes over reuters data
This commit is contained in:
36
src/Sibe.hs
36
src/Sibe.hs
@@ -5,7 +5,7 @@
|
||||
|
||||
module Sibe
|
||||
(Network(..),
|
||||
Layer,
|
||||
Layer(..),
|
||||
Input,
|
||||
Output,
|
||||
Activation,
|
||||
@@ -35,13 +35,13 @@ module Sibe
|
||||
type Output = Vector Double
|
||||
type Activation = (Vector Double -> Vector Double, Vector Double -> Vector Double)
|
||||
|
||||
data Layer = L { biases :: !(Vector Double)
|
||||
, nodes :: !(Matrix Double)
|
||||
, activation :: Activation
|
||||
}
|
||||
data Layer = Layer { biases :: !(Vector Double)
|
||||
, nodes :: !(Matrix Double)
|
||||
, activation :: Activation
|
||||
}
|
||||
|
||||
instance Show Layer where
|
||||
show (L biases nodes _) = "(" ++ show biases ++ "," ++ show nodes ++ ")"
|
||||
show (Layer biases nodes _) = "(" ++ show biases ++ "," ++ show nodes ++ ")"
|
||||
|
||||
data Network = O Layer
|
||||
| Layer :- Network
|
||||
@@ -52,8 +52,8 @@ module Sibe
|
||||
saveNetwork network file =
|
||||
writeFile file ((show . reverse) (gen network []))
|
||||
where
|
||||
gen (O (L biases nodes _)) list = (biases, nodes) : list
|
||||
gen (L biases nodes _ :- n) list = gen n $ (biases, nodes) : list
|
||||
gen (O (Layer biases nodes _)) list = (biases, nodes) : list
|
||||
gen (Layer biases nodes _ :- n) list = gen n $ (biases, nodes) : list
|
||||
|
||||
loadNetwork :: [Activation] -> String -> IO Network
|
||||
loadNetwork activations file = do
|
||||
@@ -65,21 +65,21 @@ module Sibe
|
||||
return network
|
||||
|
||||
where
|
||||
gen [(biases, nodes)] [a] = O (L biases nodes a)
|
||||
gen ((biases, nodes):hs) (a:as) = L biases nodes a :- gen hs as
|
||||
gen [(biases, nodes)] [a] = O (Layer biases nodes a)
|
||||
gen ((biases, nodes):hs) (a:as) = Layer biases nodes a :- gen hs as
|
||||
|
||||
runLayer :: Input -> Layer -> Output
|
||||
runLayer input (L !biases !weights _) = input <# weights + biases
|
||||
runLayer input (Layer !biases !weights _) = input <# weights + biases
|
||||
|
||||
forward :: Input -> Network -> Output
|
||||
forward input (O l@(L _ _ (fn, _))) = fn $ runLayer input l
|
||||
forward input (l@(L _ _ (fn, _)) :- n) = forward ((fst . activation $ l) $ runLayer input l) n
|
||||
forward input (O l@(Layer _ _ (fn, _))) = fn $ runLayer input l
|
||||
forward input (l@(Layer _ _ (fn, _)) :- n) = forward ((fst . activation $ l) $ runLayer input l) n
|
||||
|
||||
randomLayer :: Seed -> (Int, Int) -> Activation -> Layer
|
||||
randomLayer seed (wr, wc) =
|
||||
let weights = uniformSample seed wr $ replicate wc (-1, 1)
|
||||
biases = randomVector seed Uniform wc * 2 - 1
|
||||
in L biases weights
|
||||
in Layer biases weights
|
||||
|
||||
randomNetwork :: Seed -> Int -> [(Int, Activation)] -> (Int, Activation) -> Network
|
||||
randomNetwork seed input [] (output, a) =
|
||||
@@ -110,7 +110,7 @@ module Sibe
|
||||
train input network target alpha = fst $ run input network
|
||||
where
|
||||
run :: Input -> Network -> (Network, Vector Double)
|
||||
run input (O l@(L biases weights (fn, fn'))) =
|
||||
run input (O l@(Layer biases weights (fn, fn'))) =
|
||||
let y = runLayer input l
|
||||
o = fn y
|
||||
delta = o - target
|
||||
@@ -119,13 +119,13 @@ module Sibe
|
||||
|
||||
biases' = biases - scale alpha de
|
||||
weights' = weights - scale alpha (input `outer` de) -- small inputs learn slowly
|
||||
layer = L biases' weights' (fn, fn') -- updated layer
|
||||
layer = Layer biases' weights' (fn, fn') -- updated layer
|
||||
|
||||
pass = weights #> de
|
||||
-- pass = weights #> de
|
||||
|
||||
in (O layer, pass)
|
||||
run input (l@(L biases weights (fn, fn')) :- n) =
|
||||
run input (l@(Layer biases weights (fn, fn')) :- n) =
|
||||
let y = runLayer input l
|
||||
o = fn y
|
||||
(n', delta) = run o n
|
||||
@@ -134,7 +134,7 @@ module Sibe
|
||||
|
||||
biases' = biases - scale alpha de
|
||||
weights' = weights - scale alpha (input `outer` de)
|
||||
layer = L biases' weights' (fn, fn')
|
||||
layer = Layer biases' weights' (fn, fn')
|
||||
|
||||
pass = weights #> de
|
||||
-- pass = weights #> de
|
||||
|
||||
Reference in New Issue
Block a user