fix(crossEntropy): implement crossEntropy' to be used in output layer
fix(softmax'): softmax was not correct
This commit is contained in:
parent
c23fd14771
commit
f379f208db
@ -44,7 +44,7 @@ A simple Machine Learning library.
|
|||||||
# neural network examples
|
# neural network examples
|
||||||
stack exec example-xor
|
stack exec example-xor
|
||||||
stack exec example-424
|
stack exec example-424
|
||||||
# notMNIST dataset, achieves ~87% accuracy using exponential learning rate decay
|
# notMNIST dataset, achieves ~87.5% accuracy after 9 epochs (2 minutes)
|
||||||
stack exec example-notmnist
|
stack exec example-notmnist
|
||||||
|
|
||||||
# Naive Bayes document classifier, using Reuters dataset
|
# Naive Bayes document classifier, using Reuters dataset
|
||||||
|
@ -21,11 +21,12 @@ module Main where
|
|||||||
import Graphics.Rendering.Chart.Backend.Cairo
|
import Graphics.Rendering.Chart.Backend.Cairo
|
||||||
|
|
||||||
main = do
|
main = do
|
||||||
|
-- random seed, you might comment this line to get real random results
|
||||||
setStdGen (mkStdGen 100)
|
setStdGen (mkStdGen 100)
|
||||||
|
|
||||||
let a = (sigmoid, sigmoid')
|
let a = (sigmoid, sigmoid')
|
||||||
o = (softmax, one)
|
o = (softmax, crossEntropy')
|
||||||
rnetwork = randomNetwork 0 (-1, 1) (28*28) [(100, a)] (10, a)
|
rnetwork = randomNetwork 0 (-1, 1) (28*28) [(100, a)] (10, o)
|
||||||
|
|
||||||
(inputs, labels) <- dataset
|
(inputs, labels) <- dataset
|
||||||
|
|
||||||
@ -42,7 +43,7 @@ module Main where
|
|||||||
|
|
||||||
let session = def { learningRate = 0.5
|
let session = def { learningRate = 0.5
|
||||||
, batchSize = 32
|
, batchSize = 32
|
||||||
, epochs = 24
|
, epochs = 10
|
||||||
, network = rnetwork
|
, network = rnetwork
|
||||||
, training = zip trinputs trlabels
|
, training = zip trinputs trlabels
|
||||||
, test = zip teinputs telabels
|
, test = zip teinputs telabels
|
||||||
|
BIN
notmnist.png
BIN
notmnist.png
Binary file not shown.
Before Width: | Height: | Size: 28 KiB After Width: | Height: | Size: 33 KiB |
BIN
sgd.png
BIN
sgd.png
Binary file not shown.
Before Width: | Height: | Size: 28 KiB After Width: | Height: | Size: 31 KiB |
19
src/Sibe.hs
19
src/Sibe.hs
@ -22,10 +22,10 @@ module Sibe
|
|||||||
sigmoid',
|
sigmoid',
|
||||||
softmax,
|
softmax,
|
||||||
softmax',
|
softmax',
|
||||||
one,
|
|
||||||
relu,
|
relu,
|
||||||
relu',
|
relu',
|
||||||
crossEntropy,
|
crossEntropy,
|
||||||
|
crossEntropy',
|
||||||
genSeed,
|
genSeed,
|
||||||
replaceVector,
|
replaceVector,
|
||||||
Session(..),
|
Session(..),
|
||||||
@ -143,11 +143,10 @@ module Sibe
|
|||||||
where
|
where
|
||||||
s = V.sum $ exp x
|
s = V.sum $ exp x
|
||||||
|
|
||||||
one :: a -> Double
|
|
||||||
one x = 1
|
|
||||||
|
|
||||||
softmax' :: Vector Double -> Vector Double
|
softmax' :: Vector Double -> Vector Double
|
||||||
softmax' x = softmax x * (1 - softmax x)
|
softmax' = cmap (\a -> sig a * (1 - sig a))
|
||||||
|
where
|
||||||
|
sig x = 1 / max (1 + exp (-x)) 1e-10
|
||||||
|
|
||||||
relu :: Vector Double -> Vector Double
|
relu :: Vector Double -> Vector Double
|
||||||
relu = cmap (max 0.1)
|
relu = cmap (max 0.1)
|
||||||
@ -165,11 +164,13 @@ module Sibe
|
|||||||
outputs = map (toList . (`forward` session)) inputs
|
outputs = map (toList . (`forward` session)) inputs
|
||||||
pairs = zip outputs labels
|
pairs = zip outputs labels
|
||||||
n = genericLength pairs
|
n = genericLength pairs
|
||||||
|
|
||||||
in sum (map set pairs) / n
|
in sum (map set pairs) / n
|
||||||
where
|
where
|
||||||
set (os, ls) = (-1 / genericLength os) * sum (zipWith (curry f) os ls)
|
set (os, ls) = (-1 / genericLength os) * sum (zipWith f os ls)
|
||||||
f (a, y) = y * log (max 1e-10 a) + (1 - y) * log (max (1 - a) 1e-10)
|
f a y = y * log (max 1e-10 a)
|
||||||
|
|
||||||
|
crossEntropy' :: Vector Double -> Vector Double
|
||||||
|
crossEntropy' x = 1 / fromIntegral (V.length x)
|
||||||
|
|
||||||
train :: Input
|
train :: Input
|
||||||
-> Network
|
-> Network
|
||||||
@ -184,7 +185,7 @@ module Sibe
|
|||||||
o = fn y
|
o = fn y
|
||||||
delta = o - target
|
delta = o - target
|
||||||
de = delta * fn' y
|
de = delta * fn' y
|
||||||
-- de = delta -- cross entropy cost
|
-- de = delta / fromIntegral (V.length o) -- cross entropy cost
|
||||||
|
|
||||||
biases' = biases - scale alpha de
|
biases' = biases - scale alpha de
|
||||||
weights' = weights - scale alpha (input `outer` de) -- small inputs learn slowly
|
weights' = weights - scale alpha (input `outer` de) -- small inputs learn slowly
|
||||||
|
Loading…
Reference in New Issue
Block a user