fix(session): missing n
definition throws error
This commit is contained in:
parent
fb01d936c2
commit
763faef434
@ -11,7 +11,7 @@ import Debug.Trace
|
|||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main =
|
main =
|
||||||
let learning_rate = 0.5
|
let learning_rate = 0.01
|
||||||
ih = randomLayer 0 (2, 10)
|
ih = randomLayer 0 (2, 10)
|
||||||
ho = randomLayer 1 (10, 1)
|
ho = randomLayer 1 (10, 1)
|
||||||
network = ih :- O ho
|
network = ih :- O ho
|
||||||
@ -21,7 +21,7 @@ main =
|
|||||||
|
|
||||||
labels = [vector [1], vector [0], vector [1], vector [0]]
|
labels = [vector [1], vector [0], vector [1], vector [0]]
|
||||||
|
|
||||||
updated_network = session inputs network labels 0.01 10000
|
updated_network = session inputs network labels learning_rate 100
|
||||||
-- updated_network = train (head inputs) network (head labels) 0.5
|
-- updated_network = train (head inputs) network (head labels) 0.5
|
||||||
results = map (\x -> forward x updated_network) inputs
|
results = map (`forward` updated_network) inputs
|
||||||
in print results
|
in print results
|
||||||
|
BIN
app/Main.o
BIN
app/Main.o
Binary file not shown.
@ -90,7 +90,8 @@ module Lib
|
|||||||
|
|
||||||
session :: [Input] -> Network -> [Output] -> Double -> Int -> Network
|
session :: [Input] -> Network -> [Output] -> Double -> Int -> Network
|
||||||
session inputs network labels alpha epochs =
|
session inputs network labels alpha epochs =
|
||||||
foldl' iter network [0..n * epochs]
|
let n = length inputs - 1
|
||||||
|
in foldl' iter network [0..n * epochs]
|
||||||
where
|
where
|
||||||
iter net i =
|
iter net i =
|
||||||
let n = length inputs - 1
|
let n = length inputs - 1
|
||||||
|
Loading…
Reference in New Issue
Block a user