crossEntropy chart for now
This commit is contained in:
parent
6def5f6197
commit
313e120f25
@ -19,6 +19,8 @@ A simple Machine Learning library.
|
|||||||
, epochs = 1000
|
, epochs = 1000
|
||||||
, training = zip inputs labels
|
, training = zip inputs labels
|
||||||
, test = zip inputs labels
|
, test = zip inputs labels
|
||||||
|
, drawChart = True
|
||||||
|
, chartName = "nn.png" -- draws chart of loss over time
|
||||||
} :: Session
|
} :: Session
|
||||||
|
|
||||||
initialCost = crossEntropy session
|
initialCost = crossEntropy session
|
||||||
|
@ -256,6 +256,7 @@ module Sibe
|
|||||||
shuffled <- shuffleM pairs
|
shuffled <- shuffleM pairs
|
||||||
|
|
||||||
let newnet = foldl' (\n (input, label) -> train input n label alpha) net pairs
|
let newnet = foldl' (\n (input, label) -> train input n label alpha) net pairs
|
||||||
|
cost = crossEntropy (session { network = newnet })
|
||||||
|
|
||||||
let el = map (\(e, l, _) -> (e, l)) (chart session)
|
let el = map (\(e, l, _) -> (e, l)) (chart session)
|
||||||
ea = map (\(e, _, a) -> (e, a)) (chart session)
|
ea = map (\(e, _, a) -> (e, a)) (chart session)
|
||||||
@ -268,6 +269,7 @@ module Sibe
|
|||||||
|
|
||||||
return session { network = newnet
|
return session { network = newnet
|
||||||
, epoch = epoch session + 1
|
, epoch = epoch session + 1
|
||||||
|
, chart = (epoch session, cost, learningRate session):chart session
|
||||||
}
|
}
|
||||||
|
|
||||||
sgd :: Session -> IO Session
|
sgd :: Session -> IO Session
|
||||||
|
Loading…
Reference in New Issue
Block a user