fix(train): fix an error in computing layer's error
feat(examples): add an example (xor)
This commit is contained in:
parent
763faef434
commit
4397f5203a
BIN
app/Main.hi
BIN
app/Main.hi
Binary file not shown.
17
app/Main.hs
17
app/Main.hs
@ -9,19 +9,16 @@ import Debug.Trace
|
|||||||
-- 2x3 + 1x3
|
-- 2x3 + 1x3
|
||||||
-- 3x1 + 1x1
|
-- 3x1 + 1x1
|
||||||
|
|
||||||
main :: IO ()
|
-- main :: IO [()]
|
||||||
main =
|
main =
|
||||||
let learning_rate = 0.01
|
let learning_rate = 0.5
|
||||||
ih = randomLayer 0 (2, 10)
|
ih = randomLayer 0 (2, 8)
|
||||||
ho = randomLayer 1 (10, 1)
|
ho = randomLayer 1 (8, 1)
|
||||||
network = ih :- O ho
|
network = ih :- O ho
|
||||||
inputs = [vector [0, 1], vector [1, 1], vector [1, 0], vector [0, 0]]
|
|
||||||
|
|
||||||
-- result = forward input network
|
inputs = [vector [0, 1], vector [1, 0], vector [1, 1], vector [0, 0]]
|
||||||
|
labels = [vector [1], vector [1], vector [0], vector [0]]
|
||||||
|
|
||||||
labels = [vector [1], vector [0], vector [1], vector [0]]
|
updated_network = session inputs network labels learning_rate (2, 1000)
|
||||||
|
|
||||||
updated_network = session inputs network labels learning_rate 100
|
|
||||||
-- updated_network = train (head inputs) network (head labels) 0.5
|
|
||||||
results = map (`forward` updated_network) inputs
|
results = map (`forward` updated_network) inputs
|
||||||
in print results
|
in print results
|
||||||
|
BIN
app/Main.o
BIN
app/Main.o
Binary file not shown.
BIN
examples/xor
Executable file
BIN
examples/xor
Executable file
Binary file not shown.
32
examples/xor.hs
Normal file
32
examples/xor.hs
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
module Main where
|
||||||
|
import Lib
|
||||||
|
import Numeric.LinearAlgebra
|
||||||
|
import Data.List
|
||||||
|
import Debug.Trace
|
||||||
|
|
||||||
|
-- 1x2
|
||||||
|
-- 2x3 + 1x3
|
||||||
|
-- 3x1 + 1x1
|
||||||
|
|
||||||
|
-- main :: IO [()]
|
||||||
|
main =
|
||||||
|
let learning_rate = 0.5
|
||||||
|
(iterations, epochs) = (2, 1000)
|
||||||
|
ih = randomLayer 0 (2, 8)
|
||||||
|
ho = randomLayer 1 (8, 1)
|
||||||
|
network = ih :- O ho
|
||||||
|
|
||||||
|
inputs = [vector [0, 1], vector [1, 0], vector [1, 1], vector [0, 0]]
|
||||||
|
labels = [vector [1], vector [1], vector [0], vector [0]]
|
||||||
|
|
||||||
|
updated_network = session inputs network labels learning_rate (iterations, epochs)
|
||||||
|
results = map (`forward` updated_network) inputs
|
||||||
|
rounded = map (map round) $ map toList results
|
||||||
|
in sequence [putStrLn "",
|
||||||
|
putStrLn $ "inputs: " ++ show inputs,
|
||||||
|
putStrLn $ "labels: " ++ show labels,
|
||||||
|
putStrLn $ "learning rate: " ++ show learning_rate,
|
||||||
|
putStrLn $ "iterations/epochs: " ++ show (iterations, epochs),
|
||||||
|
putStrLn $ "...",
|
||||||
|
putStrLn $ "rounded result: " ++ show rounded,
|
||||||
|
putStrLn $ "actual result: " ++ show results]
|
@ -27,6 +27,14 @@ executable sibe-exe
|
|||||||
, sibe
|
, sibe
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
|
|
||||||
|
executable example-xor
|
||||||
|
hs-source-dirs: examples
|
||||||
|
main-is: xor.hs
|
||||||
|
ghc-options: -threaded -rtsopts -with-rtsopts=-N
|
||||||
|
build-depends: base
|
||||||
|
, sibe
|
||||||
|
default-language: Haskell2010
|
||||||
|
|
||||||
test-suite sibe-test
|
test-suite sibe-test
|
||||||
type: exitcode-stdio-1.0
|
type: exitcode-stdio-1.0
|
||||||
hs-source-dirs: test
|
hs-source-dirs: test
|
||||||
|
BIN
src/Lib.hi
BIN
src/Lib.hi
Binary file not shown.
29
src/Lib.hs
29
src/Lib.hs
@ -12,12 +12,13 @@ module Lib
|
|||||||
randomLayer,
|
randomLayer,
|
||||||
train,
|
train,
|
||||||
session,
|
session,
|
||||||
|
shuffle,
|
||||||
) where
|
) where
|
||||||
import Numeric.LinearAlgebra
|
import Numeric.LinearAlgebra
|
||||||
import Control.Monad.Random
|
import Control.Monad.Random
|
||||||
import System.Random
|
import System.Random
|
||||||
import Debug.Trace
|
import Debug.Trace
|
||||||
import Data.List (foldl')
|
import Data.List (foldl', sortBy)
|
||||||
|
|
||||||
type LearningRate = Double
|
type LearningRate = Double
|
||||||
type Input = Vector Double
|
type Input = Vector Double
|
||||||
@ -49,7 +50,7 @@ module Lib
|
|||||||
logistic x = 1 / (1 + exp (-x))
|
logistic x = 1 / (1 + exp (-x))
|
||||||
|
|
||||||
logistic' :: Double -> Double
|
logistic' :: Double -> Double
|
||||||
logistic' x = logistic x / max 1e-8 (1 - logistic x)
|
logistic' x = logistic x / max 1e-10 (1 - logistic x)
|
||||||
|
|
||||||
train :: Input
|
train :: Input
|
||||||
-> Network
|
-> Network
|
||||||
@ -63,7 +64,7 @@ module Lib
|
|||||||
let y = runLayer input l
|
let y = runLayer input l
|
||||||
o = cmap logistic y
|
o = cmap logistic y
|
||||||
delta = o - target
|
delta = o - target
|
||||||
de = delta * cmap logistic' y
|
de = delta * cmap logistic' o
|
||||||
|
|
||||||
biases' = biases - scale alpha de
|
biases' = biases - scale alpha de
|
||||||
weights' = weights - scale alpha (input `outer` de) -- small inputs learn slowly
|
weights' = weights - scale alpha (input `outer` de) -- small inputs learn slowly
|
||||||
@ -78,7 +79,7 @@ module Lib
|
|||||||
o = cmap logistic y
|
o = cmap logistic y
|
||||||
(n', delta) = run o n
|
(n', delta) = run o n
|
||||||
|
|
||||||
de = delta * cmap logistic' y
|
de = delta * cmap logistic' o
|
||||||
|
|
||||||
biases' = biases - scale alpha de
|
biases' = biases - scale alpha de
|
||||||
weights' = weights - scale alpha (input `outer` de)
|
weights' = weights - scale alpha (input `outer` de)
|
||||||
@ -88,14 +89,22 @@ module Lib
|
|||||||
-- pass = weights #> de
|
-- pass = weights #> de
|
||||||
in (layer :- n', pass)
|
in (layer :- n', pass)
|
||||||
|
|
||||||
session :: [Input] -> Network -> [Output] -> Double -> Int -> Network
|
session :: [Input] -> Network -> [Output] -> Double -> (Int, Int) -> Network
|
||||||
session inputs network labels alpha epochs =
|
session inputs network labels alpha (iterations, epochs) =
|
||||||
let n = length inputs - 1
|
let n = length inputs
|
||||||
in foldl' iter network [0..n * epochs]
|
indexes = shuffle n (map (`mod` n) [0..n * epochs])
|
||||||
|
in foldl' iter network indexes
|
||||||
where
|
where
|
||||||
iter net i =
|
iter net i =
|
||||||
let n = length inputs - 1
|
let n = length inputs
|
||||||
index = i `mod` n
|
index = i `mod` n
|
||||||
input = inputs !! index
|
input = inputs !! index
|
||||||
label = labels !! index
|
label = labels !! index
|
||||||
in train input net label alpha
|
in foldl' (\net _ -> train input net label alpha) net [0..iterations]
|
||||||
|
|
||||||
|
shuffle :: Seed -> [a] -> [a]
|
||||||
|
shuffle seed list =
|
||||||
|
let ords = map ord $ take (length list) (randomRs (0, 1) (mkStdGen seed) :: [Int])
|
||||||
|
in map snd $ sortBy (\x y -> fst x) (zip ords list)
|
||||||
|
where ord x | x == 0 = LT
|
||||||
|
| x == 1 = GT
|
||||||
|
Loading…
Reference in New Issue
Block a user