diff --git a/examples/word2vec.hs b/examples/word2vec.hs index f0b45e5..f42ce14 100644 --- a/examples/word2vec.hs +++ b/examples/word2vec.hs @@ -27,13 +27,14 @@ module Main where "the dwarf poisons the king", "the dwarf poisons the queen"] let session = def { learningRate = 0.1 - , batchSize = 16 + , batchSize = 1 , epochs = 100 + , debug = True } :: Session w2v = def { docs = ds , dimensions = 50 , method = SkipGram - , window = 3 + , window = 2 } :: Word2Vec diff --git a/sgd.png b/sgd.png index 1f69105..54e3396 100644 Binary files a/sgd.png and b/sgd.png differ diff --git a/src/Sibe.hs b/src/Sibe.hs index ba70ed2..7260111 100644 --- a/src/Sibe.hs +++ b/src/Sibe.hs @@ -86,6 +86,7 @@ module Sibe , batchSize :: Int , chart :: [(Int, Double, Double)] , momentum :: Double + , debug :: Bool } deriving (Show) emptyNetwork = randomNetwork 0 (0, 0) 0 [] (0, (id, id)) @@ -99,6 +100,7 @@ module Sibe , batchSize = 0 , chart = [] , momentum = 0 + , debug = False } saveNetwork :: Network -> String -> IO () diff --git a/src/Sibe/Word2Vec.hs b/src/Sibe/Word2Vec.hs index 53b80d2..cbd1d53 100644 --- a/src/Sibe/Word2Vec.hs +++ b/src/Sibe/Word2Vec.hs @@ -13,6 +13,7 @@ module Sibe.Word2Vec import qualified Data.Vector.Storable as V import Data.Default.Class import Data.Function (on) + import Control.Monad data W2VMethod = SkipGram | CBOW data Word2Vec = Word2Vec { docs :: [String] @@ -30,11 +31,12 @@ module Sibe.Word2Vec , network = randomNetwork 0 (-1, 1) v [(dimensions w2v, (id, one))] (v, (softmax, one)) } - putStr "vocabulary size: " - print v + when (debug s) $ do + putStr "vocabulary size: " + print v - putStr "trainingData length: " - print . length $ trainingData + putStr "trainingData length: " + print . length $ trainingData -- biases are not used in skipgram/cbow newses <- run (sgd . ignoreBiases) s @@ -73,7 +75,7 @@ module Sibe.Word2Vec | i == length vocvec - 1 = before | otherwise = before ++ after vectorized = map (\w -> snd . fromJust $ find ((== w) . fst) vocvec) ns - new = foldl1 (+) vectorized + new = cmap (max 1) $ foldl1 (+) vectorized in case method w2v of SkipGram -> zip (repeat v) vectorized