perf(word2vec): better word2vec

This commit is contained in:
Mahdi Dibaiee 2016-09-16 18:46:21 +04:30
parent 313e120f25
commit f16cc26798
3 changed files with 21 additions and 14 deletions

View File

@ -16,17 +16,17 @@ module Main where
main = do main = do
sws <- lines <$> readFile "examples/stopwords" sws <- lines <$> readFile "examples/stopwords"
{-ds <- do-} {-ds <- do
{-content <- readFile "examples/doc-classifier-data/data-reuters"-} content <- readFile "examples/doc-classifier-data/data-reuters"
{-let splitted = splitOn (replicate 10 '-' ++ "\n") content-} let splitted = splitOn (replicate 10 '-' ++ "\n") content
{-d = concatMap (tail . lines) (take 100 splitted)-} d = concatMap (tail . lines) (take 100 splitted)
{-return $ removeWords sws d-} return $ removeWords sws d-}
--let ds = ["I like deep learning", "I like NLP", "I enjoy flying"] --let ds = ["I like deep learning", "I like NLP", "I enjoy flying"]
let ds = ["the king loves the queen", "the queen loves the king", let ds = ["the king loves the queen", "the queen loves the king",
"the dwarf hates the king", "the queen hates the dwarf", "the dwarf hates the king", "the queen hates the dwarf",
"the dwarf poisons the king", "the dwarf poisons the queen"] "the dwarf poisons the king", "the dwarf poisons the queen"]
let session = def { learningRate = 0.1 let session = def { learningRate = 5e-2
, batchSize = 1 , batchSize = 1
, epochs = 100 , epochs = 100
, debug = True , debug = True

View File

@ -1,7 +1,8 @@
module Sibe.Utils module Sibe.Utils
(similarity, ( similarity
ordNub, , ordNub
onehot , onehot
, average
) where ) where
import qualified Data.Vector.Storable as V import qualified Data.Vector.Storable as V
import qualified Data.Set as Set import qualified Data.Set as Set
@ -22,3 +23,6 @@ module Sibe.Utils
go _ [] = [] go _ [] = []
go s (x:xs) = if x `Set.member` s then go s xs go s (x:xs) = if x `Set.member` s then go s xs
else x : go (Set.insert x s) xs else x : go (Set.insert x s) xs
average :: Vector Double -> Vector Double
average v = cmap (/ (V.sum v)) v

View File

@ -75,11 +75,14 @@ module Sibe.Word2Vec
| i == length vocvec - 1 = before | i == length vocvec - 1 = before
| otherwise = before ++ after | otherwise = before ++ after
vectorized = map (\w -> snd . fromJust $ find ((== w) . fst) vocvec) ns vectorized = map (\w -> snd . fromJust $ find ((== w) . fst) vocvec) ns
new = cmap (max 1) $ foldl1 (+) vectorized new = foldl1 (+) vectorized
in in
if length wds <= 1
then []
else
case method w2v of case method w2v of
SkipGram -> zip (repeat v) vectorized SkipGram -> [(v, average new)]
CBOW -> zip vectorized (repeat v) CBOW -> [(average new, v)]
_ -> error "unsupported word2vec method" _ -> error "unsupported word2vec method"
cleanText :: String -> String cleanText :: String -> String