fix(crossEntropy): implement crossEntropy' to be used in output layer
fix(softmax'): softmax was not correct
This commit is contained in:
		@@ -44,7 +44,7 @@ A simple Machine Learning library.
 | 
				
			|||||||
# neural network examples
 | 
					# neural network examples
 | 
				
			||||||
stack exec example-xor
 | 
					stack exec example-xor
 | 
				
			||||||
stack exec example-424
 | 
					stack exec example-424
 | 
				
			||||||
# notMNIST dataset, achieves ~87% accuracy using exponential learning rate decay
 | 
					# notMNIST dataset, achieves ~87.5% accuracy after 9 epochs (2 minutes)
 | 
				
			||||||
stack exec example-notmnist
 | 
					stack exec example-notmnist
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Naive Bayes document classifier, using Reuters dataset
 | 
					# Naive Bayes document classifier, using Reuters dataset
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -21,11 +21,12 @@ module Main where
 | 
				
			|||||||
  import Graphics.Rendering.Chart.Backend.Cairo
 | 
					  import Graphics.Rendering.Chart.Backend.Cairo
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  main = do
 | 
					  main = do
 | 
				
			||||||
 | 
					    -- random seed, you might comment this line to get real random results
 | 
				
			||||||
    setStdGen (mkStdGen 100)
 | 
					    setStdGen (mkStdGen 100)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    let a         = (sigmoid, sigmoid')
 | 
					    let a         = (sigmoid, sigmoid')
 | 
				
			||||||
        o         = (softmax, one)
 | 
					        o         = (softmax, crossEntropy')
 | 
				
			||||||
        rnetwork  = randomNetwork 0 (-1, 1) (28*28) [(100, a)] (10, a)
 | 
					        rnetwork  = randomNetwork 0 (-1, 1) (28*28) [(100, a)] (10, o)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    (inputs, labels) <- dataset
 | 
					    (inputs, labels) <- dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -42,7 +43,7 @@ module Main where
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    let session = def { learningRate = 0.5
 | 
					    let session = def { learningRate = 0.5
 | 
				
			||||||
                      , batchSize    = 32
 | 
					                      , batchSize    = 32
 | 
				
			||||||
                      , epochs = 24
 | 
					                      , epochs       = 10
 | 
				
			||||||
                      , network      = rnetwork
 | 
					                      , network      = rnetwork
 | 
				
			||||||
                      , training     = zip trinputs trlabels
 | 
					                      , training     = zip trinputs trlabels
 | 
				
			||||||
                      , test         = zip teinputs telabels
 | 
					                      , test         = zip teinputs telabels
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										
											BIN
										
									
								
								notmnist.png
									
									
									
									
									
								
							
							
						
						
									
										
											BIN
										
									
								
								notmnist.png
									
									
									
									
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 Before Width: | Height: | Size: 28 KiB After Width: | Height: | Size: 33 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								sgd.png
									
									
									
									
									
								
							
							
						
						
									
										
											BIN
										
									
								
								sgd.png
									
									
									
									
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 Before Width: | Height: | Size: 28 KiB After Width: | Height: | Size: 31 KiB  | 
							
								
								
									
										19
									
								
								src/Sibe.hs
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								src/Sibe.hs
									
									
									
									
									
								
							@@ -22,10 +22,10 @@ module Sibe
 | 
				
			|||||||
     sigmoid',
 | 
					     sigmoid',
 | 
				
			||||||
     softmax,
 | 
					     softmax,
 | 
				
			||||||
     softmax',
 | 
					     softmax',
 | 
				
			||||||
     one,
 | 
					 | 
				
			||||||
     relu,
 | 
					     relu,
 | 
				
			||||||
     relu',
 | 
					     relu',
 | 
				
			||||||
     crossEntropy,
 | 
					     crossEntropy,
 | 
				
			||||||
 | 
					     crossEntropy',
 | 
				
			||||||
     genSeed,
 | 
					     genSeed,
 | 
				
			||||||
     replaceVector,
 | 
					     replaceVector,
 | 
				
			||||||
     Session(..),
 | 
					     Session(..),
 | 
				
			||||||
@@ -143,11 +143,10 @@ module Sibe
 | 
				
			|||||||
        where
 | 
					        where
 | 
				
			||||||
          s = V.sum $ exp x
 | 
					          s = V.sum $ exp x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      one :: a -> Double
 | 
					 | 
				
			||||||
      one x = 1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      softmax' :: Vector Double -> Vector Double
 | 
					      softmax' :: Vector Double -> Vector Double
 | 
				
			||||||
      softmax' x = softmax x * (1 - softmax x)
 | 
					      softmax' = cmap (\a -> sig a * (1 - sig a))
 | 
				
			||||||
 | 
					        where
 | 
				
			||||||
 | 
					          sig x = 1 / max (1 + exp (-x)) 1e-10
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      relu :: Vector Double -> Vector Double
 | 
					      relu :: Vector Double -> Vector Double
 | 
				
			||||||
      relu = cmap (max 0.1)
 | 
					      relu = cmap (max 0.1)
 | 
				
			||||||
@@ -165,11 +164,13 @@ module Sibe
 | 
				
			|||||||
            outputs = map (toList . (`forward` session)) inputs
 | 
					            outputs = map (toList . (`forward` session)) inputs
 | 
				
			||||||
            pairs = zip outputs labels
 | 
					            pairs = zip outputs labels
 | 
				
			||||||
            n = genericLength pairs
 | 
					            n = genericLength pairs
 | 
				
			||||||
 | 
					 | 
				
			||||||
        in sum (map set pairs) / n
 | 
					        in sum (map set pairs) / n
 | 
				
			||||||
        where
 | 
					        where
 | 
				
			||||||
          set (os, ls) = (-1 / genericLength os) * sum (zipWith (curry f) os ls)
 | 
					          set (os, ls) = (-1 / genericLength os) * sum (zipWith f os ls)
 | 
				
			||||||
          f (a, y) = y * log (max 1e-10 a) + (1 - y) * log (max (1 - a) 1e-10)
 | 
					          f a y = y * log (max 1e-10 a)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      crossEntropy' :: Vector Double -> Vector Double
 | 
				
			||||||
 | 
					      crossEntropy' x = 1 / fromIntegral (V.length x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      train :: Input
 | 
					      train :: Input
 | 
				
			||||||
            -> Network
 | 
					            -> Network
 | 
				
			||||||
@@ -184,7 +185,7 @@ module Sibe
 | 
				
			|||||||
                o = fn y
 | 
					                o = fn y
 | 
				
			||||||
                delta = o - target 
 | 
					                delta = o - target 
 | 
				
			||||||
                de = delta * fn' y
 | 
					                de = delta * fn' y
 | 
				
			||||||
                -- de = delta -- cross entropy cost
 | 
					                -- de = delta / fromIntegral (V.length o) -- cross entropy cost
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                biases'  = biases  - scale alpha de
 | 
					                biases'  = biases  - scale alpha de
 | 
				
			||||||
                weights' = weights - scale alpha (input `outer` de) -- small inputs learn slowly
 | 
					                weights' = weights - scale alpha (input `outer` de) -- small inputs learn slowly
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user