fix(train): fix an error in computing layer's error
feat(examples): add an example (xor)
This commit is contained in:
		
							
								
								
									
										
											BIN
										
									
								
								app/Main.hi
									
									
									
									
									
								
							
							
						
						
									
										
											BIN
										
									
								
								app/Main.hi
									
									
									
									
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										17
									
								
								app/Main.hs
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								app/Main.hs
									
									
									
									
									
								
							| @@ -9,19 +9,16 @@ import Debug.Trace | ||||
| -- 2x3 + 1x3 | ||||
| -- 3x1 + 1x1 | ||||
|  | ||||
| main :: IO () | ||||
| -- main :: IO [()] | ||||
| main = | ||||
|   let learning_rate = 0.01 | ||||
|       ih = randomLayer 0 (2, 10) | ||||
|       ho = randomLayer 1 (10, 1) | ||||
|   let learning_rate = 0.5 | ||||
|       ih = randomLayer 0 (2, 8) | ||||
|       ho = randomLayer 1 (8, 1) | ||||
|       network = ih :- O ho | ||||
|       inputs = [vector [0, 1], vector [1, 1], vector [1, 0], vector [0, 0]] | ||||
|  | ||||
|       -- result = forward input network | ||||
|       inputs = [vector [0, 1], vector [1, 0], vector [1, 1], vector [0, 0]] | ||||
|       labels = [vector [1], vector [1], vector [0], vector [0]] | ||||
|  | ||||
|       labels = [vector [1], vector [0], vector [1], vector [0]] | ||||
|  | ||||
|       updated_network = session inputs network labels learning_rate 100 | ||||
|       -- updated_network = train (head inputs) network (head labels) 0.5 | ||||
|       updated_network = session inputs network labels learning_rate (2, 1000) | ||||
|       results = map (`forward` updated_network) inputs | ||||
|   in print results | ||||
|   | ||||
							
								
								
									
										
											BIN
										
									
								
								app/Main.o
									
									
									
									
									
								
							
							
						
						
									
										
											BIN
										
									
								
								app/Main.o
									
									
									
									
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								examples/xor
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								examples/xor
									
									
									
									
									
										Executable file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										32
									
								
								examples/xor.hs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								examples/xor.hs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| module Main where | ||||
|   import Lib | ||||
|   import Numeric.LinearAlgebra | ||||
|   import Data.List | ||||
|   import Debug.Trace | ||||
|  | ||||
|   -- 1x2 | ||||
|   -- 2x3 + 1x3 | ||||
|   -- 3x1 + 1x1 | ||||
|  | ||||
|   -- main :: IO [()] | ||||
|   main = | ||||
|     let learning_rate = 0.5 | ||||
|         (iterations, epochs) = (2, 1000) | ||||
|         ih = randomLayer 0 (2, 8) | ||||
|         ho = randomLayer 1 (8, 1) | ||||
|         network = ih :- O ho | ||||
|  | ||||
|         inputs = [vector [0, 1], vector [1, 0], vector [1, 1], vector [0, 0]] | ||||
|         labels = [vector [1], vector [1], vector [0], vector [0]] | ||||
|  | ||||
|         updated_network = session inputs network labels learning_rate (iterations, epochs) | ||||
|         results = map (`forward` updated_network) inputs | ||||
|         rounded = map (map round) $ map toList results | ||||
|     in sequence [putStrLn "", | ||||
|                  putStrLn $ "inputs: " ++ show inputs, | ||||
|                  putStrLn $ "labels: " ++ show labels, | ||||
|                  putStrLn $ "learning rate: " ++ show learning_rate, | ||||
|                  putStrLn $ "iterations/epochs: " ++ show (iterations, epochs), | ||||
|                  putStrLn $ "...", | ||||
|                  putStrLn $ "rounded result: " ++ show rounded, | ||||
|                  putStrLn $ "actual result: " ++ show results] | ||||
| @@ -27,6 +27,14 @@ executable sibe-exe | ||||
|                      , sibe | ||||
|   default-language:    Haskell2010 | ||||
|  | ||||
| executable example-xor | ||||
|   hs-source-dirs:      examples | ||||
|   main-is:             xor.hs | ||||
|   ghc-options:         -threaded -rtsopts -with-rtsopts=-N | ||||
|   build-depends:       base | ||||
|                      , sibe | ||||
|   default-language:    Haskell2010 | ||||
|  | ||||
| test-suite sibe-test | ||||
|   type:                exitcode-stdio-1.0 | ||||
|   hs-source-dirs:      test | ||||
|   | ||||
							
								
								
									
										
											BIN
										
									
								
								src/Lib.hi
									
									
									
									
									
								
							
							
						
						
									
										
											BIN
										
									
								
								src/Lib.hi
									
									
									
									
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										29
									
								
								src/Lib.hs
									
									
									
									
									
								
							
							
						
						
									
										29
									
								
								src/Lib.hs
									
									
									
									
									
								
							| @@ -12,12 +12,13 @@ module Lib | ||||
|      randomLayer, | ||||
|      train, | ||||
|      session, | ||||
|      shuffle, | ||||
|     ) where | ||||
|       import Numeric.LinearAlgebra | ||||
|       import Control.Monad.Random | ||||
|       import System.Random | ||||
|       import Debug.Trace | ||||
|       import Data.List (foldl') | ||||
|       import Data.List (foldl', sortBy) | ||||
|  | ||||
|       type LearningRate = Double | ||||
|       type Input = Vector Double | ||||
| @@ -49,7 +50,7 @@ module Lib | ||||
|       logistic x = 1 / (1 + exp (-x)) | ||||
|  | ||||
|       logistic' :: Double -> Double | ||||
|       logistic' x = logistic x / max 1e-8 (1 - logistic x) | ||||
|       logistic' x = logistic x / max 1e-10 (1 - logistic x) | ||||
|  | ||||
|       train :: Input | ||||
|             -> Network | ||||
| @@ -63,7 +64,7 @@ module Lib | ||||
|             let y = runLayer input l | ||||
|                 o = cmap logistic y | ||||
|                 delta = o - target | ||||
|                 de = delta * cmap logistic' y | ||||
|                 de = delta * cmap logistic' o | ||||
|  | ||||
|                 biases'  = biases  - scale alpha de | ||||
|                 weights' = weights - scale alpha (input `outer` de) -- small inputs learn slowly | ||||
| @@ -78,7 +79,7 @@ module Lib | ||||
|                 o = cmap logistic y | ||||
|                 (n', delta) = run o n | ||||
|  | ||||
|                 de = delta * cmap logistic' y | ||||
|                 de = delta * cmap logistic' o | ||||
|  | ||||
|                 biases'  = biases  - scale alpha de | ||||
|                 weights' = weights - scale alpha (input `outer` de) | ||||
| @@ -88,14 +89,22 @@ module Lib | ||||
|                 -- pass = weights #> de | ||||
|             in (layer :- n', pass) | ||||
|  | ||||
|       session :: [Input] -> Network -> [Output] -> Double -> Int -> Network | ||||
|       session inputs network labels alpha epochs = | ||||
|         let n = length inputs - 1 | ||||
|         in foldl' iter network [0..n * epochs] | ||||
|       session :: [Input] -> Network -> [Output] -> Double -> (Int, Int) -> Network | ||||
|       session inputs network labels alpha (iterations, epochs) = | ||||
|         let n = length inputs | ||||
|             indexes = shuffle n (map (`mod` n) [0..n * epochs]) | ||||
|         in foldl' iter network indexes | ||||
|         where | ||||
|           iter net i = | ||||
|             let n = length inputs - 1 | ||||
|             let n = length inputs | ||||
|                 index = i `mod` n | ||||
|                 input = inputs !! index | ||||
|                 label = labels !! index | ||||
|             in train input net label alpha | ||||
|             in foldl' (\net _ -> train input net label alpha) net [0..iterations] | ||||
|  | ||||
|       shuffle :: Seed -> [a] -> [a] | ||||
|       shuffle seed list = | ||||
|         let ords = map ord $ take (length list) (randomRs (0, 1) (mkStdGen seed) :: [Int]) | ||||
|         in map snd $ sortBy (\x y -> fst x) (zip ords list) | ||||
|         where ord x | x == 0 = LT | ||||
|                     | x == 1 = GT | ||||
|   | ||||
		Reference in New Issue
	
	Block a user