fix(crossEntropy): implement crossEntropy' to be used in output layer

fix(softmax'): softmax was not correct
This commit is contained in:
Mahdi Dibaiee 2016-09-10 17:43:45 +04:30
parent c23fd14771
commit f379f208db
5 changed files with 20 additions and 18 deletions

View File

@ -44,7 +44,7 @@ A simple Machine Learning library.
# neural network examples # neural network examples
stack exec example-xor stack exec example-xor
stack exec example-424 stack exec example-424
# notMNIST dataset, achieves ~87% accuracy using exponential learning rate decay # notMNIST dataset, achieves ~87.5% accuracy after 9 epochs (2 minutes)
stack exec example-notmnist stack exec example-notmnist
# Naive Bayes document classifier, using Reuters dataset # Naive Bayes document classifier, using Reuters dataset

View File

@ -21,11 +21,12 @@ module Main where
import Graphics.Rendering.Chart.Backend.Cairo import Graphics.Rendering.Chart.Backend.Cairo
main = do main = do
-- random seed, you might comment this line to get real random results
setStdGen (mkStdGen 100) setStdGen (mkStdGen 100)
let a = (sigmoid, sigmoid') let a = (sigmoid, sigmoid')
o = (softmax, one) o = (softmax, crossEntropy')
rnetwork = randomNetwork 0 (-1, 1) (28*28) [(100, a)] (10, a) rnetwork = randomNetwork 0 (-1, 1) (28*28) [(100, a)] (10, o)
(inputs, labels) <- dataset (inputs, labels) <- dataset
@ -41,11 +42,11 @@ module Main where
telabels = take tep . drop trp $ labels telabels = take tep . drop trp $ labels
let session = def { learningRate = 0.5 let session = def { learningRate = 0.5
, batchSize = 32 , batchSize = 32
, epochs = 24 , epochs = 10
, network = rnetwork , network = rnetwork
, training = zip trinputs trlabels , training = zip trinputs trlabels
, test = zip teinputs telabels , test = zip teinputs telabels
} :: Session } :: Session
let initialCost = crossEntropy session let initialCost = crossEntropy session

Binary file not shown.

Before

Width:  |  Height:  |  Size: 28 KiB

After

Width:  |  Height:  |  Size: 33 KiB

BIN
sgd.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 28 KiB

After

Width:  |  Height:  |  Size: 31 KiB

View File

@ -22,10 +22,10 @@ module Sibe
sigmoid', sigmoid',
softmax, softmax,
softmax', softmax',
one,
relu, relu,
relu', relu',
crossEntropy, crossEntropy,
crossEntropy',
genSeed, genSeed,
replaceVector, replaceVector,
Session(..), Session(..),
@ -143,11 +143,10 @@ module Sibe
where where
s = V.sum $ exp x s = V.sum $ exp x
one :: a -> Double
one x = 1
softmax' :: Vector Double -> Vector Double softmax' :: Vector Double -> Vector Double
softmax' x = softmax x * (1 - softmax x) softmax' = cmap (\a -> sig a * (1 - sig a))
where
sig x = 1 / max (1 + exp (-x)) 1e-10
relu :: Vector Double -> Vector Double relu :: Vector Double -> Vector Double
relu = cmap (max 0.1) relu = cmap (max 0.1)
@ -165,11 +164,13 @@ module Sibe
outputs = map (toList . (`forward` session)) inputs outputs = map (toList . (`forward` session)) inputs
pairs = zip outputs labels pairs = zip outputs labels
n = genericLength pairs n = genericLength pairs
in sum (map set pairs) / n in sum (map set pairs) / n
where where
set (os, ls) = (-1 / genericLength os) * sum (zipWith (curry f) os ls) set (os, ls) = (-1 / genericLength os) * sum (zipWith f os ls)
f (a, y) = y * log (max 1e-10 a) + (1 - y) * log (max (1 - a) 1e-10) f a y = y * log (max 1e-10 a)
crossEntropy' :: Vector Double -> Vector Double
crossEntropy' x = 1 / fromIntegral (V.length x)
train :: Input train :: Input
-> Network -> Network
@ -182,9 +183,9 @@ module Sibe
run input (O l@(Layer biases weights (fn, fn'))) = run input (O l@(Layer biases weights (fn, fn'))) =
let y = runLayer input l let y = runLayer input l
o = fn y o = fn y
delta = o - target delta = o - target
de = delta * fn' y de = delta * fn' y
-- de = delta -- cross entropy cost -- de = delta / fromIntegral (V.length o) -- cross entropy cost
biases' = biases - scale alpha de biases' = biases - scale alpha de
weights' = weights - scale alpha (input `outer` de) -- small inputs learn slowly weights' = weights - scale alpha (input `outer` de) -- small inputs learn slowly