fix(word2vec): simple example of word2vec

This commit is contained in:
Mahdi Dibaiee 2016-09-16 14:03:15 +04:30
parent d4ac90bbd5
commit 0d43814448
4 changed files with 12 additions and 7 deletions

View File

@ -27,13 +27,14 @@ module Main where
"the dwarf poisons the king", "the dwarf poisons the queen"] "the dwarf poisons the king", "the dwarf poisons the queen"]
let session = def { learningRate = 0.1 let session = def { learningRate = 0.1
, batchSize = 16 , batchSize = 1
, epochs = 100 , epochs = 100
, debug = True
} :: Session } :: Session
w2v = def { docs = ds w2v = def { docs = ds
, dimensions = 50 , dimensions = 50
, method = SkipGram , method = SkipGram
, window = 3 , window = 2
} :: Word2Vec } :: Word2Vec

BIN
sgd.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 31 KiB

After

Width:  |  Height:  |  Size: 12 KiB

View File

@ -86,6 +86,7 @@ module Sibe
, batchSize :: Int , batchSize :: Int
, chart :: [(Int, Double, Double)] , chart :: [(Int, Double, Double)]
, momentum :: Double , momentum :: Double
, debug :: Bool
} deriving (Show) } deriving (Show)
emptyNetwork = randomNetwork 0 (0, 0) 0 [] (0, (id, id)) emptyNetwork = randomNetwork 0 (0, 0) 0 [] (0, (id, id))
@ -99,6 +100,7 @@ module Sibe
, batchSize = 0 , batchSize = 0
, chart = [] , chart = []
, momentum = 0 , momentum = 0
, debug = False
} }
saveNetwork :: Network -> String -> IO () saveNetwork :: Network -> String -> IO ()

View File

@ -13,6 +13,7 @@ module Sibe.Word2Vec
import qualified Data.Vector.Storable as V import qualified Data.Vector.Storable as V
import Data.Default.Class import Data.Default.Class
import Data.Function (on) import Data.Function (on)
import Control.Monad
data W2VMethod = SkipGram | CBOW data W2VMethod = SkipGram | CBOW
data Word2Vec = Word2Vec { docs :: [String] data Word2Vec = Word2Vec { docs :: [String]
@ -30,6 +31,7 @@ module Sibe.Word2Vec
, network = randomNetwork 0 (-1, 1) v [(dimensions w2v, (id, one))] (v, (softmax, one)) , network = randomNetwork 0 (-1, 1) v [(dimensions w2v, (id, one))] (v, (softmax, one))
} }
when (debug s) $ do
putStr "vocabulary size: " putStr "vocabulary size: "
print v print v
@ -73,7 +75,7 @@ module Sibe.Word2Vec
| i == length vocvec - 1 = before | i == length vocvec - 1 = before
| otherwise = before ++ after | otherwise = before ++ after
vectorized = map (\w -> snd . fromJust $ find ((== w) . fst) vocvec) ns vectorized = map (\w -> snd . fromJust $ find ((== w) . fst) vocvec) ns
new = foldl1 (+) vectorized new = cmap (max 1) $ foldl1 (+) vectorized
in in
case method w2v of case method w2v of
SkipGram -> zip (repeat v) vectorized SkipGram -> zip (repeat v) vectorized