diff --git a/examples/notmnist.hs b/examples/notmnist.hs index 80d384c..865ed55 100644 --- a/examples/notmnist.hs +++ b/examples/notmnist.hs @@ -17,9 +17,6 @@ module Main where import System.Random.Shuffle import Data.Default.Class - import qualified Graphics.Rendering.Chart.Easy as Chart - import Graphics.Rendering.Chart.Backend.Cairo - main = do -- random seed, you might comment this line to get real random results setStdGen (mkStdGen 100) diff --git a/examples/word2vec.hs b/examples/word2vec.hs index 1669c41..96875cc 100644 --- a/examples/word2vec.hs +++ b/examples/word2vec.hs @@ -42,28 +42,25 @@ module Main where let ds = ["the king loves the queen", "the queen loves the king", "the dwarf hates the king", "the queen hates the dwarf", - "the dwarf poisons the king", "the dwarf poisons the queen"] + "the dwarf poisons the king", "the dwarf poisons the queen", + "the man loves the woman", "the woman loves the man", + "the thief hates the man", "the woman hates the thief", + "the thief robs the man", "the thief robs the woman"] - let session = def { learningRate = 1e-1 + let session = def { learningRate = 5e-1 , batchSize = 1 - , epochs = 200 + , epochs = 1000 , debug = True } :: Session w2v = def { docs = ds , dimensions = 25 , method = SkipGram , window = 2 + , w2vDrawChart = True + , w2vChartName = "w2v.png" } :: Word2Vec - (computed, vocvec) <- word2vec w2v session - - mapM_ (\(w, v) -> do - putStr $ w ++ ": " - let similarities = map (similarity v . snd) computed - let sorted = sortBy (compare `on` similarity v . snd) computed - print . take 2 . drop 1 . reverse $ map fst sorted - ) computed return () diff --git a/sibe.cabal b/sibe.cabal index 5bde54d..0f84197 100644 --- a/sibe.cabal +++ b/sibe.cabal @@ -32,6 +32,7 @@ library , data-default-class , Chart , Chart-cairo + , lens default-language: Haskell2010 executable example-xor diff --git a/src/Sibe.hs b/src/Sibe.hs index 7460aab..f4ccede 100644 --- a/src/Sibe.hs +++ b/src/Sibe.hs @@ -26,7 +26,6 @@ module Sibe sigmoid', softmax, softmax', - sampledSoftmax, relu, relu', crossEntropy, @@ -183,10 +182,10 @@ module Sibe sig x = 1 / max (1 + exp (-x)) 1e-10 -- used for negative sampling - sampledSoftmax :: Int -> Vector Double -> Vector Double - sampledSoftmax n x = cmap (\a -> exp a / s) x - where - s = V.sum . exp $ V.take n x + {-sampledSoftmax :: Vector Double -> Vector Double-} + {-sampledSoftmax x = cmap (\a -> exp a / s) x-} + {-where-} + {-s = V.sum . exp $ x-} relu :: Vector Double -> Vector Double relu = cmap (max 0.1) diff --git a/src/Sibe/Word2Vec.hs b/src/Sibe/Word2Vec.hs index b18f8ae..7c52669 100644 --- a/src/Sibe/Word2Vec.hs +++ b/src/Sibe/Word2Vec.hs @@ -16,19 +16,28 @@ module Sibe.Word2Vec import Control.Monad import System.Random + import Graphics.Rendering.Chart as Chart + import Graphics.Rendering.Chart.Backend.Cairo + import Control.Lens + data W2VMethod = SkipGram | CBOW data Word2Vec = Word2Vec { docs :: [String] , window :: Int , dimensions :: Int , method :: W2VMethod + , w2vChartName :: String + , w2vDrawChart :: Bool } instance Default Word2Vec where def = Word2Vec { docs = [] , window = 2 + , w2vChartName = "w2v.png" + , w2vDrawChart = False } word2vec w2v session = do seed <- newStdGen + let s = session { training = trainingData , network = randomNetwork 0 (-1, 1) v [(dimensions w2v, (id, one))] (v, (softmax, crossEntropy')) } @@ -49,6 +58,26 @@ module Sibe.Word2Vec -- run words through the hidden layer alone to get the word vector let computedVocVec = map (\(w, v) -> (w, runLayer' v hidden)) vocvec + when (w2vDrawChart w2v) $ do + let mat = fromColumns . map snd $ computedVocVec + (u, s, v) = svd mat + cut = subMatrix (0, 0) (2, cols mat) + diagS = diagRect 0 (V.take 2 s) (rows mat) (cols mat) + + twoDimensions = cut $ u <> diagS <> tr v + textData = zipWith (\s l -> (V.head l, V.last l, s)) (map fst computedVocVec) (toColumns twoDimensions) + + chart = toRenderable layout + where + textP = plot_annotation_values .~ textData + $ def + layout = layout_title .~ "word vectors" + $ layout_plots .~ [toPlot textP] + $ def + + renderableToFile def (w2vChartName w2v) chart + return () + return (computedVocVec, vocvec) where -- clean documents diff --git a/w2v.png b/w2v.png new file mode 100644 index 0000000..0183921 Binary files /dev/null and b/w2v.png differ