feat(w2v): draw text charts for words
This commit is contained in:
parent
d9d24f69a6
commit
85971bc84d
@ -17,9 +17,6 @@ module Main where
|
|||||||
import System.Random.Shuffle
|
import System.Random.Shuffle
|
||||||
import Data.Default.Class
|
import Data.Default.Class
|
||||||
|
|
||||||
import qualified Graphics.Rendering.Chart.Easy as Chart
|
|
||||||
import Graphics.Rendering.Chart.Backend.Cairo
|
|
||||||
|
|
||||||
main = do
|
main = do
|
||||||
-- random seed, you might comment this line to get real random results
|
-- random seed, you might comment this line to get real random results
|
||||||
setStdGen (mkStdGen 100)
|
setStdGen (mkStdGen 100)
|
||||||
|
@ -42,29 +42,26 @@ module Main where
|
|||||||
|
|
||||||
let ds = ["the king loves the queen", "the queen loves the king",
|
let ds = ["the king loves the queen", "the queen loves the king",
|
||||||
"the dwarf hates the king", "the queen hates the dwarf",
|
"the dwarf hates the king", "the queen hates the dwarf",
|
||||||
"the dwarf poisons the king", "the dwarf poisons the queen"]
|
"the dwarf poisons the king", "the dwarf poisons the queen",
|
||||||
|
"the man loves the woman", "the woman loves the man",
|
||||||
|
"the thief hates the man", "the woman hates the thief",
|
||||||
|
"the thief robs the man", "the thief robs the woman"]
|
||||||
|
|
||||||
let session = def { learningRate = 1e-1
|
let session = def { learningRate = 5e-1
|
||||||
, batchSize = 1
|
, batchSize = 1
|
||||||
, epochs = 200
|
, epochs = 1000
|
||||||
, debug = True
|
, debug = True
|
||||||
} :: Session
|
} :: Session
|
||||||
w2v = def { docs = ds
|
w2v = def { docs = ds
|
||||||
, dimensions = 25
|
, dimensions = 25
|
||||||
, method = SkipGram
|
, method = SkipGram
|
||||||
, window = 2
|
, window = 2
|
||||||
|
, w2vDrawChart = True
|
||||||
|
, w2vChartName = "w2v.png"
|
||||||
} :: Word2Vec
|
} :: Word2Vec
|
||||||
|
|
||||||
|
|
||||||
(computed, vocvec) <- word2vec w2v session
|
(computed, vocvec) <- word2vec w2v session
|
||||||
|
|
||||||
mapM_ (\(w, v) -> do
|
|
||||||
putStr $ w ++ ": "
|
|
||||||
let similarities = map (similarity v . snd) computed
|
|
||||||
let sorted = sortBy (compare `on` similarity v . snd) computed
|
|
||||||
print . take 2 . drop 1 . reverse $ map fst sorted
|
|
||||||
) computed
|
|
||||||
|
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
cleanText :: String -> String
|
cleanText :: String -> String
|
||||||
|
@ -32,6 +32,7 @@ library
|
|||||||
, data-default-class
|
, data-default-class
|
||||||
, Chart
|
, Chart
|
||||||
, Chart-cairo
|
, Chart-cairo
|
||||||
|
, lens
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
|
|
||||||
executable example-xor
|
executable example-xor
|
||||||
|
@ -26,7 +26,6 @@ module Sibe
|
|||||||
sigmoid',
|
sigmoid',
|
||||||
softmax,
|
softmax,
|
||||||
softmax',
|
softmax',
|
||||||
sampledSoftmax,
|
|
||||||
relu,
|
relu,
|
||||||
relu',
|
relu',
|
||||||
crossEntropy,
|
crossEntropy,
|
||||||
@ -183,10 +182,10 @@ module Sibe
|
|||||||
sig x = 1 / max (1 + exp (-x)) 1e-10
|
sig x = 1 / max (1 + exp (-x)) 1e-10
|
||||||
|
|
||||||
-- used for negative sampling
|
-- used for negative sampling
|
||||||
sampledSoftmax :: Int -> Vector Double -> Vector Double
|
{-sampledSoftmax :: Vector Double -> Vector Double-}
|
||||||
sampledSoftmax n x = cmap (\a -> exp a / s) x
|
{-sampledSoftmax x = cmap (\a -> exp a / s) x-}
|
||||||
where
|
{-where-}
|
||||||
s = V.sum . exp $ V.take n x
|
{-s = V.sum . exp $ x-}
|
||||||
|
|
||||||
relu :: Vector Double -> Vector Double
|
relu :: Vector Double -> Vector Double
|
||||||
relu = cmap (max 0.1)
|
relu = cmap (max 0.1)
|
||||||
|
@ -16,19 +16,28 @@ module Sibe.Word2Vec
|
|||||||
import Control.Monad
|
import Control.Monad
|
||||||
import System.Random
|
import System.Random
|
||||||
|
|
||||||
|
import Graphics.Rendering.Chart as Chart
|
||||||
|
import Graphics.Rendering.Chart.Backend.Cairo
|
||||||
|
import Control.Lens
|
||||||
|
|
||||||
data W2VMethod = SkipGram | CBOW
|
data W2VMethod = SkipGram | CBOW
|
||||||
data Word2Vec = Word2Vec { docs :: [String]
|
data Word2Vec = Word2Vec { docs :: [String]
|
||||||
, window :: Int
|
, window :: Int
|
||||||
, dimensions :: Int
|
, dimensions :: Int
|
||||||
, method :: W2VMethod
|
, method :: W2VMethod
|
||||||
|
, w2vChartName :: String
|
||||||
|
, w2vDrawChart :: Bool
|
||||||
}
|
}
|
||||||
instance Default Word2Vec where
|
instance Default Word2Vec where
|
||||||
def = Word2Vec { docs = []
|
def = Word2Vec { docs = []
|
||||||
, window = 2
|
, window = 2
|
||||||
|
, w2vChartName = "w2v.png"
|
||||||
|
, w2vDrawChart = False
|
||||||
}
|
}
|
||||||
|
|
||||||
word2vec w2v session = do
|
word2vec w2v session = do
|
||||||
seed <- newStdGen
|
seed <- newStdGen
|
||||||
|
|
||||||
let s = session { training = trainingData
|
let s = session { training = trainingData
|
||||||
, network = randomNetwork 0 (-1, 1) v [(dimensions w2v, (id, one))] (v, (softmax, crossEntropy'))
|
, network = randomNetwork 0 (-1, 1) v [(dimensions w2v, (id, one))] (v, (softmax, crossEntropy'))
|
||||||
}
|
}
|
||||||
@ -49,6 +58,26 @@ module Sibe.Word2Vec
|
|||||||
-- run words through the hidden layer alone to get the word vector
|
-- run words through the hidden layer alone to get the word vector
|
||||||
let computedVocVec = map (\(w, v) -> (w, runLayer' v hidden)) vocvec
|
let computedVocVec = map (\(w, v) -> (w, runLayer' v hidden)) vocvec
|
||||||
|
|
||||||
|
when (w2vDrawChart w2v) $ do
|
||||||
|
let mat = fromColumns . map snd $ computedVocVec
|
||||||
|
(u, s, v) = svd mat
|
||||||
|
cut = subMatrix (0, 0) (2, cols mat)
|
||||||
|
diagS = diagRect 0 (V.take 2 s) (rows mat) (cols mat)
|
||||||
|
|
||||||
|
twoDimensions = cut $ u <> diagS <> tr v
|
||||||
|
textData = zipWith (\s l -> (V.head l, V.last l, s)) (map fst computedVocVec) (toColumns twoDimensions)
|
||||||
|
|
||||||
|
chart = toRenderable layout
|
||||||
|
where
|
||||||
|
textP = plot_annotation_values .~ textData
|
||||||
|
$ def
|
||||||
|
layout = layout_title .~ "word vectors"
|
||||||
|
$ layout_plots .~ [toPlot textP]
|
||||||
|
$ def
|
||||||
|
|
||||||
|
renderableToFile def (w2vChartName w2v) chart
|
||||||
|
return ()
|
||||||
|
|
||||||
return (computedVocVec, vocvec)
|
return (computedVocVec, vocvec)
|
||||||
where
|
where
|
||||||
-- clean documents
|
-- clean documents
|
||||||
|
Loading…
Reference in New Issue
Block a user