diff options
Diffstat (limited to 'mnist')
| -rw-r--r-- | mnist/Neuronet.hs | 5 | ||||
| -rw-r--r-- | mnist/main.hs | 54 |
2 files changed, 31 insertions, 28 deletions
diff --git a/mnist/Neuronet.hs b/mnist/Neuronet.hs index e3344c7..ece288a 100644 --- a/mnist/Neuronet.hs +++ b/mnist/Neuronet.hs @@ -24,7 +24,7 @@ module Neuronet ,backprop )where -import Numeric.LinearAlgebra (Matrix,Vector,tr,scale,cmap,(#>),randn,toList,fromList,toLists,fromLists,Container) +import Numeric.LinearAlgebra (Matrix,Vector,tr,scale,cmap,(#>),randn,toList,fromList,toLists,fromLists,outer) import Data.List -- | A layer of our network consists of a weight matrix with input weights @@ -58,10 +58,9 @@ asknet net x = snd . last $ wghtact net x -- split in the weight and bias partial derivatives respectively). -- Keep the required assumptions about the cost function in mind! backprop :: Neuronet -> Vector Double -> Vector Double -> [(Matrix Double,Vector Double)] -backprop net x y = zipWith (\a e->(wm a e,e)) (x:map snd wa) (go $ zip ws wa) +backprop net x y = zipWith (\a e->(outer e a,e)) (x:map snd wa) (go $ zip ws wa) where ws = (++[fromLists []]) . tail . map fst $ net wa = wghtact net x - wm a e = fromLists $ map (\e->map (*e) (toList a)) (toList e) go [(w,(z,a))] = [cost_derivative a y * cmap sigmoid' z] go ((w,(z,a)):lx) =let r@(e:_)=go lx in tr w #> e * cmap sigmoid' z:r diff --git a/mnist/main.hs b/mnist/main.hs index a91984a..2810e34 100644 --- a/mnist/main.hs +++ b/mnist/main.hs @@ -2,6 +2,7 @@ import Neuronet import System.Random(randomRIO) import Numeric.LinearAlgebra import Data.List +import Data.Foldable(foldlM) import Data.List.Split import Data.Tuple.Extra import System.Random.Shuffle @@ -24,10 +25,13 @@ epochs :: Int -> Int -> Samples -> IO [[Samples]] epochs 0 _ _ = return [] epochs e n s = (:) <$> batches n s <*> epochs (e-1) n s --- train for multiple epochs -training :: Double -> Neuronet -> [[Samples]] -> Neuronet -training r net s = foldl' f net (concat s) - where f a v = trainBatch r a (fst.unzip $ v) (snd.unzip $ v) +-- train for multiple epochs, optionally . testing after each. +training :: Bool -> (Neuronet->String) -> Double -> Neuronet -> [[Samples]] -> IO Neuronet +training tst tstf r net s = foldlM f net (zip s [1..]) + where f nt (v,i) = do putStr $ "Epoch "++ show i ++ "...." + let n = foldl' (\n x->trainBatch r n (fst.unzip $ x) (snd.unzip $ x)) nt v + putStrLn $ tstf n + return n -- test with given samples and return number of correct answers testing :: (Vector Double -> Vector Double -> Bool) -> Neuronet -> Samples -> Int @@ -54,29 +58,29 @@ read_samples f1 f2 = do -- MNIST main function mainMNIST :: IO () mainMNIST =do - let ep = 1 -- number of epochs - let mbs = 10 -- mini-batch size - let lr = 3 -- learning rate + + let ep = 20 -- number of epochs + let mbs = 10 -- mini-batch size + let lr = 2 -- learning rate + let cap = 999999 -- cap number of training samples + + putStrLn "= Init =" + str "Initializing Net" nt <- neuronet [28*28,30,10] - smpl_train <- read_samples "train-images-idx3-ubyte" "train-labels-idx1-ubyte" + done + + str "Reading Samples" + smpl_train <- take cap <$> read_samples "train-images-idx3-ubyte" "train-labels-idx1-ubyte" smpl_test <- read_samples "t10k-images-idx3-ubyte" "t10k-labels-idx1-ubyte" - tr <- epochs ep mbs smpl_train >>= return . training (lr/fromIntegral mbs) nt - let passed = testing chk tr smpl_test - print $ show passed ++ "/10000 (" ++ show (fromIntegral passed/100)++ "%)" - where chk y1 y2 = val y1 == val y2 - val x=snd . last . sort $ zip (toList x) [0..9] + done --- just a quick and simple network created manually, used for experimenting -mainMANUAL :: IO () -mainMANUAL = do + putStrLn "= Training =" + tr <- epochs ep mbs smpl_train >>= training True (tst smpl_test) (lr/fromIntegral mbs) nt - let nt =[ ((2><2)[0.2,0.3,0.4,0.5],fromList[0.6,-0.6]) -- L1 - ,((2><2)[-0.5,-0.4,-0.3,-0.2],fromList[0.4,0.5]) -- L2 - ,((1><2)[0.25,0.35],fromList[0.9]) -- L3 - ] + putStrLn "= THE END =" - print nt - print $ wghtact nt $ fromList [0.8,0.9] - print $ backprop nt (fromList [0.8,0.9]) (fromList [1]) - print $ train 0.3 nt (fromList [0.8,0.9]) (fromList [1]) - print $ trainBatch 0.15 nt [fromList [0.8,0.9],fromList [0.8,0.9]] [fromList [1],fromList [1]] + where chk y1 y2 = val y1 == val y2 + val x=snd . last . sort $ zip (toList x) [0..9] + done = putStrLn "...[\ESC[32m\STXDone\ESC[m\STX]" + str v = putStr $ take 20 (v++repeat '.') + tst smpl n = "... \ESC[32m\STX" ++ show (fromIntegral (testing chk n smpl) / 100) ++ "\ESC[m\STX%" |
