feat(notmnist): notmnist example using SGD + learning rate decay
This commit is contained in:
0
src/Sibe/LogisticRegression.hs
Normal file
0
src/Sibe/LogisticRegression.hs
Normal file
129
src/Sibe/NLP.hs
Normal file
129
src/Sibe/NLP.hs
Normal file
@ -0,0 +1,129 @@
|
||||
module Sibe.NLP
|
||||
(Class,
|
||||
Document(..),
|
||||
ordNub,
|
||||
accuracy,
|
||||
recall,
|
||||
precision,
|
||||
fmeasure,
|
||||
cleanText,
|
||||
cleanDocuments,
|
||||
removeWords,
|
||||
removeStopwords,
|
||||
ngram,
|
||||
ngramText,
|
||||
)
|
||||
where
|
||||
import Data.List
|
||||
import Debug.Trace
|
||||
import qualified Data.Set as Set
|
||||
import Data.List.Split
|
||||
import Data.Maybe
|
||||
import Control.Arrow ((&&&))
|
||||
import Text.Regex.PCRE
|
||||
import Data.Char (isSpace, isNumber, toLower)
|
||||
import NLP.Stemmer
|
||||
|
||||
type Class = Int;
|
||||
|
||||
data Document = Document { text :: String
|
||||
, c :: Class
|
||||
} deriving (Eq, Show, Read)
|
||||
|
||||
|
||||
cleanText :: String -> String
|
||||
cleanText string =
|
||||
let puncs = filter (`notElem` ['!', '"', '#', '$', '%', '(', ')', '.', '?']) (trim string)
|
||||
spacify = foldl (\acc x -> replace x ' ' acc) puncs [',', '/', '-', '\n', '\r']
|
||||
stemmed = unwords $ map (stem Porter) (words spacify)
|
||||
nonumber = filter (not . isNumber) stemmed
|
||||
lower = map toLower nonumber
|
||||
in (unwords . words) lower -- remove unnecessary spaces
|
||||
where
|
||||
trim = f . f
|
||||
where
|
||||
f = reverse . dropWhile isSpace
|
||||
replace needle replacement =
|
||||
map (\c -> if c == needle then replacement else c)
|
||||
|
||||
cleanDocuments :: [Document] -> [Document]
|
||||
cleanDocuments documents =
|
||||
let cleaned = map (\(Document text c) -> Document (cleanText text) c) documents
|
||||
in cleaned
|
||||
|
||||
removeWords :: [String] -> [Document] -> [Document]
|
||||
removeWords ws documents =
|
||||
map (\(Document text c) -> Document (rm ws text) c) documents
|
||||
where
|
||||
rm list text =
|
||||
unwords $ filter (`notElem` list) (words text)
|
||||
|
||||
removeStopwords :: Int -> [Document] -> [Document]
|
||||
removeStopwords i documents =
|
||||
let wc = wordCounts (concatDocs documents)
|
||||
wlist = sortBy (\(_, a) (_, b) -> b `compare` a) wc
|
||||
stopwords = map fst (take i wlist)
|
||||
in removeWords stopwords documents
|
||||
where
|
||||
vocabulary x = ordNub (words x)
|
||||
countWordInDoc d w = genericLength (filter (==w) d)
|
||||
wordCounts x =
|
||||
let voc = vocabulary x
|
||||
in zip voc $ map (countWordInDoc (words x)) voc
|
||||
|
||||
concatDocs = concatMap (\(Document text _) -> text ++ " ")
|
||||
|
||||
ordNub :: (Ord a) => [a] -> [a]
|
||||
ordNub = go Set.empty
|
||||
where
|
||||
go _ [] = []
|
||||
go s (x:xs) = if x `Set.member` s then go s xs
|
||||
else x : go (Set.insert x s) xs
|
||||
|
||||
accuracy :: [(Int, (Int, Double))] -> Double
|
||||
accuracy results =
|
||||
let pairs = map (\(a, b) -> (a, fst b)) results
|
||||
correct = filter (uncurry (==)) pairs
|
||||
in genericLength correct / genericLength results
|
||||
|
||||
recall :: [(Int, (Int, Double))] -> Double
|
||||
recall results =
|
||||
let classes = ordNub (map fst results)
|
||||
s = sum (map rec classes) / genericLength classes
|
||||
in s
|
||||
where
|
||||
rec a =
|
||||
let t = genericLength $ filter (\(c, (r, _)) -> c == r && c == a) results
|
||||
y = genericLength $ filter (\(c, (r, _)) -> c == a) results
|
||||
in t / y
|
||||
|
||||
precision :: [(Int, (Int, Double))] -> Double
|
||||
precision results =
|
||||
let classes = ordNub (map fst results)
|
||||
s = sum (map prec classes) / genericLength classes
|
||||
in s
|
||||
where
|
||||
prec a =
|
||||
let t = genericLength $ filter (\(c, (r, _)) -> c == r && c == a) results
|
||||
y = genericLength $ filter (\(c, (r, _)) -> r == a) results
|
||||
in
|
||||
if y == 0
|
||||
then 0
|
||||
else t / y
|
||||
|
||||
fmeasure :: [(Int, (Int, Double))] -> Double
|
||||
fmeasure results =
|
||||
let r = recall results
|
||||
p = precision results
|
||||
in (2 * p * r) / (p + r)
|
||||
|
||||
ngram :: Int -> [Document] -> [Document]
|
||||
ngram n documents =
|
||||
map (\(Document text c) -> Document (ngramText n text) c) documents
|
||||
|
||||
ngramText :: Int -> String -> String
|
||||
ngramText n text =
|
||||
let ws = words text
|
||||
pairs = zip [0..] ws
|
||||
grams = map (\(i, w) -> concat . intersperse "_" $ w:((take (n - 1) . drop (i+1)) ws)) pairs
|
||||
in unwords ("<b>_":grams)
|
@ -1,7 +1,7 @@
|
||||
module Sibe.NaiveBayes
|
||||
(Document(..),
|
||||
NB(..),
|
||||
train,
|
||||
initialize,
|
||||
run,
|
||||
session,
|
||||
ordNub,
|
||||
@ -19,21 +19,13 @@ module Sibe.NaiveBayes
|
||||
removeStopwords,
|
||||
)
|
||||
where
|
||||
import Sibe.NLP
|
||||
import Data.List
|
||||
import Debug.Trace
|
||||
import qualified Data.Set as Set
|
||||
import Data.List.Split
|
||||
import Data.Maybe
|
||||
import Control.Arrow ((&&&))
|
||||
import Text.Regex.PCRE
|
||||
import Data.Char (isSpace, isNumber, toLower)
|
||||
import NLP.Stemmer
|
||||
|
||||
type Class = Int;
|
||||
|
||||
data Document = Document { text :: String
|
||||
, c :: Class
|
||||
} deriving (Eq, Show, Read)
|
||||
|
||||
data NB = NB { documents :: [Document]
|
||||
, classes :: [(Class, Double)]
|
||||
@ -44,8 +36,8 @@ module Sibe.NaiveBayes
|
||||
, cgram :: [(Class, [(String, Int)])]
|
||||
} deriving (Eq, Show, Read)
|
||||
|
||||
train :: [Document] -> [Class] -> NB
|
||||
train documents classes =
|
||||
initialize :: [Document] -> [Class] -> NB
|
||||
initialize documents classes =
|
||||
let megadoc = concatDocs documents
|
||||
vocabulary = genericLength ((ordNub . words) megadoc)
|
||||
-- (class, prior probability)
|
||||
@ -83,17 +75,6 @@ module Sibe.NaiveBayes
|
||||
classWordsCounts x = wordsCount (classWords x) (classVocabulary x)
|
||||
classNGramCounts x = wordsCount (classNGramWords x) (ordNub $ classNGramWords x)
|
||||
|
||||
ngram :: Int -> [Document] -> [Document]
|
||||
ngram n documents =
|
||||
map (\(Document text c) -> Document (ngramText n text) c) documents
|
||||
|
||||
ngramText :: Int -> String -> String
|
||||
ngramText n text =
|
||||
let ws = words text
|
||||
pairs = zip [0..] ws
|
||||
grams = map (\(i, w) -> concat . intersperse "_" $ w:((take (n - 1) . drop (i+1)) ws)) pairs
|
||||
in unwords ("<b>_":grams)
|
||||
|
||||
session :: [Document] -> NB -> [(Class, (Class, Double))]
|
||||
session docs nb =
|
||||
let results = map (\(Document text c) -> (c, run text nb)) docs
|
||||
@ -143,91 +124,5 @@ module Sibe.NaiveBayes
|
||||
variance = sum (map ((^2) . subtract avg) x) / (genericLength x - 1)
|
||||
in sqrt variance
|
||||
|
||||
cleanText :: String -> String
|
||||
cleanText string =
|
||||
let puncs = filter (`notElem` ['!', '"', '#', '$', '%', '(', ')', '.', '?']) (trim string)
|
||||
spacify = foldl (\acc x -> replace x ' ' acc) puncs [',', '/', '-', '\n', '\r']
|
||||
stemmed = unwords $ map (stem Porter) (words spacify)
|
||||
nonumber = filter (not . isNumber) stemmed
|
||||
lower = map toLower nonumber
|
||||
in (unwords . words) lower -- remove unnecessary spaces
|
||||
where
|
||||
trim = f . f
|
||||
where
|
||||
f = reverse . dropWhile isSpace
|
||||
replace needle replacement =
|
||||
map (\c -> if c == needle then replacement else c)
|
||||
|
||||
cleanDocuments :: [Document] -> [Document]
|
||||
cleanDocuments documents =
|
||||
let cleaned = map (\(Document text c) -> Document (cleanText text) c) documents
|
||||
in cleaned
|
||||
|
||||
removeWords :: [String] -> [Document] -> [Document]
|
||||
removeWords ws documents =
|
||||
map (\(Document text c) -> Document (rm ws text) c) documents
|
||||
where
|
||||
rm list text =
|
||||
unwords $ filter (`notElem` list) (words text)
|
||||
|
||||
removeStopwords :: Int -> [Document] -> [Document]
|
||||
removeStopwords i documents =
|
||||
let wc = wordCounts (concatDocs documents)
|
||||
wlist = sortBy (\(_, a) (_, b) -> b `compare` a) wc
|
||||
stopwords = map fst (take i wlist)
|
||||
in removeWords stopwords documents
|
||||
where
|
||||
vocabulary x = ordNub (words x)
|
||||
countWordInDoc d w = genericLength (filter (==w) d)
|
||||
wordCounts x =
|
||||
let voc = vocabulary x
|
||||
in zip voc $ map (countWordInDoc (words x)) voc
|
||||
|
||||
concatDocs = concatMap (\(Document text _) -> text ++ " ")
|
||||
|
||||
l :: (Show a) => a -> a
|
||||
l a = trace (show a) a
|
||||
|
||||
ordNub :: (Ord a) => [a] -> [a]
|
||||
ordNub = go Set.empty
|
||||
where
|
||||
go _ [] = []
|
||||
go s (x:xs) = if x `Set.member` s then go s xs
|
||||
else x : go (Set.insert x s) xs
|
||||
|
||||
accuracy :: [(Int, (Int, Double))] -> Double
|
||||
accuracy results =
|
||||
let pairs = map (\(a, b) -> (a, fst b)) results
|
||||
correct = filter (uncurry (==)) pairs
|
||||
in genericLength correct / genericLength results
|
||||
|
||||
recall :: [(Int, (Int, Double))] -> Double
|
||||
recall results =
|
||||
let classes = ordNub (map fst results)
|
||||
s = sum (map rec classes) / genericLength classes
|
||||
in s
|
||||
where
|
||||
rec a =
|
||||
let t = genericLength $ filter (\(c, (r, _)) -> c == r && c == a) results
|
||||
y = genericLength $ filter (\(c, (r, _)) -> c == a) results
|
||||
in t / y
|
||||
|
||||
precision :: [(Int, (Int, Double))] -> Double
|
||||
precision results =
|
||||
let classes = ordNub (map fst results)
|
||||
s = sum (map prec classes) / genericLength classes
|
||||
in s
|
||||
where
|
||||
prec a =
|
||||
let t = genericLength $ filter (\(c, (r, _)) -> c == r && c == a) results
|
||||
y = genericLength $ filter (\(c, (r, _)) -> r == a) results
|
||||
in
|
||||
if y == 0
|
||||
then 0
|
||||
else t / y
|
||||
|
||||
fmeasure :: [(Int, (Int, Double))] -> Double
|
||||
fmeasure results =
|
||||
let r = recall results
|
||||
p = precision results
|
||||
in (2 * p * r) / (p + r)
|
||||
|
Reference in New Issue
Block a user