diff --git a/examples/word2vec.hs b/examples/word2vec.hs index f42ce14..921fa88 100644 --- a/examples/word2vec.hs +++ b/examples/word2vec.hs @@ -16,17 +16,17 @@ module Main where main = do sws <- lines <$> readFile "examples/stopwords" - {-ds <- do-} - {-content <- readFile "examples/doc-classifier-data/data-reuters"-} - {-let splitted = splitOn (replicate 10 '-' ++ "\n") content-} - {-d = concatMap (tail . lines) (take 100 splitted)-} - {-return $ removeWords sws d-} + {-ds <- do + content <- readFile "examples/doc-classifier-data/data-reuters" + let splitted = splitOn (replicate 10 '-' ++ "\n") content + d = concatMap (tail . lines) (take 100 splitted) + return $ removeWords sws d-} --let ds = ["I like deep learning", "I like NLP", "I enjoy flying"] let ds = ["the king loves the queen", "the queen loves the king", "the dwarf hates the king", "the queen hates the dwarf", "the dwarf poisons the king", "the dwarf poisons the queen"] - let session = def { learningRate = 0.1 + let session = def { learningRate = 5e-2 , batchSize = 1 , epochs = 100 , debug = True diff --git a/src/Sibe/Utils.hs b/src/Sibe/Utils.hs index ea3fd24..72108a6 100644 --- a/src/Sibe/Utils.hs +++ b/src/Sibe/Utils.hs @@ -1,7 +1,8 @@ module Sibe.Utils - (similarity, - ordNub, - onehot + ( similarity + , ordNub + , onehot + , average ) where import qualified Data.Vector.Storable as V import qualified Data.Set as Set @@ -22,3 +23,6 @@ module Sibe.Utils go _ [] = [] go s (x:xs) = if x `Set.member` s then go s xs else x : go (Set.insert x s) xs + + average :: Vector Double -> Vector Double + average v = cmap (/ (V.sum v)) v diff --git a/src/Sibe/Word2Vec.hs b/src/Sibe/Word2Vec.hs index cbd1d53..94c3ea7 100644 --- a/src/Sibe/Word2Vec.hs +++ b/src/Sibe/Word2Vec.hs @@ -75,12 +75,15 @@ module Sibe.Word2Vec | i == length vocvec - 1 = before | otherwise = before ++ after vectorized = map (\w -> snd . fromJust $ find ((== w) . fst) vocvec) ns - new = cmap (max 1) $ foldl1 (+) vectorized + new = foldl1 (+) vectorized in - case method w2v of - SkipGram -> zip (repeat v) vectorized - CBOW -> zip vectorized (repeat v) - _ -> error "unsupported word2vec method" + if length wds <= 1 + then [] + else + case method w2v of + SkipGram -> [(v, average new)] + CBOW -> [(average new, v)] + _ -> error "unsupported word2vec method" cleanText :: String -> String cleanText string =