summaryrefslogtreecommitdiff
path: root/mnist/main.hs
diff options
context:
space:
mode:
authorMiguel <m.i@gmx.at>2019-03-22 23:03:41 +0100
committerMiguel <m.i@gmx.at>2019-03-22 23:03:41 +0100
commit9c7b00c58ae0b4ece9f46a7226b59248b8b9eba6 (patch)
tree330d6b5da1d87318628e2ca54242fb07382ee1ae /mnist/main.hs
parente1826a4c5975260c784d3f6c43fd53a7092d64e4 (diff)
getting nicer
Diffstat (limited to 'mnist/main.hs')
-rw-r--r--mnist/main.hs54
1 files changed, 29 insertions, 25 deletions
diff --git a/mnist/main.hs b/mnist/main.hs
index a91984a..2810e34 100644
--- a/mnist/main.hs
+++ b/mnist/main.hs
@@ -2,6 +2,7 @@ import Neuronet
import System.Random(randomRIO)
import Numeric.LinearAlgebra
import Data.List
+import Data.Foldable(foldlM)
import Data.List.Split
import Data.Tuple.Extra
import System.Random.Shuffle
@@ -24,10 +25,13 @@ epochs :: Int -> Int -> Samples -> IO [[Samples]]
epochs 0 _ _ = return []
epochs e n s = (:) <$> batches n s <*> epochs (e-1) n s
--- train for multiple epochs
-training :: Double -> Neuronet -> [[Samples]] -> Neuronet
-training r net s = foldl' f net (concat s)
- where f a v = trainBatch r a (fst.unzip $ v) (snd.unzip $ v)
+-- train for multiple epochs, optionally . testing after each.
+training :: Bool -> (Neuronet->String) -> Double -> Neuronet -> [[Samples]] -> IO Neuronet
+training tst tstf r net s = foldlM f net (zip s [1..])
+ where f nt (v,i) = do putStr $ "Epoch "++ show i ++ "...."
+ let n = foldl' (\n x->trainBatch r n (fst.unzip $ x) (snd.unzip $ x)) nt v
+ putStrLn $ tstf n
+ return n
-- test with given samples and return number of correct answers
testing :: (Vector Double -> Vector Double -> Bool) -> Neuronet -> Samples -> Int
@@ -54,29 +58,29 @@ read_samples f1 f2 = do
-- MNIST main function
mainMNIST :: IO ()
mainMNIST =do
- let ep = 1 -- number of epochs
- let mbs = 10 -- mini-batch size
- let lr = 3 -- learning rate
+
+ let ep = 20 -- number of epochs
+ let mbs = 10 -- mini-batch size
+ let lr = 2 -- learning rate
+ let cap = 999999 -- cap number of training samples
+
+ putStrLn "= Init ="
+ str "Initializing Net"
nt <- neuronet [28*28,30,10]
- smpl_train <- read_samples "train-images-idx3-ubyte" "train-labels-idx1-ubyte"
+ done
+
+ str "Reading Samples"
+ smpl_train <- take cap <$> read_samples "train-images-idx3-ubyte" "train-labels-idx1-ubyte"
smpl_test <- read_samples "t10k-images-idx3-ubyte" "t10k-labels-idx1-ubyte"
- tr <- epochs ep mbs smpl_train >>= return . training (lr/fromIntegral mbs) nt
- let passed = testing chk tr smpl_test
- print $ show passed ++ "/10000 (" ++ show (fromIntegral passed/100)++ "%)"
- where chk y1 y2 = val y1 == val y2
- val x=snd . last . sort $ zip (toList x) [0..9]
+ done
--- just a quick and simple network created manually, used for experimenting
-mainMANUAL :: IO ()
-mainMANUAL = do
+ putStrLn "= Training ="
+ tr <- epochs ep mbs smpl_train >>= training True (tst smpl_test) (lr/fromIntegral mbs) nt
- let nt =[ ((2><2)[0.2,0.3,0.4,0.5],fromList[0.6,-0.6]) -- L1
- ,((2><2)[-0.5,-0.4,-0.3,-0.2],fromList[0.4,0.5]) -- L2
- ,((1><2)[0.25,0.35],fromList[0.9]) -- L3
- ]
+ putStrLn "= THE END ="
- print nt
- print $ wghtact nt $ fromList [0.8,0.9]
- print $ backprop nt (fromList [0.8,0.9]) (fromList [1])
- print $ train 0.3 nt (fromList [0.8,0.9]) (fromList [1])
- print $ trainBatch 0.15 nt [fromList [0.8,0.9],fromList [0.8,0.9]] [fromList [1],fromList [1]]
+ where chk y1 y2 = val y1 == val y2
+ val x=snd . last . sort $ zip (toList x) [0..9]
+ done = putStrLn "...[\ESC[32m\STXDone\ESC[m\STX]"
+ str v = putStr $ take 20 (v++repeat '.')
+ tst smpl n = "... \ESC[32m\STX" ++ show (fromIntegral (testing chk n smpl) / 100) ++ "\ESC[m\STX%"