fix(train): fix an error in computing layer's error

feat(examples): add an example (xor)
This commit is contained in:
Mahdi Dibaiee
2016-07-18 00:30:17 +04:30
parent 763faef434
commit 4397f5203a
10 changed files with 66 additions and 20 deletions

Binary file not shown.

View File

@ -12,12 +12,13 @@ module Lib
randomLayer,
train,
session,
shuffle,
) where
import Numeric.LinearAlgebra
import Control.Monad.Random
import System.Random
import Debug.Trace
import Data.List (foldl')
import Data.List (foldl', sortBy)
type LearningRate = Double
type Input = Vector Double
@ -49,7 +50,7 @@ module Lib
logistic x = 1 / (1 + exp (-x))
logistic' :: Double -> Double
logistic' x = logistic x / max 1e-8 (1 - logistic x)
logistic' x = logistic x / max 1e-10 (1 - logistic x)
train :: Input
-> Network
@ -63,7 +64,7 @@ module Lib
let y = runLayer input l
o = cmap logistic y
delta = o - target
de = delta * cmap logistic' y
de = delta * cmap logistic' o
biases' = biases - scale alpha de
weights' = weights - scale alpha (input `outer` de) -- small inputs learn slowly
@ -78,7 +79,7 @@ module Lib
o = cmap logistic y
(n', delta) = run o n
de = delta * cmap logistic' y
de = delta * cmap logistic' o
biases' = biases - scale alpha de
weights' = weights - scale alpha (input `outer` de)
@ -88,14 +89,22 @@ module Lib
-- pass = weights #> de
in (layer :- n', pass)
session :: [Input] -> Network -> [Output] -> Double -> Int -> Network
session inputs network labels alpha epochs =
let n = length inputs - 1
in foldl' iter network [0..n * epochs]
session :: [Input] -> Network -> [Output] -> Double -> (Int, Int) -> Network
session inputs network labels alpha (iterations, epochs) =
let n = length inputs
indexes = shuffle n (map (`mod` n) [0..n * epochs])
in foldl' iter network indexes
where
iter net i =
let n = length inputs - 1
let n = length inputs
index = i `mod` n
input = inputs !! index
label = labels !! index
in train input net label alpha
in foldl' (\net _ -> train input net label alpha) net [0..iterations]
shuffle :: Seed -> [a] -> [a]
shuffle seed list =
let ords = map ord $ take (length list) (randomRs (0, 1) (mkStdGen seed) :: [Int])
in map snd $ sortBy (\x y -> fst x) (zip ords list)
where ord x | x == 0 = LT
| x == 1 = GT

BIN
src/Lib.o

Binary file not shown.