feat(naivebayes): implement NaiveBayes algorithm
feat(example): a document classifier using NaiveBayes over reuters data
This commit is contained in:
parent
493a20eb0a
commit
26eb4531fa
29
examples/naivebayes-doc-classifier.hs
Normal file
29
examples/naivebayes-doc-classifier.hs
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
module Main
|
||||||
|
where
|
||||||
|
import Sibe
|
||||||
|
import Sibe.NaiveBayes
|
||||||
|
import Text.Printf
|
||||||
|
import Data.List
|
||||||
|
import Data.Maybe
|
||||||
|
import Debug.Trace
|
||||||
|
|
||||||
|
main = do
|
||||||
|
dataset <- readFile "examples/naivebayes-doc-classifier/data-reuters"
|
||||||
|
test <- readFile "examples/naivebayes-doc-classifier/data-reuters-test"
|
||||||
|
|
||||||
|
classes <- map (filter (/= ' ')) . lines <$> readFile "examples/naivebayes-doc-classifier/data-classes"
|
||||||
|
|
||||||
|
let intClasses = [0..length classes - 1]
|
||||||
|
documents = createDocuments classes dataset
|
||||||
|
testDocuments = createDocuments classes test
|
||||||
|
nb = initialize documents
|
||||||
|
|
||||||
|
let testResults (Document text c) =
|
||||||
|
let r = determine text nb intClasses documents
|
||||||
|
in trace (classes !! c ++ " ~ " ++ classes !! r) c == r
|
||||||
|
|
||||||
|
let results = map testResults testDocuments
|
||||||
|
|
||||||
|
putStr "Accuracy: "
|
||||||
|
putStr . show . round $ (genericLength (filter (==True) results) / genericLength results) * 100
|
||||||
|
putStrLn "%"
|
14
sibe.cabal
14
sibe.cabal
@ -15,11 +15,13 @@ cabal-version: >=1.10
|
|||||||
|
|
||||||
library
|
library
|
||||||
hs-source-dirs: src
|
hs-source-dirs: src
|
||||||
exposed-modules: Sibe
|
exposed-modules: Sibe, Sibe.NaiveBayes
|
||||||
build-depends: base >= 4.7 && < 5
|
build-depends: base >= 4.7 && < 5
|
||||||
, hmatrix
|
, hmatrix
|
||||||
, random
|
, random
|
||||||
, deepseq
|
, deepseq
|
||||||
|
, containers
|
||||||
|
, split
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
|
|
||||||
executable sibe-exe
|
executable sibe-exe
|
||||||
@ -40,6 +42,16 @@ executable example-xor
|
|||||||
, hmatrix
|
, hmatrix
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
|
|
||||||
|
executable example-naivebayes-doc-classifier
|
||||||
|
hs-source-dirs: examples
|
||||||
|
main-is: naivebayes-doc-classifier.hs
|
||||||
|
ghc-options: -threaded -rtsopts -with-rtsopts=-N
|
||||||
|
build-depends: base
|
||||||
|
, sibe
|
||||||
|
, hmatrix
|
||||||
|
, containers
|
||||||
|
default-language: Haskell2010
|
||||||
|
|
||||||
test-suite sibe-test
|
test-suite sibe-test
|
||||||
type: exitcode-stdio-1.0
|
type: exitcode-stdio-1.0
|
||||||
hs-source-dirs: test
|
hs-source-dirs: test
|
||||||
|
30
src/Sibe.hs
30
src/Sibe.hs
@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
module Sibe
|
module Sibe
|
||||||
(Network(..),
|
(Network(..),
|
||||||
Layer,
|
Layer(..),
|
||||||
Input,
|
Input,
|
||||||
Output,
|
Output,
|
||||||
Activation,
|
Activation,
|
||||||
@ -35,13 +35,13 @@ module Sibe
|
|||||||
type Output = Vector Double
|
type Output = Vector Double
|
||||||
type Activation = (Vector Double -> Vector Double, Vector Double -> Vector Double)
|
type Activation = (Vector Double -> Vector Double, Vector Double -> Vector Double)
|
||||||
|
|
||||||
data Layer = L { biases :: !(Vector Double)
|
data Layer = Layer { biases :: !(Vector Double)
|
||||||
, nodes :: !(Matrix Double)
|
, nodes :: !(Matrix Double)
|
||||||
, activation :: Activation
|
, activation :: Activation
|
||||||
}
|
}
|
||||||
|
|
||||||
instance Show Layer where
|
instance Show Layer where
|
||||||
show (L biases nodes _) = "(" ++ show biases ++ "," ++ show nodes ++ ")"
|
show (Layer biases nodes _) = "(" ++ show biases ++ "," ++ show nodes ++ ")"
|
||||||
|
|
||||||
data Network = O Layer
|
data Network = O Layer
|
||||||
| Layer :- Network
|
| Layer :- Network
|
||||||
@ -52,8 +52,8 @@ module Sibe
|
|||||||
saveNetwork network file =
|
saveNetwork network file =
|
||||||
writeFile file ((show . reverse) (gen network []))
|
writeFile file ((show . reverse) (gen network []))
|
||||||
where
|
where
|
||||||
gen (O (L biases nodes _)) list = (biases, nodes) : list
|
gen (O (Layer biases nodes _)) list = (biases, nodes) : list
|
||||||
gen (L biases nodes _ :- n) list = gen n $ (biases, nodes) : list
|
gen (Layer biases nodes _ :- n) list = gen n $ (biases, nodes) : list
|
||||||
|
|
||||||
loadNetwork :: [Activation] -> String -> IO Network
|
loadNetwork :: [Activation] -> String -> IO Network
|
||||||
loadNetwork activations file = do
|
loadNetwork activations file = do
|
||||||
@ -65,21 +65,21 @@ module Sibe
|
|||||||
return network
|
return network
|
||||||
|
|
||||||
where
|
where
|
||||||
gen [(biases, nodes)] [a] = O (L biases nodes a)
|
gen [(biases, nodes)] [a] = O (Layer biases nodes a)
|
||||||
gen ((biases, nodes):hs) (a:as) = L biases nodes a :- gen hs as
|
gen ((biases, nodes):hs) (a:as) = Layer biases nodes a :- gen hs as
|
||||||
|
|
||||||
runLayer :: Input -> Layer -> Output
|
runLayer :: Input -> Layer -> Output
|
||||||
runLayer input (L !biases !weights _) = input <# weights + biases
|
runLayer input (Layer !biases !weights _) = input <# weights + biases
|
||||||
|
|
||||||
forward :: Input -> Network -> Output
|
forward :: Input -> Network -> Output
|
||||||
forward input (O l@(L _ _ (fn, _))) = fn $ runLayer input l
|
forward input (O l@(Layer _ _ (fn, _))) = fn $ runLayer input l
|
||||||
forward input (l@(L _ _ (fn, _)) :- n) = forward ((fst . activation $ l) $ runLayer input l) n
|
forward input (l@(Layer _ _ (fn, _)) :- n) = forward ((fst . activation $ l) $ runLayer input l) n
|
||||||
|
|
||||||
randomLayer :: Seed -> (Int, Int) -> Activation -> Layer
|
randomLayer :: Seed -> (Int, Int) -> Activation -> Layer
|
||||||
randomLayer seed (wr, wc) =
|
randomLayer seed (wr, wc) =
|
||||||
let weights = uniformSample seed wr $ replicate wc (-1, 1)
|
let weights = uniformSample seed wr $ replicate wc (-1, 1)
|
||||||
biases = randomVector seed Uniform wc * 2 - 1
|
biases = randomVector seed Uniform wc * 2 - 1
|
||||||
in L biases weights
|
in Layer biases weights
|
||||||
|
|
||||||
randomNetwork :: Seed -> Int -> [(Int, Activation)] -> (Int, Activation) -> Network
|
randomNetwork :: Seed -> Int -> [(Int, Activation)] -> (Int, Activation) -> Network
|
||||||
randomNetwork seed input [] (output, a) =
|
randomNetwork seed input [] (output, a) =
|
||||||
@ -110,7 +110,7 @@ module Sibe
|
|||||||
train input network target alpha = fst $ run input network
|
train input network target alpha = fst $ run input network
|
||||||
where
|
where
|
||||||
run :: Input -> Network -> (Network, Vector Double)
|
run :: Input -> Network -> (Network, Vector Double)
|
||||||
run input (O l@(L biases weights (fn, fn'))) =
|
run input (O l@(Layer biases weights (fn, fn'))) =
|
||||||
let y = runLayer input l
|
let y = runLayer input l
|
||||||
o = fn y
|
o = fn y
|
||||||
delta = o - target
|
delta = o - target
|
||||||
@ -119,13 +119,13 @@ module Sibe
|
|||||||
|
|
||||||
biases' = biases - scale alpha de
|
biases' = biases - scale alpha de
|
||||||
weights' = weights - scale alpha (input `outer` de) -- small inputs learn slowly
|
weights' = weights - scale alpha (input `outer` de) -- small inputs learn slowly
|
||||||
layer = L biases' weights' (fn, fn') -- updated layer
|
layer = Layer biases' weights' (fn, fn') -- updated layer
|
||||||
|
|
||||||
pass = weights #> de
|
pass = weights #> de
|
||||||
-- pass = weights #> de
|
-- pass = weights #> de
|
||||||
|
|
||||||
in (O layer, pass)
|
in (O layer, pass)
|
||||||
run input (l@(L biases weights (fn, fn')) :- n) =
|
run input (l@(Layer biases weights (fn, fn')) :- n) =
|
||||||
let y = runLayer input l
|
let y = runLayer input l
|
||||||
o = fn y
|
o = fn y
|
||||||
(n', delta) = run o n
|
(n', delta) = run o n
|
||||||
@ -134,7 +134,7 @@ module Sibe
|
|||||||
|
|
||||||
biases' = biases - scale alpha de
|
biases' = biases - scale alpha de
|
||||||
weights' = weights - scale alpha (input `outer` de)
|
weights' = weights - scale alpha (input `outer` de)
|
||||||
layer = L biases' weights' (fn, fn')
|
layer = Layer biases' weights' (fn, fn')
|
||||||
|
|
||||||
pass = weights #> de
|
pass = weights #> de
|
||||||
-- pass = weights #> de
|
-- pass = weights #> de
|
||||||
|
65
src/Sibe/NaiveBayes.hs
Normal file
65
src/Sibe/NaiveBayes.hs
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
module Sibe.NaiveBayes
|
||||||
|
(Document(..),
|
||||||
|
NB(..),
|
||||||
|
createDocuments,
|
||||||
|
initialize,
|
||||||
|
calculate,
|
||||||
|
determine
|
||||||
|
)
|
||||||
|
where
|
||||||
|
import Data.List
|
||||||
|
import Debug.Trace
|
||||||
|
import qualified Data.Set as Set
|
||||||
|
import Data.List.Split
|
||||||
|
import Data.Maybe
|
||||||
|
type Class = Int
|
||||||
|
|
||||||
|
data Document = Document { text :: String
|
||||||
|
, c :: Class
|
||||||
|
} deriving (Eq, Show, Read)
|
||||||
|
|
||||||
|
data NB = NB { vocabulary :: Double
|
||||||
|
, megadoc :: String
|
||||||
|
}
|
||||||
|
|
||||||
|
initialize :: [Document] -> NB
|
||||||
|
initialize documents =
|
||||||
|
let megadoc = concatMap (\(Document text _) -> text ++ " ") documents
|
||||||
|
vocabulary = genericLength ((ordNub . words) megadoc)
|
||||||
|
in NB vocabulary megadoc
|
||||||
|
|
||||||
|
determine :: String -> NB -> [Class] -> [Document] -> Class
|
||||||
|
determine text nb classes documents =
|
||||||
|
let scores = zip [0..] (map (\cls -> calculate text nb cls documents) classes)
|
||||||
|
m = maximumBy (\(i0, c0) (i1, c1) -> c0 `compare` c1) scores
|
||||||
|
in fst m
|
||||||
|
|
||||||
|
calculate :: String -> NB -> Class -> [Document] -> Double
|
||||||
|
calculate text (NB vocabulary megadoc) cls documents =
|
||||||
|
let docs = filter (\(Document text c) -> c == cls) documents
|
||||||
|
texts = map (\(Document text _) -> text ++ " ") docs
|
||||||
|
classText = concat texts
|
||||||
|
classWords = words classText
|
||||||
|
c = genericLength classWords
|
||||||
|
pc = genericLength docs / genericLength documents
|
||||||
|
in pc * product (map (cword classWords c) (words text))
|
||||||
|
where
|
||||||
|
cword classWords c word =
|
||||||
|
let wc = genericLength (filter (==word) classWords)
|
||||||
|
in (wc + 1) / (c + vocabulary)
|
||||||
|
|
||||||
|
createDocuments classes content =
|
||||||
|
let splitted = splitOn (replicate 10 '-' ++ "\n") content
|
||||||
|
pairs = map (\a -> ((head . lines) a, (concat . tail . lines) a)) splitted
|
||||||
|
documents = map (\(topic, text) -> Document text (fromJust $ elemIndex topic classes)) pairs
|
||||||
|
in documents
|
||||||
|
|
||||||
|
l :: (Show a) => a -> a
|
||||||
|
l a = trace (show a) a
|
||||||
|
|
||||||
|
ordNub :: (Ord a) => [a] -> [a]
|
||||||
|
ordNub = go Set.empty
|
||||||
|
where
|
||||||
|
go _ [] = []
|
||||||
|
go s (x:xs) = if x `Set.member` s then go s xs
|
||||||
|
else x : go (Set.insert x s) xs
|
Loading…
Reference in New Issue
Block a user