2016-07-17 12:23:13 +00:00
|
|
|
module Main where
|
|
|
|
|
|
|
|
import Lib
|
|
|
|
import Numeric.LinearAlgebra
|
|
|
|
import Data.List
|
|
|
|
import Debug.Trace
|
|
|
|
|
|
|
|
-- 1x2
|
|
|
|
-- 2x3 + 1x3
|
|
|
|
-- 3x1 + 1x1
|
|
|
|
|
2016-07-17 20:00:17 +00:00
|
|
|
-- main :: IO [()]
|
2016-07-17 12:23:13 +00:00
|
|
|
main =
|
2016-07-17 20:00:17 +00:00
|
|
|
let learning_rate = 0.5
|
|
|
|
ih = randomLayer 0 (2, 8)
|
|
|
|
ho = randomLayer 1 (8, 1)
|
2016-07-17 12:23:13 +00:00
|
|
|
network = ih :- O ho
|
|
|
|
|
2016-07-17 20:00:17 +00:00
|
|
|
inputs = [vector [0, 1], vector [1, 0], vector [1, 1], vector [0, 0]]
|
|
|
|
labels = [vector [1], vector [1], vector [0], vector [0]]
|
2016-07-17 12:23:13 +00:00
|
|
|
|
2016-07-17 20:00:17 +00:00
|
|
|
updated_network = session inputs network labels learning_rate (2, 1000)
|
2016-07-17 12:44:50 +00:00
|
|
|
results = map (`forward` updated_network) inputs
|
2016-07-17 12:23:13 +00:00
|
|
|
in print results
|