feat(w2v): draw text charts for words
This commit is contained in:
		| @@ -17,9 +17,6 @@ module Main where | ||||
|   import System.Random.Shuffle | ||||
|   import Data.Default.Class | ||||
|  | ||||
|   import qualified Graphics.Rendering.Chart.Easy as Chart | ||||
|   import Graphics.Rendering.Chart.Backend.Cairo | ||||
|  | ||||
|   main = do | ||||
|     -- random seed, you might comment this line to get real random results | ||||
|     setStdGen (mkStdGen 100) | ||||
|   | ||||
| @@ -42,28 +42,25 @@ module Main where | ||||
|  | ||||
|     let ds = ["the king loves the queen", "the queen loves the king", | ||||
|               "the dwarf hates the king", "the queen hates the dwarf", | ||||
|               "the dwarf poisons the king", "the dwarf poisons the queen"] | ||||
|               "the dwarf poisons the king", "the dwarf poisons the queen", | ||||
|               "the man loves the woman", "the woman loves the man", | ||||
|               "the thief hates the man", "the woman hates the thief", | ||||
|               "the thief robs the man", "the thief robs the woman"] | ||||
|  | ||||
|     let session = def { learningRate = 1e-1 | ||||
|     let session = def { learningRate = 5e-1 | ||||
|                       , batchSize = 1 | ||||
|                       , epochs = 200 | ||||
|                       , epochs = 1000 | ||||
|                       , debug = True | ||||
|                       } :: Session | ||||
|         w2v = def { docs = ds | ||||
|                   , dimensions = 25 | ||||
|                   , method = SkipGram | ||||
|                   , window = 2 | ||||
|                   , w2vDrawChart = True | ||||
|                   , w2vChartName = "w2v.png" | ||||
|                   } :: Word2Vec | ||||
|  | ||||
|  | ||||
|     (computed, vocvec) <- word2vec w2v session | ||||
|      | ||||
|     mapM_ (\(w, v) -> do | ||||
|                     putStr $ w ++ ": " | ||||
|                     let similarities = map (similarity v . snd) computed | ||||
|                     let sorted = sortBy (compare `on` similarity v . snd) computed | ||||
|                     print . take 2 . drop 1 . reverse $ map fst sorted | ||||
|           ) computed | ||||
|  | ||||
|     return () | ||||
|  | ||||
|   | ||||
| @@ -32,6 +32,7 @@ library | ||||
|                      , data-default-class | ||||
|                      , Chart | ||||
|                      , Chart-cairo | ||||
|                      , lens | ||||
|   default-language:    Haskell2010 | ||||
|  | ||||
| executable example-xor | ||||
|   | ||||
| @@ -26,7 +26,6 @@ module Sibe | ||||
|      sigmoid', | ||||
|      softmax, | ||||
|      softmax', | ||||
|      sampledSoftmax, | ||||
|      relu, | ||||
|      relu', | ||||
|      crossEntropy, | ||||
| @@ -183,10 +182,10 @@ module Sibe | ||||
|           sig x = 1 / max (1 + exp (-x)) 1e-10 | ||||
|  | ||||
|       -- used for negative sampling | ||||
|       sampledSoftmax :: Int -> Vector Double -> Vector Double | ||||
|       sampledSoftmax n x = cmap (\a -> exp a / s) x | ||||
|         where | ||||
|           s = V.sum . exp $ V.take n x | ||||
|       {-sampledSoftmax :: Vector Double -> Vector Double-} | ||||
|       {-sampledSoftmax x = cmap (\a -> exp a / s) x-} | ||||
|         {-where-} | ||||
|           {-s = V.sum . exp $ x-} | ||||
|  | ||||
|       relu :: Vector Double -> Vector Double | ||||
|       relu = cmap (max 0.1) | ||||
|   | ||||
| @@ -16,19 +16,28 @@ module Sibe.Word2Vec | ||||
|     import Control.Monad | ||||
|     import System.Random | ||||
|  | ||||
|     import Graphics.Rendering.Chart as Chart | ||||
|     import Graphics.Rendering.Chart.Backend.Cairo | ||||
|     import Control.Lens | ||||
|  | ||||
|     data W2VMethod = SkipGram | CBOW | ||||
|     data Word2Vec = Word2Vec { docs :: [String] | ||||
|                              , window :: Int | ||||
|                              , dimensions :: Int | ||||
|                              , method :: W2VMethod | ||||
|                              , w2vChartName :: String | ||||
|                              , w2vDrawChart :: Bool | ||||
|                              } | ||||
|     instance Default Word2Vec where | ||||
|       def = Word2Vec { docs = [] | ||||
|                      , window = 2 | ||||
|                      , w2vChartName = "w2v.png" | ||||
|                      , w2vDrawChart = False | ||||
|                      } | ||||
|  | ||||
|     word2vec w2v session = do | ||||
|       seed <- newStdGen | ||||
|  | ||||
|       let s = session { training = trainingData | ||||
|                       , network = randomNetwork 0 (-1, 1) v [(dimensions w2v, (id, one))] (v, (softmax, crossEntropy')) | ||||
|                       } | ||||
| @@ -49,6 +58,26 @@ module Sibe.Word2Vec | ||||
|       -- run words through the hidden layer alone to get the word vector | ||||
|       let computedVocVec = map (\(w, v) -> (w, runLayer' v hidden)) vocvec | ||||
|  | ||||
|       when (w2vDrawChart w2v) $ do | ||||
|         let mat = fromColumns . map snd $ computedVocVec | ||||
|             (u, s, v) = svd mat | ||||
|             cut = subMatrix (0, 0) (2, cols mat) | ||||
|             diagS = diagRect 0 (V.take 2 s) (rows mat) (cols mat) | ||||
|  | ||||
|             twoDimensions = cut $ u <> diagS <> tr v | ||||
|             textData = zipWith (\s l -> (V.head l, V.last l, s)) (map fst computedVocVec) (toColumns twoDimensions) | ||||
|  | ||||
|             chart = toRenderable layout | ||||
|               where | ||||
|                 textP  = plot_annotation_values .~ textData | ||||
|                       $ def | ||||
|                 layout = layout_title .~ "word vectors" | ||||
|                       $ layout_plots .~ [toPlot textP] | ||||
|                       $ def | ||||
|                      | ||||
|         renderableToFile def (w2vChartName w2v) chart | ||||
|         return () | ||||
|  | ||||
|       return (computedVocVec, vocvec) | ||||
|       where | ||||
|         -- clean documents | ||||
|   | ||||
		Reference in New Issue
	
	Block a user