diff --git a/README.md b/README.md index dafd963..f7ceacd 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,8 @@ A simple Machine Learning library. , epochs = 1000 , training = zip inputs labels , test = zip inputs labels + , drawChart = True + , chartName = "nn.png" -- draws chart of loss over time } :: Session initialCost = crossEntropy session diff --git a/src/Sibe.hs b/src/Sibe.hs index e30b6ac..eb97ca5 100644 --- a/src/Sibe.hs +++ b/src/Sibe.hs @@ -256,6 +256,7 @@ module Sibe shuffled <- shuffleM pairs let newnet = foldl' (\n (input, label) -> train input n label alpha) net pairs + cost = crossEntropy (session { network = newnet }) let el = map (\(e, l, _) -> (e, l)) (chart session) ea = map (\(e, _, a) -> (e, a)) (chart session) @@ -268,6 +269,7 @@ module Sibe return session { network = newnet , epoch = epoch session + 1 + , chart = (epoch session, cost, learningRate session):chart session } sgd :: Session -> IO Session