perf(word2vec): better word2vec
This commit is contained in:
@ -1,7 +1,8 @@
|
||||
module Sibe.Utils
|
||||
(similarity,
|
||||
ordNub,
|
||||
onehot
|
||||
( similarity
|
||||
, ordNub
|
||||
, onehot
|
||||
, average
|
||||
) where
|
||||
import qualified Data.Vector.Storable as V
|
||||
import qualified Data.Set as Set
|
||||
@ -22,3 +23,6 @@ module Sibe.Utils
|
||||
go _ [] = []
|
||||
go s (x:xs) = if x `Set.member` s then go s xs
|
||||
else x : go (Set.insert x s) xs
|
||||
|
||||
average :: Vector Double -> Vector Double
|
||||
average v = cmap (/ (V.sum v)) v
|
||||
|
@ -75,12 +75,15 @@ module Sibe.Word2Vec
|
||||
| i == length vocvec - 1 = before
|
||||
| otherwise = before ++ after
|
||||
vectorized = map (\w -> snd . fromJust $ find ((== w) . fst) vocvec) ns
|
||||
new = cmap (max 1) $ foldl1 (+) vectorized
|
||||
new = foldl1 (+) vectorized
|
||||
in
|
||||
case method w2v of
|
||||
SkipGram -> zip (repeat v) vectorized
|
||||
CBOW -> zip vectorized (repeat v)
|
||||
_ -> error "unsupported word2vec method"
|
||||
if length wds <= 1
|
||||
then []
|
||||
else
|
||||
case method w2v of
|
||||
SkipGram -> [(v, average new)]
|
||||
CBOW -> [(average new, v)]
|
||||
_ -> error "unsupported word2vec method"
|
||||
|
||||
cleanText :: String -> String
|
||||
cleanText string =
|
||||
|
Reference in New Issue
Block a user