feat(w2v): draw text charts for words
This commit is contained in:
parent
d9d24f69a6
commit
85971bc84d
@ -17,9 +17,6 @@ module Main where
|
||||
import System.Random.Shuffle
|
||||
import Data.Default.Class
|
||||
|
||||
import qualified Graphics.Rendering.Chart.Easy as Chart
|
||||
import Graphics.Rendering.Chart.Backend.Cairo
|
||||
|
||||
main = do
|
||||
-- random seed, you might comment this line to get real random results
|
||||
setStdGen (mkStdGen 100)
|
||||
|
@ -42,28 +42,25 @@ module Main where
|
||||
|
||||
let ds = ["the king loves the queen", "the queen loves the king",
|
||||
"the dwarf hates the king", "the queen hates the dwarf",
|
||||
"the dwarf poisons the king", "the dwarf poisons the queen"]
|
||||
"the dwarf poisons the king", "the dwarf poisons the queen",
|
||||
"the man loves the woman", "the woman loves the man",
|
||||
"the thief hates the man", "the woman hates the thief",
|
||||
"the thief robs the man", "the thief robs the woman"]
|
||||
|
||||
let session = def { learningRate = 1e-1
|
||||
let session = def { learningRate = 5e-1
|
||||
, batchSize = 1
|
||||
, epochs = 200
|
||||
, epochs = 1000
|
||||
, debug = True
|
||||
} :: Session
|
||||
w2v = def { docs = ds
|
||||
, dimensions = 25
|
||||
, method = SkipGram
|
||||
, window = 2
|
||||
, w2vDrawChart = True
|
||||
, w2vChartName = "w2v.png"
|
||||
} :: Word2Vec
|
||||
|
||||
|
||||
(computed, vocvec) <- word2vec w2v session
|
||||
|
||||
mapM_ (\(w, v) -> do
|
||||
putStr $ w ++ ": "
|
||||
let similarities = map (similarity v . snd) computed
|
||||
let sorted = sortBy (compare `on` similarity v . snd) computed
|
||||
print . take 2 . drop 1 . reverse $ map fst sorted
|
||||
) computed
|
||||
|
||||
return ()
|
||||
|
||||
|
@ -32,6 +32,7 @@ library
|
||||
, data-default-class
|
||||
, Chart
|
||||
, Chart-cairo
|
||||
, lens
|
||||
default-language: Haskell2010
|
||||
|
||||
executable example-xor
|
||||
|
@ -26,7 +26,6 @@ module Sibe
|
||||
sigmoid',
|
||||
softmax,
|
||||
softmax',
|
||||
sampledSoftmax,
|
||||
relu,
|
||||
relu',
|
||||
crossEntropy,
|
||||
@ -183,10 +182,10 @@ module Sibe
|
||||
sig x = 1 / max (1 + exp (-x)) 1e-10
|
||||
|
||||
-- used for negative sampling
|
||||
sampledSoftmax :: Int -> Vector Double -> Vector Double
|
||||
sampledSoftmax n x = cmap (\a -> exp a / s) x
|
||||
where
|
||||
s = V.sum . exp $ V.take n x
|
||||
{-sampledSoftmax :: Vector Double -> Vector Double-}
|
||||
{-sampledSoftmax x = cmap (\a -> exp a / s) x-}
|
||||
{-where-}
|
||||
{-s = V.sum . exp $ x-}
|
||||
|
||||
relu :: Vector Double -> Vector Double
|
||||
relu = cmap (max 0.1)
|
||||
|
@ -16,19 +16,28 @@ module Sibe.Word2Vec
|
||||
import Control.Monad
|
||||
import System.Random
|
||||
|
||||
import Graphics.Rendering.Chart as Chart
|
||||
import Graphics.Rendering.Chart.Backend.Cairo
|
||||
import Control.Lens
|
||||
|
||||
data W2VMethod = SkipGram | CBOW
|
||||
data Word2Vec = Word2Vec { docs :: [String]
|
||||
, window :: Int
|
||||
, dimensions :: Int
|
||||
, method :: W2VMethod
|
||||
, w2vChartName :: String
|
||||
, w2vDrawChart :: Bool
|
||||
}
|
||||
instance Default Word2Vec where
|
||||
def = Word2Vec { docs = []
|
||||
, window = 2
|
||||
, w2vChartName = "w2v.png"
|
||||
, w2vDrawChart = False
|
||||
}
|
||||
|
||||
word2vec w2v session = do
|
||||
seed <- newStdGen
|
||||
|
||||
let s = session { training = trainingData
|
||||
, network = randomNetwork 0 (-1, 1) v [(dimensions w2v, (id, one))] (v, (softmax, crossEntropy'))
|
||||
}
|
||||
@ -49,6 +58,26 @@ module Sibe.Word2Vec
|
||||
-- run words through the hidden layer alone to get the word vector
|
||||
let computedVocVec = map (\(w, v) -> (w, runLayer' v hidden)) vocvec
|
||||
|
||||
when (w2vDrawChart w2v) $ do
|
||||
let mat = fromColumns . map snd $ computedVocVec
|
||||
(u, s, v) = svd mat
|
||||
cut = subMatrix (0, 0) (2, cols mat)
|
||||
diagS = diagRect 0 (V.take 2 s) (rows mat) (cols mat)
|
||||
|
||||
twoDimensions = cut $ u <> diagS <> tr v
|
||||
textData = zipWith (\s l -> (V.head l, V.last l, s)) (map fst computedVocVec) (toColumns twoDimensions)
|
||||
|
||||
chart = toRenderable layout
|
||||
where
|
||||
textP = plot_annotation_values .~ textData
|
||||
$ def
|
||||
layout = layout_title .~ "word vectors"
|
||||
$ layout_plots .~ [toPlot textP]
|
||||
$ def
|
||||
|
||||
renderableToFile def (w2vChartName w2v) chart
|
||||
return ()
|
||||
|
||||
return (computedVocVec, vocvec)
|
||||
where
|
||||
-- clean documents
|
||||
|
Loading…
Reference in New Issue
Block a user