From 9c7b00c58ae0b4ece9f46a7226b59248b8b9eba6 Mon Sep 17 00:00:00 2001 From: Miguel Date: Fri, 22 Mar 2019 23:03:41 +0100 Subject: getting nicer --- mnist/main.hs | 54 +++++++++++++++++++++++++++++------------------------- 1 file changed, 29 insertions(+), 25 deletions(-) (limited to 'mnist/main.hs') diff --git a/mnist/main.hs b/mnist/main.hs index a91984a..2810e34 100644 --- a/mnist/main.hs +++ b/mnist/main.hs @@ -2,6 +2,7 @@ import Neuronet import System.Random(randomRIO) import Numeric.LinearAlgebra import Data.List +import Data.Foldable(foldlM) import Data.List.Split import Data.Tuple.Extra import System.Random.Shuffle @@ -24,10 +25,13 @@ epochs :: Int -> Int -> Samples -> IO [[Samples]] epochs 0 _ _ = return [] epochs e n s = (:) <$> batches n s <*> epochs (e-1) n s --- train for multiple epochs -training :: Double -> Neuronet -> [[Samples]] -> Neuronet -training r net s = foldl' f net (concat s) - where f a v = trainBatch r a (fst.unzip $ v) (snd.unzip $ v) +-- train for multiple epochs, optionally . testing after each. +training :: Bool -> (Neuronet->String) -> Double -> Neuronet -> [[Samples]] -> IO Neuronet +training tst tstf r net s = foldlM f net (zip s [1..]) + where f nt (v,i) = do putStr $ "Epoch "++ show i ++ "...." + let n = foldl' (\n x->trainBatch r n (fst.unzip $ x) (snd.unzip $ x)) nt v + putStrLn $ tstf n + return n -- test with given samples and return number of correct answers testing :: (Vector Double -> Vector Double -> Bool) -> Neuronet -> Samples -> Int @@ -54,29 +58,29 @@ read_samples f1 f2 = do -- MNIST main function mainMNIST :: IO () mainMNIST =do - let ep = 1 -- number of epochs - let mbs = 10 -- mini-batch size - let lr = 3 -- learning rate + + let ep = 20 -- number of epochs + let mbs = 10 -- mini-batch size + let lr = 2 -- learning rate + let cap = 999999 -- cap number of training samples + + putStrLn "= Init =" + str "Initializing Net" nt <- neuronet [28*28,30,10] - smpl_train <- read_samples "train-images-idx3-ubyte" "train-labels-idx1-ubyte" + done + + str "Reading Samples" + smpl_train <- take cap <$> read_samples "train-images-idx3-ubyte" "train-labels-idx1-ubyte" smpl_test <- read_samples "t10k-images-idx3-ubyte" "t10k-labels-idx1-ubyte" - tr <- epochs ep mbs smpl_train >>= return . training (lr/fromIntegral mbs) nt - let passed = testing chk tr smpl_test - print $ show passed ++ "/10000 (" ++ show (fromIntegral passed/100)++ "%)" - where chk y1 y2 = val y1 == val y2 - val x=snd . last . sort $ zip (toList x) [0..9] + done --- just a quick and simple network created manually, used for experimenting -mainMANUAL :: IO () -mainMANUAL = do + putStrLn "= Training =" + tr <- epochs ep mbs smpl_train >>= training True (tst smpl_test) (lr/fromIntegral mbs) nt - let nt =[ ((2><2)[0.2,0.3,0.4,0.5],fromList[0.6,-0.6]) -- L1 - ,((2><2)[-0.5,-0.4,-0.3,-0.2],fromList[0.4,0.5]) -- L2 - ,((1><2)[0.25,0.35],fromList[0.9]) -- L3 - ] + putStrLn "= THE END =" - print nt - print $ wghtact nt $ fromList [0.8,0.9] - print $ backprop nt (fromList [0.8,0.9]) (fromList [1]) - print $ train 0.3 nt (fromList [0.8,0.9]) (fromList [1]) - print $ trainBatch 0.15 nt [fromList [0.8,0.9],fromList [0.8,0.9]] [fromList [1],fromList [1]] + where chk y1 y2 = val y1 == val y2 + val x=snd . last . sort $ zip (toList x) [0..9] + done = putStrLn "...[\ESC[32m\STXDone\ESC[m\STX]" + str v = putStr $ take 20 (v++repeat '.') + tst smpl n = "... \ESC[32m\STX" ++ show (fromIntegral (testing chk n smpl) / 100) ++ "\ESC[m\STX%" -- cgit v1.2.3