2016-07-17 20:00:17 +00:00
|
|
|
module Main where
|
2016-10-16 22:24:35 +00:00
|
|
|
import Numeric.Sibe
|
2016-07-17 20:00:17 +00:00
|
|
|
import Numeric.LinearAlgebra
|
|
|
|
import Data.List
|
|
|
|
import Debug.Trace
|
2016-09-09 20:06:15 +00:00
|
|
|
import Data.Default.Class
|
2016-07-17 20:00:17 +00:00
|
|
|
|
2016-07-18 12:37:12 +00:00
|
|
|
main = do
|
2016-09-09 20:06:15 +00:00
|
|
|
let a = (sigmoid, sigmoid')
|
|
|
|
rnetwork = randomNetwork 0 (-1, 1) 2 [(2, a)] (1, a) -- two inputs, 8 nodes in a single hidden layer, 1 output
|
2016-07-17 20:00:17 +00:00
|
|
|
|
|
|
|
inputs = [vector [0, 1], vector [1, 0], vector [1, 1], vector [0, 0]]
|
|
|
|
labels = [vector [1], vector [1], vector [0], vector [0]]
|
|
|
|
|
2016-09-09 20:06:15 +00:00
|
|
|
session = def { network = rnetwork
|
2016-09-10 14:51:52 +00:00
|
|
|
, learningRate = 0.8
|
2016-09-09 20:06:15 +00:00
|
|
|
, epochs = 1000
|
|
|
|
, training = zip inputs labels
|
|
|
|
, test = zip inputs labels
|
|
|
|
} :: Session
|
2016-07-24 06:18:04 +00:00
|
|
|
|
2016-09-09 20:06:15 +00:00
|
|
|
initialCost = crossEntropy session
|
|
|
|
|
|
|
|
newsession <- run gd session
|
|
|
|
|
|
|
|
let results = map (`forward` newsession) inputs
|
2016-07-18 12:03:34 +00:00
|
|
|
rounded = map (map round . toList) results
|
2016-07-18 12:37:12 +00:00
|
|
|
|
2016-09-09 20:06:15 +00:00
|
|
|
cost = crossEntropy newsession
|
2016-07-24 06:18:04 +00:00
|
|
|
|
2016-07-18 12:37:12 +00:00
|
|
|
putStrLn "parameters: "
|
|
|
|
putStrLn $ "- inputs: " ++ show inputs
|
|
|
|
putStrLn $ "- labels: " ++ show labels
|
2016-09-09 20:06:15 +00:00
|
|
|
putStrLn $ "- learning rate: " ++ show (learningRate session)
|
|
|
|
putStrLn $ "- epochs: " ++ show (epochs session)
|
|
|
|
putStrLn $ "- initial cost (cross-entropy): " ++ show initialCost
|
2016-07-18 12:37:12 +00:00
|
|
|
putStrLn "results: "
|
|
|
|
putStrLn $ "- actual result: " ++ show results
|
|
|
|
putStrLn $ "- rounded result: " ++ show rounded
|
2016-07-24 06:18:04 +00:00
|
|
|
putStrLn $ "- cost (cross-entropy): " ++ show cost
|