fix(train): fix an error in computing layer's error
feat(examples): add an example (xor)
This commit is contained in:
BIN
src/Lib.hi
BIN
src/Lib.hi
Binary file not shown.
29
src/Lib.hs
29
src/Lib.hs
@ -12,12 +12,13 @@ module Lib
|
||||
randomLayer,
|
||||
train,
|
||||
session,
|
||||
shuffle,
|
||||
) where
|
||||
import Numeric.LinearAlgebra
|
||||
import Control.Monad.Random
|
||||
import System.Random
|
||||
import Debug.Trace
|
||||
import Data.List (foldl')
|
||||
import Data.List (foldl', sortBy)
|
||||
|
||||
type LearningRate = Double
|
||||
type Input = Vector Double
|
||||
@ -49,7 +50,7 @@ module Lib
|
||||
logistic x = 1 / (1 + exp (-x))
|
||||
|
||||
logistic' :: Double -> Double
|
||||
logistic' x = logistic x / max 1e-8 (1 - logistic x)
|
||||
logistic' x = logistic x / max 1e-10 (1 - logistic x)
|
||||
|
||||
train :: Input
|
||||
-> Network
|
||||
@ -63,7 +64,7 @@ module Lib
|
||||
let y = runLayer input l
|
||||
o = cmap logistic y
|
||||
delta = o - target
|
||||
de = delta * cmap logistic' y
|
||||
de = delta * cmap logistic' o
|
||||
|
||||
biases' = biases - scale alpha de
|
||||
weights' = weights - scale alpha (input `outer` de) -- small inputs learn slowly
|
||||
@ -78,7 +79,7 @@ module Lib
|
||||
o = cmap logistic y
|
||||
(n', delta) = run o n
|
||||
|
||||
de = delta * cmap logistic' y
|
||||
de = delta * cmap logistic' o
|
||||
|
||||
biases' = biases - scale alpha de
|
||||
weights' = weights - scale alpha (input `outer` de)
|
||||
@ -88,14 +89,22 @@ module Lib
|
||||
-- pass = weights #> de
|
||||
in (layer :- n', pass)
|
||||
|
||||
session :: [Input] -> Network -> [Output] -> Double -> Int -> Network
|
||||
session inputs network labels alpha epochs =
|
||||
let n = length inputs - 1
|
||||
in foldl' iter network [0..n * epochs]
|
||||
session :: [Input] -> Network -> [Output] -> Double -> (Int, Int) -> Network
|
||||
session inputs network labels alpha (iterations, epochs) =
|
||||
let n = length inputs
|
||||
indexes = shuffle n (map (`mod` n) [0..n * epochs])
|
||||
in foldl' iter network indexes
|
||||
where
|
||||
iter net i =
|
||||
let n = length inputs - 1
|
||||
let n = length inputs
|
||||
index = i `mod` n
|
||||
input = inputs !! index
|
||||
label = labels !! index
|
||||
in train input net label alpha
|
||||
in foldl' (\net _ -> train input net label alpha) net [0..iterations]
|
||||
|
||||
shuffle :: Seed -> [a] -> [a]
|
||||
shuffle seed list =
|
||||
let ords = map ord $ take (length list) (randomRs (0, 1) (mkStdGen seed) :: [Int])
|
||||
in map snd $ sortBy (\x y -> fst x) (zip ords list)
|
||||
where ord x | x == 0 = LT
|
||||
| x == 1 = GT
|
||||
|
Reference in New Issue
Block a user