feat(w2v): draw text charts for words

This commit is contained in:
Mahdi Dibaiee 2016-10-01 12:24:36 +03:30
parent d9d24f69a6
commit 85971bc84d
6 changed files with 42 additions and 19 deletions

View File

@ -17,9 +17,6 @@ module Main where
import System.Random.Shuffle import System.Random.Shuffle
import Data.Default.Class import Data.Default.Class
import qualified Graphics.Rendering.Chart.Easy as Chart
import Graphics.Rendering.Chart.Backend.Cairo
main = do main = do
-- random seed, you might comment this line to get real random results -- random seed, you might comment this line to get real random results
setStdGen (mkStdGen 100) setStdGen (mkStdGen 100)

View File

@ -42,29 +42,26 @@ module Main where
let ds = ["the king loves the queen", "the queen loves the king", let ds = ["the king loves the queen", "the queen loves the king",
"the dwarf hates the king", "the queen hates the dwarf", "the dwarf hates the king", "the queen hates the dwarf",
"the dwarf poisons the king", "the dwarf poisons the queen"] "the dwarf poisons the king", "the dwarf poisons the queen",
"the man loves the woman", "the woman loves the man",
"the thief hates the man", "the woman hates the thief",
"the thief robs the man", "the thief robs the woman"]
let session = def { learningRate = 1e-1 let session = def { learningRate = 5e-1
, batchSize = 1 , batchSize = 1
, epochs = 200 , epochs = 1000
, debug = True , debug = True
} :: Session } :: Session
w2v = def { docs = ds w2v = def { docs = ds
, dimensions = 25 , dimensions = 25
, method = SkipGram , method = SkipGram
, window = 2 , window = 2
, w2vDrawChart = True
, w2vChartName = "w2v.png"
} :: Word2Vec } :: Word2Vec
(computed, vocvec) <- word2vec w2v session (computed, vocvec) <- word2vec w2v session
mapM_ (\(w, v) -> do
putStr $ w ++ ": "
let similarities = map (similarity v . snd) computed
let sorted = sortBy (compare `on` similarity v . snd) computed
print . take 2 . drop 1 . reverse $ map fst sorted
) computed
return () return ()
cleanText :: String -> String cleanText :: String -> String

View File

@ -32,6 +32,7 @@ library
, data-default-class , data-default-class
, Chart , Chart
, Chart-cairo , Chart-cairo
, lens
default-language: Haskell2010 default-language: Haskell2010
executable example-xor executable example-xor

View File

@ -26,7 +26,6 @@ module Sibe
sigmoid', sigmoid',
softmax, softmax,
softmax', softmax',
sampledSoftmax,
relu, relu,
relu', relu',
crossEntropy, crossEntropy,
@ -183,10 +182,10 @@ module Sibe
sig x = 1 / max (1 + exp (-x)) 1e-10 sig x = 1 / max (1 + exp (-x)) 1e-10
-- used for negative sampling -- used for negative sampling
sampledSoftmax :: Int -> Vector Double -> Vector Double {-sampledSoftmax :: Vector Double -> Vector Double-}
sampledSoftmax n x = cmap (\a -> exp a / s) x {-sampledSoftmax x = cmap (\a -> exp a / s) x-}
where {-where-}
s = V.sum . exp $ V.take n x {-s = V.sum . exp $ x-}
relu :: Vector Double -> Vector Double relu :: Vector Double -> Vector Double
relu = cmap (max 0.1) relu = cmap (max 0.1)

View File

@ -16,19 +16,28 @@ module Sibe.Word2Vec
import Control.Monad import Control.Monad
import System.Random import System.Random
import Graphics.Rendering.Chart as Chart
import Graphics.Rendering.Chart.Backend.Cairo
import Control.Lens
data W2VMethod = SkipGram | CBOW data W2VMethod = SkipGram | CBOW
data Word2Vec = Word2Vec { docs :: [String] data Word2Vec = Word2Vec { docs :: [String]
, window :: Int , window :: Int
, dimensions :: Int , dimensions :: Int
, method :: W2VMethod , method :: W2VMethod
, w2vChartName :: String
, w2vDrawChart :: Bool
} }
instance Default Word2Vec where instance Default Word2Vec where
def = Word2Vec { docs = [] def = Word2Vec { docs = []
, window = 2 , window = 2
, w2vChartName = "w2v.png"
, w2vDrawChart = False
} }
word2vec w2v session = do word2vec w2v session = do
seed <- newStdGen seed <- newStdGen
let s = session { training = trainingData let s = session { training = trainingData
, network = randomNetwork 0 (-1, 1) v [(dimensions w2v, (id, one))] (v, (softmax, crossEntropy')) , network = randomNetwork 0 (-1, 1) v [(dimensions w2v, (id, one))] (v, (softmax, crossEntropy'))
} }
@ -49,6 +58,26 @@ module Sibe.Word2Vec
-- run words through the hidden layer alone to get the word vector -- run words through the hidden layer alone to get the word vector
let computedVocVec = map (\(w, v) -> (w, runLayer' v hidden)) vocvec let computedVocVec = map (\(w, v) -> (w, runLayer' v hidden)) vocvec
when (w2vDrawChart w2v) $ do
let mat = fromColumns . map snd $ computedVocVec
(u, s, v) = svd mat
cut = subMatrix (0, 0) (2, cols mat)
diagS = diagRect 0 (V.take 2 s) (rows mat) (cols mat)
twoDimensions = cut $ u <> diagS <> tr v
textData = zipWith (\s l -> (V.head l, V.last l, s)) (map fst computedVocVec) (toColumns twoDimensions)
chart = toRenderable layout
where
textP = plot_annotation_values .~ textData
$ def
layout = layout_title .~ "word vectors"
$ layout_plots .~ [toPlot textP]
$ def
renderableToFile def (w2vChartName w2v) chart
return ()
return (computedVocVec, vocvec) return (computedVocVec, vocvec)
where where
-- clean documents -- clean documents

BIN
w2v.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB