feat(rnn): recurrent neural networks, experimental
WIP: runs out of memory quickly
This commit is contained in:
65
examples/recurrent.hs
Normal file
65
examples/recurrent.hs
Normal file
@ -0,0 +1,65 @@
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Main where
|
||||
import Numeric.LinearAlgebra
|
||||
import Numeric.Sibe.Recurrent
|
||||
import Numeric.Sibe.Utils
|
||||
import System.IO
|
||||
import Data.Default.Class
|
||||
import Data.List (genericLength)
|
||||
import qualified Data.ByteString.Lazy.Char8 as BSL
|
||||
|
||||
main = do
|
||||
texts <- lines <$> readFile "examples/reddit.csv"
|
||||
let (vocabulary, indexes) = processData texts
|
||||
|
||||
let settings = def { wordD = length vocabulary }
|
||||
r = randomRecurrent 0 settings
|
||||
|
||||
let x0 = reverse . drop 1 . reverse $ indexes !! 0
|
||||
y0 = drop 1 $ indexes !! 0
|
||||
|
||||
print $ x0
|
||||
print $ y0
|
||||
|
||||
let xs = map (reverse . drop 1 . reverse) indexes
|
||||
ys = map (drop 1) indexes
|
||||
|
||||
let tov = fromList . map fromIntegral
|
||||
let vys = map tov ys
|
||||
|
||||
let newr = sgd r (take 1 xs) (take 1 vys) 0.005 1
|
||||
|
||||
saveRecurrent "recurrent.trained" (show newr) 512
|
||||
--writeFile "recurrent.trained" (show newr)
|
||||
|
||||
let newpredicted = predict newr x0
|
||||
print $ y0
|
||||
print $ newpredicted
|
||||
|
||||
print $ loss (tov y0) (tov newpredicted)
|
||||
|
||||
{-let (dU, dV, dW) = backprop r x0 (fromList $ map fromIntegral y0)-}
|
||||
{-print $ seq u "u"-}
|
||||
{-print $ seq v "v"-}
|
||||
{-print $ seq w "w"-}
|
||||
|
||||
--print $ dW
|
||||
print "done"
|
||||
|
||||
saveRecurrent :: FilePath -> String -> Int -> IO ()
|
||||
saveRecurrent path str chunkSize = do
|
||||
handle <- openFile path AppendMode
|
||||
hSetBuffering handle NoBuffering
|
||||
loop handle str
|
||||
hClose handle
|
||||
where
|
||||
loop _ [] = return ()
|
||||
loop handle s = do
|
||||
hPutStr handle $ take chunkSize s
|
||||
hFlush handle
|
||||
putStr $ take chunkSize s
|
||||
loop handle $ drop chunkSize s
|
||||
|
20087
examples/reddit.csv
Normal file
20087
examples/reddit.csv
Normal file
File diff suppressed because it is too large
Load Diff
@ -31,7 +31,7 @@ module Main where
|
||||
setStdGen (mkStdGen 100)
|
||||
sws <- lines <$> readFile "examples/stopwords"
|
||||
|
||||
-- real data, takes a lot of time to train
|
||||
-- real data, currently faces a memory problem
|
||||
{-ds <- do-}
|
||||
{-files <- filter ((/= "xml") . take 1 . reverse) <$> listDirectory "examples/blogs-corpus/"-}
|
||||
{-contents <- mapM (rf . ("examples/blogs-corpus/" ++)) files-}
|
||||
|
Reference in New Issue
Block a user