feat(pca): implement PCA and visualize data using it
This commit is contained in:
parent
7f90afba7f
commit
6da6e4fd09
20
README.md
20
README.md
@ -87,11 +87,27 @@ the thief robs the man
|
||||
the thief robs the woman
|
||||
```
|
||||
|
||||
The computed vectors are transformed to two dimensions using SVD:
|
||||
The computed vectors are transformed to two dimensions using PCA:
|
||||
|
||||
`king` and `queen` have a relation with `man` and `woman`, `love` and `hate` are close to each other,
|
||||
and `dwarf` and `thief` have a relation with `poisons` and `robs`, also, `dwarf` is close to `queen` and `king` while
|
||||
`thief` is closer to `man` and `woman`. `the` doesn't relate to anything.
|
||||
![word2vec results](https://raw.githubusercontent.com/mdibaiee/sibe/master/w2v.png)
|
||||
|
||||
This is a very small dataset and I have to test it on larger datasets.
|
||||
_You can reproduce this result using these parameters:_
|
||||
```haskell
|
||||
let session = def { learningRate = 0.1
|
||||
, batchSize = 1
|
||||
, epochs = 10000
|
||||
, debug = True
|
||||
} :: Session
|
||||
w2v = def { docs = ds
|
||||
, dimensions = 30
|
||||
, method = SkipGram
|
||||
, window = 2
|
||||
, w2vDrawChart = True
|
||||
, w2vChartName = "w2v.png"
|
||||
} :: Word2Vec
|
||||
```
|
||||
|
||||
This is a very small development dataset and I have to test it on larger datasets.
|
||||
|
@ -32,32 +32,32 @@ module Main where
|
||||
sws <- lines <$> readFile "examples/stopwords"
|
||||
|
||||
-- real data, takes a lot of time to train
|
||||
ds <- do
|
||||
files <- filter ((/= "xml") . take 1 . reverse) <$> listDirectory "examples/blogs-corpus/"
|
||||
contents <- mapM (rf . ("examples/blogs-corpus/" ++)) files
|
||||
{-ds <- do-}
|
||||
{-files <- filter ((/= "xml") . take 1 . reverse) <$> listDirectory "examples/blogs-corpus/"-}
|
||||
{-contents <- mapM (rf . ("examples/blogs-corpus/" ++)) files-}
|
||||
|
||||
let texts = map (unwords . splitOn " ") contents
|
||||
let tags = ["<Blog>", "</Blog>", "<date>", "</date>", "<post>", "</post>", " "]
|
||||
return $ map cleanText $ removeWords (sws ++ tags) texts
|
||||
{-let texts = map (unwords . splitOn " ") contents-}
|
||||
{-let tags = ["<Blog>", "</Blog>", "<date>", "</date>", "<post>", "</post>", " "]-}
|
||||
{-return $ map cleanText $ removeWords (sws ++ tags) texts-}
|
||||
|
||||
{-let ds = ["the king loves the queen", "the queen loves the king",-}
|
||||
{-"the dwarf hates the king", "the queen hates the dwarf",-}
|
||||
{-"the dwarf poisons the king", "the dwarf poisons the queen",-}
|
||||
{-"the man loves the woman", "the woman loves the man",-}
|
||||
{-"the thief hates the man", "the woman hates the thief",-}
|
||||
{-"the thief robs the man", "the thief robs the woman"]-}
|
||||
let ds = ["the king loves the queen", "the queen loves the king",
|
||||
"the dwarf hates the king", "the queen hates the dwarf",
|
||||
"the dwarf poisons the king", "the dwarf poisons the queen",
|
||||
"the man loves the woman", "the woman loves the man",
|
||||
"the thief hates the man", "the woman hates the thief",
|
||||
"the thief robs the man", "the thief robs the woman"]
|
||||
|
||||
let session = def { learningRate = 5e-1
|
||||
let session = def { learningRate = 0.1
|
||||
, batchSize = 1
|
||||
, epochs = 200
|
||||
, epochs = 10000
|
||||
, debug = True
|
||||
} :: Session
|
||||
w2v = def { docs = ds
|
||||
, dimensions = 300
|
||||
, dimensions = 30
|
||||
, method = SkipGram
|
||||
, window = 2
|
||||
, w2vDrawChart = True
|
||||
, w2vChartName = "w2v-big-data.png"
|
||||
, w2vChartName = "w2v.png"
|
||||
} :: Word2Vec
|
||||
|
||||
(computed, vocvec) <- word2vec w2v session
|
||||
|
@ -16,7 +16,6 @@ cabal-version: >=1.10
|
||||
library
|
||||
hs-source-dirs: src
|
||||
exposed-modules: Sibe, Sibe.NaiveBayes, Sibe.NLP, Sibe.Word2Vec, Sibe.Utils
|
||||
ghc-options: -threaded -rtsopts -with-rtsopts=-N
|
||||
build-depends: base >= 4.7 && < 5
|
||||
, hmatrix
|
||||
, random
|
||||
|
@ -3,6 +3,7 @@ module Sibe.Utils
|
||||
, ordNub
|
||||
, onehot
|
||||
, average
|
||||
, pca
|
||||
) where
|
||||
import qualified Data.Vector.Storable as V
|
||||
import qualified Data.Set as Set
|
||||
@ -26,3 +27,15 @@ module Sibe.Utils
|
||||
|
||||
average :: Vector Double -> Vector Double
|
||||
average v = cmap (/ (V.sum v)) v
|
||||
|
||||
pca :: Matrix Double -> Int -> Matrix Double
|
||||
pca m d =
|
||||
let rs = toRows m
|
||||
means = map (\v -> V.sum v / fromIntegral (V.length v)) rs
|
||||
meanReduced = map (\(a, b) -> V.map (+ (negate b)) a) $ zip rs means
|
||||
mat = fromRows meanReduced
|
||||
|
||||
(u, s, v) = svd mat
|
||||
diagS = diagRect 0 s (rows mat) (cols mat)
|
||||
|
||||
in u ?? (All, Take d) <> diagS ?? (Take d, Take d)
|
||||
|
@ -1,3 +1,5 @@
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Sibe.Word2Vec
|
||||
( word2vec
|
||||
, Word2Vec (..)
|
||||
@ -9,7 +11,7 @@ module Sibe.Word2Vec
|
||||
import Data.Char
|
||||
import Data.Maybe
|
||||
import Data.List
|
||||
import Numeric.LinearAlgebra hiding (find)
|
||||
import Numeric.LinearAlgebra as H hiding (find)
|
||||
import qualified Data.Vector.Storable as V
|
||||
import Data.Default.Class
|
||||
import Data.Function (on)
|
||||
@ -59,21 +61,17 @@ module Sibe.Word2Vec
|
||||
let computedVocVec = map (\(w, v) -> (w, runLayer' v hidden)) vocvec
|
||||
|
||||
when (w2vDrawChart w2v) $ do
|
||||
let mat = fromColumns . map snd $ computedVocVec
|
||||
(u, s, v) = svd mat
|
||||
cut = subMatrix (0, 0) (2, cols mat)
|
||||
diagS = diagRect 0 (V.take 2 s) (rows mat) (cols mat)
|
||||
|
||||
twoDimensions = cut $ u <> diagS <> tr v
|
||||
textData = zipWith (\s l -> (V.head l, V.last l, s)) (map fst computedVocVec) (toColumns twoDimensions)
|
||||
let m = fromRows . map snd $ computedVocVec
|
||||
twoDimensions = pca m 2
|
||||
textData = zipWith (\s l -> (V.head l, V.last l, s)) (map fst computedVocVec) (toRows twoDimensions)
|
||||
|
||||
chart = toRenderable layout
|
||||
where
|
||||
textP = plot_annotation_values .~ textData
|
||||
textP = plot_annotation_values .~ textData
|
||||
$ def
|
||||
layout = layout_title .~ "word vectors"
|
||||
$ layout_plots .~ [toPlot textP]
|
||||
$ def
|
||||
$ layout_plots .~ [toPlot textP]
|
||||
$ def
|
||||
|
||||
renderableToFile def (w2vChartName w2v) chart
|
||||
return ()
|
||||
|
Loading…
Reference in New Issue
Block a user