feat(naivebayes): implement NaiveBayes algorithm
feat(example): a document classifier using NaiveBayes over reuters data
This commit is contained in:
parent
493a20eb0a
commit
26eb4531fa
29
examples/naivebayes-doc-classifier.hs
Normal file
29
examples/naivebayes-doc-classifier.hs
Normal file
@ -0,0 +1,29 @@
|
||||
module Main
|
||||
where
|
||||
import Sibe
|
||||
import Sibe.NaiveBayes
|
||||
import Text.Printf
|
||||
import Data.List
|
||||
import Data.Maybe
|
||||
import Debug.Trace
|
||||
|
||||
main = do
|
||||
dataset <- readFile "examples/naivebayes-doc-classifier/data-reuters"
|
||||
test <- readFile "examples/naivebayes-doc-classifier/data-reuters-test"
|
||||
|
||||
classes <- map (filter (/= ' ')) . lines <$> readFile "examples/naivebayes-doc-classifier/data-classes"
|
||||
|
||||
let intClasses = [0..length classes - 1]
|
||||
documents = createDocuments classes dataset
|
||||
testDocuments = createDocuments classes test
|
||||
nb = initialize documents
|
||||
|
||||
let testResults (Document text c) =
|
||||
let r = determine text nb intClasses documents
|
||||
in trace (classes !! c ++ " ~ " ++ classes !! r) c == r
|
||||
|
||||
let results = map testResults testDocuments
|
||||
|
||||
putStr "Accuracy: "
|
||||
putStr . show . round $ (genericLength (filter (==True) results) / genericLength results) * 100
|
||||
putStrLn "%"
|
14
sibe.cabal
14
sibe.cabal
@ -15,11 +15,13 @@ cabal-version: >=1.10
|
||||
|
||||
library
|
||||
hs-source-dirs: src
|
||||
exposed-modules: Sibe
|
||||
exposed-modules: Sibe, Sibe.NaiveBayes
|
||||
build-depends: base >= 4.7 && < 5
|
||||
, hmatrix
|
||||
, random
|
||||
, deepseq
|
||||
, containers
|
||||
, split
|
||||
default-language: Haskell2010
|
||||
|
||||
executable sibe-exe
|
||||
@ -40,6 +42,16 @@ executable example-xor
|
||||
, hmatrix
|
||||
default-language: Haskell2010
|
||||
|
||||
executable example-naivebayes-doc-classifier
|
||||
hs-source-dirs: examples
|
||||
main-is: naivebayes-doc-classifier.hs
|
||||
ghc-options: -threaded -rtsopts -with-rtsopts=-N
|
||||
build-depends: base
|
||||
, sibe
|
||||
, hmatrix
|
||||
, containers
|
||||
default-language: Haskell2010
|
||||
|
||||
test-suite sibe-test
|
||||
type: exitcode-stdio-1.0
|
||||
hs-source-dirs: test
|
||||
|
36
src/Sibe.hs
36
src/Sibe.hs
@ -5,7 +5,7 @@
|
||||
|
||||
module Sibe
|
||||
(Network(..),
|
||||
Layer,
|
||||
Layer(..),
|
||||
Input,
|
||||
Output,
|
||||
Activation,
|
||||
@ -35,13 +35,13 @@ module Sibe
|
||||
type Output = Vector Double
|
||||
type Activation = (Vector Double -> Vector Double, Vector Double -> Vector Double)
|
||||
|
||||
data Layer = L { biases :: !(Vector Double)
|
||||
, nodes :: !(Matrix Double)
|
||||
, activation :: Activation
|
||||
}
|
||||
data Layer = Layer { biases :: !(Vector Double)
|
||||
, nodes :: !(Matrix Double)
|
||||
, activation :: Activation
|
||||
}
|
||||
|
||||
instance Show Layer where
|
||||
show (L biases nodes _) = "(" ++ show biases ++ "," ++ show nodes ++ ")"
|
||||
show (Layer biases nodes _) = "(" ++ show biases ++ "," ++ show nodes ++ ")"
|
||||
|
||||
data Network = O Layer
|
||||
| Layer :- Network
|
||||
@ -52,8 +52,8 @@ module Sibe
|
||||
saveNetwork network file =
|
||||
writeFile file ((show . reverse) (gen network []))
|
||||
where
|
||||
gen (O (L biases nodes _)) list = (biases, nodes) : list
|
||||
gen (L biases nodes _ :- n) list = gen n $ (biases, nodes) : list
|
||||
gen (O (Layer biases nodes _)) list = (biases, nodes) : list
|
||||
gen (Layer biases nodes _ :- n) list = gen n $ (biases, nodes) : list
|
||||
|
||||
loadNetwork :: [Activation] -> String -> IO Network
|
||||
loadNetwork activations file = do
|
||||
@ -65,21 +65,21 @@ module Sibe
|
||||
return network
|
||||
|
||||
where
|
||||
gen [(biases, nodes)] [a] = O (L biases nodes a)
|
||||
gen ((biases, nodes):hs) (a:as) = L biases nodes a :- gen hs as
|
||||
gen [(biases, nodes)] [a] = O (Layer biases nodes a)
|
||||
gen ((biases, nodes):hs) (a:as) = Layer biases nodes a :- gen hs as
|
||||
|
||||
runLayer :: Input -> Layer -> Output
|
||||
runLayer input (L !biases !weights _) = input <# weights + biases
|
||||
runLayer input (Layer !biases !weights _) = input <# weights + biases
|
||||
|
||||
forward :: Input -> Network -> Output
|
||||
forward input (O l@(L _ _ (fn, _))) = fn $ runLayer input l
|
||||
forward input (l@(L _ _ (fn, _)) :- n) = forward ((fst . activation $ l) $ runLayer input l) n
|
||||
forward input (O l@(Layer _ _ (fn, _))) = fn $ runLayer input l
|
||||
forward input (l@(Layer _ _ (fn, _)) :- n) = forward ((fst . activation $ l) $ runLayer input l) n
|
||||
|
||||
randomLayer :: Seed -> (Int, Int) -> Activation -> Layer
|
||||
randomLayer seed (wr, wc) =
|
||||
let weights = uniformSample seed wr $ replicate wc (-1, 1)
|
||||
biases = randomVector seed Uniform wc * 2 - 1
|
||||
in L biases weights
|
||||
in Layer biases weights
|
||||
|
||||
randomNetwork :: Seed -> Int -> [(Int, Activation)] -> (Int, Activation) -> Network
|
||||
randomNetwork seed input [] (output, a) =
|
||||
@ -110,7 +110,7 @@ module Sibe
|
||||
train input network target alpha = fst $ run input network
|
||||
where
|
||||
run :: Input -> Network -> (Network, Vector Double)
|
||||
run input (O l@(L biases weights (fn, fn'))) =
|
||||
run input (O l@(Layer biases weights (fn, fn'))) =
|
||||
let y = runLayer input l
|
||||
o = fn y
|
||||
delta = o - target
|
||||
@ -119,13 +119,13 @@ module Sibe
|
||||
|
||||
biases' = biases - scale alpha de
|
||||
weights' = weights - scale alpha (input `outer` de) -- small inputs learn slowly
|
||||
layer = L biases' weights' (fn, fn') -- updated layer
|
||||
layer = Layer biases' weights' (fn, fn') -- updated layer
|
||||
|
||||
pass = weights #> de
|
||||
-- pass = weights #> de
|
||||
|
||||
in (O layer, pass)
|
||||
run input (l@(L biases weights (fn, fn')) :- n) =
|
||||
run input (l@(Layer biases weights (fn, fn')) :- n) =
|
||||
let y = runLayer input l
|
||||
o = fn y
|
||||
(n', delta) = run o n
|
||||
@ -134,7 +134,7 @@ module Sibe
|
||||
|
||||
biases' = biases - scale alpha de
|
||||
weights' = weights - scale alpha (input `outer` de)
|
||||
layer = L biases' weights' (fn, fn')
|
||||
layer = Layer biases' weights' (fn, fn')
|
||||
|
||||
pass = weights #> de
|
||||
-- pass = weights #> de
|
||||
|
65
src/Sibe/NaiveBayes.hs
Normal file
65
src/Sibe/NaiveBayes.hs
Normal file
@ -0,0 +1,65 @@
|
||||
module Sibe.NaiveBayes
|
||||
(Document(..),
|
||||
NB(..),
|
||||
createDocuments,
|
||||
initialize,
|
||||
calculate,
|
||||
determine
|
||||
)
|
||||
where
|
||||
import Data.List
|
||||
import Debug.Trace
|
||||
import qualified Data.Set as Set
|
||||
import Data.List.Split
|
||||
import Data.Maybe
|
||||
type Class = Int
|
||||
|
||||
data Document = Document { text :: String
|
||||
, c :: Class
|
||||
} deriving (Eq, Show, Read)
|
||||
|
||||
data NB = NB { vocabulary :: Double
|
||||
, megadoc :: String
|
||||
}
|
||||
|
||||
initialize :: [Document] -> NB
|
||||
initialize documents =
|
||||
let megadoc = concatMap (\(Document text _) -> text ++ " ") documents
|
||||
vocabulary = genericLength ((ordNub . words) megadoc)
|
||||
in NB vocabulary megadoc
|
||||
|
||||
determine :: String -> NB -> [Class] -> [Document] -> Class
|
||||
determine text nb classes documents =
|
||||
let scores = zip [0..] (map (\cls -> calculate text nb cls documents) classes)
|
||||
m = maximumBy (\(i0, c0) (i1, c1) -> c0 `compare` c1) scores
|
||||
in fst m
|
||||
|
||||
calculate :: String -> NB -> Class -> [Document] -> Double
|
||||
calculate text (NB vocabulary megadoc) cls documents =
|
||||
let docs = filter (\(Document text c) -> c == cls) documents
|
||||
texts = map (\(Document text _) -> text ++ " ") docs
|
||||
classText = concat texts
|
||||
classWords = words classText
|
||||
c = genericLength classWords
|
||||
pc = genericLength docs / genericLength documents
|
||||
in pc * product (map (cword classWords c) (words text))
|
||||
where
|
||||
cword classWords c word =
|
||||
let wc = genericLength (filter (==word) classWords)
|
||||
in (wc + 1) / (c + vocabulary)
|
||||
|
||||
createDocuments classes content =
|
||||
let splitted = splitOn (replicate 10 '-' ++ "\n") content
|
||||
pairs = map (\a -> ((head . lines) a, (concat . tail . lines) a)) splitted
|
||||
documents = map (\(topic, text) -> Document text (fromJust $ elemIndex topic classes)) pairs
|
||||
in documents
|
||||
|
||||
l :: (Show a) => a -> a
|
||||
l a = trace (show a) a
|
||||
|
||||
ordNub :: (Ord a) => [a] -> [a]
|
||||
ordNub = go Set.empty
|
||||
where
|
||||
go _ [] = []
|
||||
go s (x:xs) = if x `Set.member` s then go s xs
|
||||
else x : go (Set.insert x s) xs
|
Loading…
Reference in New Issue
Block a user