fix(word2vec): simple example of word2vec
This commit is contained in:
parent
d4ac90bbd5
commit
0d43814448
@ -27,13 +27,14 @@ module Main where
|
|||||||
"the dwarf poisons the king", "the dwarf poisons the queen"]
|
"the dwarf poisons the king", "the dwarf poisons the queen"]
|
||||||
|
|
||||||
let session = def { learningRate = 0.1
|
let session = def { learningRate = 0.1
|
||||||
, batchSize = 16
|
, batchSize = 1
|
||||||
, epochs = 100
|
, epochs = 100
|
||||||
|
, debug = True
|
||||||
} :: Session
|
} :: Session
|
||||||
w2v = def { docs = ds
|
w2v = def { docs = ds
|
||||||
, dimensions = 50
|
, dimensions = 50
|
||||||
, method = SkipGram
|
, method = SkipGram
|
||||||
, window = 3
|
, window = 2
|
||||||
} :: Word2Vec
|
} :: Word2Vec
|
||||||
|
|
||||||
|
|
||||||
|
BIN
sgd.png
BIN
sgd.png
Binary file not shown.
Before Width: | Height: | Size: 31 KiB After Width: | Height: | Size: 12 KiB |
@ -86,6 +86,7 @@ module Sibe
|
|||||||
, batchSize :: Int
|
, batchSize :: Int
|
||||||
, chart :: [(Int, Double, Double)]
|
, chart :: [(Int, Double, Double)]
|
||||||
, momentum :: Double
|
, momentum :: Double
|
||||||
|
, debug :: Bool
|
||||||
} deriving (Show)
|
} deriving (Show)
|
||||||
|
|
||||||
emptyNetwork = randomNetwork 0 (0, 0) 0 [] (0, (id, id))
|
emptyNetwork = randomNetwork 0 (0, 0) 0 [] (0, (id, id))
|
||||||
@ -99,6 +100,7 @@ module Sibe
|
|||||||
, batchSize = 0
|
, batchSize = 0
|
||||||
, chart = []
|
, chart = []
|
||||||
, momentum = 0
|
, momentum = 0
|
||||||
|
, debug = False
|
||||||
}
|
}
|
||||||
|
|
||||||
saveNetwork :: Network -> String -> IO ()
|
saveNetwork :: Network -> String -> IO ()
|
||||||
|
@ -13,6 +13,7 @@ module Sibe.Word2Vec
|
|||||||
import qualified Data.Vector.Storable as V
|
import qualified Data.Vector.Storable as V
|
||||||
import Data.Default.Class
|
import Data.Default.Class
|
||||||
import Data.Function (on)
|
import Data.Function (on)
|
||||||
|
import Control.Monad
|
||||||
|
|
||||||
data W2VMethod = SkipGram | CBOW
|
data W2VMethod = SkipGram | CBOW
|
||||||
data Word2Vec = Word2Vec { docs :: [String]
|
data Word2Vec = Word2Vec { docs :: [String]
|
||||||
@ -30,11 +31,12 @@ module Sibe.Word2Vec
|
|||||||
, network = randomNetwork 0 (-1, 1) v [(dimensions w2v, (id, one))] (v, (softmax, one))
|
, network = randomNetwork 0 (-1, 1) v [(dimensions w2v, (id, one))] (v, (softmax, one))
|
||||||
}
|
}
|
||||||
|
|
||||||
putStr "vocabulary size: "
|
when (debug s) $ do
|
||||||
print v
|
putStr "vocabulary size: "
|
||||||
|
print v
|
||||||
|
|
||||||
putStr "trainingData length: "
|
putStr "trainingData length: "
|
||||||
print . length $ trainingData
|
print . length $ trainingData
|
||||||
|
|
||||||
-- biases are not used in skipgram/cbow
|
-- biases are not used in skipgram/cbow
|
||||||
newses <- run (sgd . ignoreBiases) s
|
newses <- run (sgd . ignoreBiases) s
|
||||||
@ -73,7 +75,7 @@ module Sibe.Word2Vec
|
|||||||
| i == length vocvec - 1 = before
|
| i == length vocvec - 1 = before
|
||||||
| otherwise = before ++ after
|
| otherwise = before ++ after
|
||||||
vectorized = map (\w -> snd . fromJust $ find ((== w) . fst) vocvec) ns
|
vectorized = map (\w -> snd . fromJust $ find ((== w) . fst) vocvec) ns
|
||||||
new = foldl1 (+) vectorized
|
new = cmap (max 1) $ foldl1 (+) vectorized
|
||||||
in
|
in
|
||||||
case method w2v of
|
case method w2v of
|
||||||
SkipGram -> zip (repeat v) vectorized
|
SkipGram -> zip (repeat v) vectorized
|
||||||
|
Loading…
Reference in New Issue
Block a user