feat(rnn): recurrent neural networks, experimental
WIP: runs out of memory quickly
This commit is contained in:
		@@ -112,8 +112,8 @@ module Numeric.Sibe.NaiveBayes
 | 
			
		||||
          -- in realToFrac (tct * pg + 1) / realToFrac (cvoc + voc) -- uncomment to enable ngrams
 | 
			
		||||
          in realToFrac (tct + 1) / realToFrac (cvoc + voc)
 | 
			
		||||
 | 
			
		||||
    argmax :: (Ord a) => [a] -> Int
 | 
			
		||||
    argmax x = fst $ maximumBy (\(_, a) (_, b) -> a `compare` b) (zip [0..] x)
 | 
			
		||||
    {-argmax :: (Ord a) => [a] -> Int-}
 | 
			
		||||
    {-argmax x = fst $ maximumBy (\(_, a) (_, b) -> a `compare` b) (zip [0..] x)-}
 | 
			
		||||
 | 
			
		||||
    mean :: [Double] -> Double
 | 
			
		||||
    mean x = sum x / genericLength x
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										144
									
								
								src/Numeric/Sibe/Recurrent.hs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								src/Numeric/Sibe/Recurrent.hs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,144 @@
 | 
			
		||||
{-# LANGUAGE GADTs #-}
 | 
			
		||||
{-# LANGUAGE BangPatterns #-}
 | 
			
		||||
{-# LANGUAGE DataKinds #-}
 | 
			
		||||
{-# LANGUAGE TypeOperators #-}
 | 
			
		||||
 | 
			
		||||
module Numeric.Sibe.Recurrent
 | 
			
		||||
  ( Recurrent (..)
 | 
			
		||||
  , randomRecurrent
 | 
			
		||||
  , processData
 | 
			
		||||
  , forward
 | 
			
		||||
  , predict
 | 
			
		||||
  , loss
 | 
			
		||||
  , backprop
 | 
			
		||||
  , sgd
 | 
			
		||||
  ) where
 | 
			
		||||
    import Numeric.LinearAlgebra
 | 
			
		||||
    import System.Random
 | 
			
		||||
    import System.Random.Shuffle
 | 
			
		||||
    import Debug.Trace
 | 
			
		||||
    import qualified Data.List as L
 | 
			
		||||
    import Data.Maybe
 | 
			
		||||
    import System.IO
 | 
			
		||||
    import Control.DeepSeq
 | 
			
		||||
    import Control.Monad
 | 
			
		||||
    import qualified Data.Vector.Storable as V
 | 
			
		||||
    import Data.Default.Class
 | 
			
		||||
 | 
			
		||||
    import qualified Graphics.Rendering.Chart.Easy as Chart
 | 
			
		||||
    import Graphics.Rendering.Chart.Backend.Cairo
 | 
			
		||||
    import Numeric.Sibe.Utils
 | 
			
		||||
    import Debug.Trace
 | 
			
		||||
 | 
			
		||||
    processData :: [String] -> ([(Int, String)], [[Int]])
 | 
			
		||||
    processData x =
 | 
			
		||||
      let setokens = map (\a -> " <start> " ++ a ++ " <end> ") x
 | 
			
		||||
          tokenized = map tokenize setokens
 | 
			
		||||
          vocabulary = zip [0..] (unique . concat $ tokenized)
 | 
			
		||||
          indexes = map (\a -> fst . fromJust $ L.find ((==a) . snd) vocabulary)
 | 
			
		||||
      in (vocabulary, map indexes tokenized)
 | 
			
		||||
 | 
			
		||||
    data Recurrent = Recurrent { bpttThreshold :: Int
 | 
			
		||||
                               , wordD :: Int
 | 
			
		||||
                               , hiddenD :: Int
 | 
			
		||||
                               , u :: Matrix Double
 | 
			
		||||
                               , v :: Matrix Double
 | 
			
		||||
                               , w :: Matrix Double
 | 
			
		||||
                               } deriving (Show, Read)
 | 
			
		||||
    instance Default Recurrent where
 | 
			
		||||
      def = Recurrent { bpttThreshold = 3
 | 
			
		||||
                      , hiddenD = 100
 | 
			
		||||
                      }
 | 
			
		||||
 | 
			
		||||
    randomRecurrent :: Seed -> Recurrent -> Recurrent
 | 
			
		||||
    randomRecurrent seed r = r { u = randomMatrix (wordD r, hiddenD r) (bounds $ wordD r)
 | 
			
		||||
                               , v = randomMatrix (hiddenD r, wordD r) (bounds $ hiddenD r)
 | 
			
		||||
                               , w = randomMatrix (hiddenD r, hiddenD r) (bounds $ hiddenD r)
 | 
			
		||||
                               }
 | 
			
		||||
      where
 | 
			
		||||
        randomMatrix (wr, wc) (l, u) = uniformSample (seed + wr + wc) wr $ replicate wc (l, u)
 | 
			
		||||
        bounds x = (negate . sqrt $ 1 / fromIntegral x, sqrt $ 1 / fromIntegral x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    forward :: Recurrent -> [Int] -> (Matrix Double, Matrix Double)
 | 
			
		||||
    forward r input = 
 | 
			
		||||
      let (h, o) = helper [vector (replicate (hiddenD r) 0)] [] input
 | 
			
		||||
      in (fromRows h, fromRows o)
 | 
			
		||||
      where
 | 
			
		||||
        helper hs os [] = (hs, os)
 | 
			
		||||
        helper (h:hs) os (i:is) =
 | 
			
		||||
          let k = w r #> h
 | 
			
		||||
              newh = V.map tanh $ (u r ! i) + k
 | 
			
		||||
              o = softmax $ newh <# v r
 | 
			
		||||
          in helper (newh:h:hs) (o:os) is
 | 
			
		||||
 | 
			
		||||
    predict :: Recurrent -> [Int] -> [Int]
 | 
			
		||||
    predict r i = 
 | 
			
		||||
      let (_, o) = forward r i
 | 
			
		||||
      in map argmax (toLists o)
 | 
			
		||||
 | 
			
		||||
    backprop :: Recurrent -> [Int] -> Vector Double -> (Matrix Double, Matrix Double, Matrix Double)
 | 
			
		||||
    backprop r input y =
 | 
			
		||||
      let dU = zero (u r)
 | 
			
		||||
          dV = zero (v r)
 | 
			
		||||
          dW = zero (w r)
 | 
			
		||||
      in bp dU dV dW (zip [0..] input)
 | 
			
		||||
      where
 | 
			
		||||
        (hs, os) = forward r input
 | 
			
		||||
        -- delta
 | 
			
		||||
        dO = fromColumns $ zipWith (\i o -> if i `V.elem` y then o - 1 else o) [0..] (toColumns os)
 | 
			
		||||
 | 
			
		||||
        bp dU dV dW [] = (dU, dV, dW)
 | 
			
		||||
        bp dU dV dW ((i,x):xs) =
 | 
			
		||||
          let ndV = dV + (hs ! i) `outer` (dO ! i)
 | 
			
		||||
              dT = (v r) #> (dO ! i) -- * (1 - (hs ! i)^2)
 | 
			
		||||
              threshold = bpttThreshold r
 | 
			
		||||
              (ndU, ndW) = tt dU dW dT [max 0 (i-threshold)..i]
 | 
			
		||||
           in bp ndU ndV ndW xs
 | 
			
		||||
           where
 | 
			
		||||
             tt dU dW dT [] = (dU, dW)
 | 
			
		||||
             tt dU dW dT (c:cs) =
 | 
			
		||||
               let ndW = dW + (dT `outer` (hs ! (max 0 $ c - 1)))
 | 
			
		||||
                   zdT = vector $ replicate (V.length dT) 0
 | 
			
		||||
                   mdT = fromRows $ replicate (max 0 $ c - 1) zdT ++ [dT] ++ replicate (min (rows dU - 1) $ rows dU - c) zdT
 | 
			
		||||
                   ndU = dU + mdT
 | 
			
		||||
                   ndT = (w r) #> dT
 | 
			
		||||
               in tt ndU ndW ndT cs
 | 
			
		||||
 | 
			
		||||
        zero m = ((rows m)><(cols m)) $ repeat 0
 | 
			
		||||
 | 
			
		||||
    {-gradientCheck :: Recurrent -> [Int] -> Vector Double -> Double-}
 | 
			
		||||
 | 
			
		||||
    sgdStep :: Recurrent -> [Int] -> Vector Double -> Double -> Recurrent
 | 
			
		||||
    sgdStep r input y learningRate =
 | 
			
		||||
      let (dU, dV, dW) = backprop r input y
 | 
			
		||||
      in r { u = (u r) - scale learningRate dU
 | 
			
		||||
           , v = (v r) - scale learningRate dV
 | 
			
		||||
           , w = (w r) - scale learningRate dW
 | 
			
		||||
           }
 | 
			
		||||
 | 
			
		||||
    sgd :: Recurrent -> [[Int]] -> [Vector Double] -> Double -> Int -> Recurrent
 | 
			
		||||
    sgd r input y learningRate epochs = run [0..epochs] r
 | 
			
		||||
      where
 | 
			
		||||
        run [] r = r
 | 
			
		||||
        run (i:is) r = run is $ train (zip input y) r
 | 
			
		||||
 | 
			
		||||
        train [] r = r
 | 
			
		||||
        train ((x, y):xs) r = train xs $ sgdStep r x y learningRate
 | 
			
		||||
 | 
			
		||||
    softmax :: Vector Double -> Vector Double
 | 
			
		||||
    softmax x = cmap (\a -> exp a / s) x
 | 
			
		||||
      where
 | 
			
		||||
        s = V.sum $ exp x
 | 
			
		||||
 | 
			
		||||
    softmax' :: Vector Double -> Vector Double
 | 
			
		||||
    softmax' = cmap (\a -> sig a * (1 - sig a))
 | 
			
		||||
      where
 | 
			
		||||
        sig x = 1 / max (1 + exp (-x)) 1e-10
 | 
			
		||||
 | 
			
		||||
    -- cross-entropy
 | 
			
		||||
    loss :: Vector Double -> Vector Double -> Double
 | 
			
		||||
    loss ys os = (-1 / fromIntegral (V.length os)) * V.sum (V.zipWith f os ys)
 | 
			
		||||
      where
 | 
			
		||||
        f a y = y * log (max 1e-10 a)
 | 
			
		||||
 | 
			
		||||
@@ -4,10 +4,19 @@ module Numeric.Sibe.Utils
 | 
			
		||||
  , onehot
 | 
			
		||||
  , average
 | 
			
		||||
  , pca
 | 
			
		||||
  , tokenize
 | 
			
		||||
  , frequency
 | 
			
		||||
  , unique
 | 
			
		||||
  , argmax
 | 
			
		||||
  , shape
 | 
			
		||||
  ) where
 | 
			
		||||
    import qualified Data.Vector.Storable as V
 | 
			
		||||
    import qualified Data.Set as Set
 | 
			
		||||
    import Numeric.LinearAlgebra
 | 
			
		||||
    import Data.List.Split
 | 
			
		||||
    import Data.Char (isSpace, isNumber, toLower)
 | 
			
		||||
    import Control.Arrow ((&&&))
 | 
			
		||||
    import Data.List
 | 
			
		||||
 | 
			
		||||
    similarity :: Vector Double -> Vector Double -> Double
 | 
			
		||||
    similarity a b = (V.sum $ a * b) / (magnitude a * magnitude b)
 | 
			
		||||
@@ -24,6 +33,8 @@ module Numeric.Sibe.Utils
 | 
			
		||||
        go _ [] = []
 | 
			
		||||
        go s (x:xs) = if x `Set.member` s then go s xs
 | 
			
		||||
                                          else x : go (Set.insert x s) xs
 | 
			
		||||
    unique :: (Ord a) => [a] -> [a]
 | 
			
		||||
    unique = ordNub
 | 
			
		||||
 | 
			
		||||
    average :: Vector Double -> Vector Double
 | 
			
		||||
    average v = cmap (/ (V.sum v)) v
 | 
			
		||||
@@ -39,3 +50,27 @@ module Numeric.Sibe.Utils
 | 
			
		||||
          diagS = diagRect 0 s (rows mat) (cols mat)
 | 
			
		||||
 | 
			
		||||
      in u ?? (All, Take d) <> diagS ?? (Take d, Take d)
 | 
			
		||||
 | 
			
		||||
    tokenize :: String -> [String]
 | 
			
		||||
    tokenize str = 
 | 
			
		||||
      let spaced = spacify str
 | 
			
		||||
          ws = words spaced
 | 
			
		||||
      in ws
 | 
			
		||||
      where
 | 
			
		||||
        puncs = ['!', '"', '#', '$', '%', '(', ')', '.', '?', ',', '\'', '/', '-']
 | 
			
		||||
        replace needle replacement =
 | 
			
		||||
          concatMap (\c -> if c == needle then replacement else c)
 | 
			
		||||
        spacify = foldl (\acc c -> if c `elem` puncs then acc ++ [' ', c, ' '] else acc ++ [c]) ""
 | 
			
		||||
 | 
			
		||||
    frequency :: (Ord a) => [a] -> [(a, Int)]
 | 
			
		||||
    frequency = map (head &&& length) . group . sort
 | 
			
		||||
 | 
			
		||||
    argmax :: (Foldable t, Num a, Fractional a, Ord a) => t a -> Int
 | 
			
		||||
    argmax v = snd $ foldl mx ((-1/0), 0) v
 | 
			
		||||
      where
 | 
			
		||||
        mx (a, i) b
 | 
			
		||||
          | b > a = (b, i + 1)
 | 
			
		||||
          | otherwise = (a, i)
 | 
			
		||||
 | 
			
		||||
    shape :: Matrix a -> (Int, Int)
 | 
			
		||||
    shape x = (rows x, cols x)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user