summaryrefslogtreecommitdiff
path: root/mnist/main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'mnist/main.hs')
-rw-r--r--mnist/main.hs42
1 files changed, 17 insertions, 25 deletions
diff --git a/mnist/main.hs b/mnist/main.hs
index 02ba6f7..d1b5471 100644
--- a/mnist/main.hs
+++ b/mnist/main.hs
@@ -8,6 +8,7 @@ import Data.Tuple.Extra
import System.Random.Shuffle
import qualified Data.ByteString as BS
import System.IO
+import Control.DeepSeq
-- a single data-sample with input and expected output
type Sample = (Vector Double,Vector Double)
@@ -29,7 +30,7 @@ epochs e n s = (:) <$> batches n s <*> epochs (e-1) n s
training :: Bool -> (Neuronet->String) -> Double -> Neuronet -> [[Samples]] -> IO Neuronet
training tst tstf r net s = foldlM f net (zip s [1..])
where f nt (v,i) = do putStr $ "Epoch "++ show i ++ "...."
- let n = foldl' (\n x->trainBatch r n (fst.unzip $ x) (snd.unzip $ x)) nt v
+ let n = foldl' (\n x->train r n (fst.unzip $ x) (snd.unzip $ x)) nt v
putStrLn $ tstf n
return n
@@ -37,11 +38,6 @@ training tst tstf r net s = foldlM f net (zip s [1..])
testing :: (Vector Double -> Vector Double -> Bool) -> Neuronet -> Samples -> Int
testing f net s = length . filter id $ map (\(x,y)->f y (asknet net x)) s
--- finally some learning and testing with MNIST
--- MNIST files from http://yann.lecun.com/exdb/mnist/
-main :: IO ()
-main = mainMNIST
-
-- create Samples given two MNIST files for images and labels
read_samples :: FilePath -> FilePath -> IO Samples
read_samples f1 f2 = do
@@ -55,29 +51,25 @@ read_samples f1 f2 = do
where zrs= take 9 $ repeat 0
val x= take x zrs ++ [1] ++ drop x zrs
--- MNIST main function
-mainMNIST :: IO ()
-mainMNIST =do
-
- let ep = 10 -- number of epochs
- let mbs = 10 -- mini-batch size
- let lr = 3 -- learning rate
- let cap = 999999 -- cap number of training samples
+-- finally some learning and testing with MNIST
+-- MNIST files from http://yann.lecun.com/exdb/mnist/
+main = do s <- read_samples "train-images-idx3-ubyte" "train-labels-idx1-ubyte"
+ print $ s `deepseq` length s
- putStrLn "= Init ="
- str "Initializing Net"
- nt <- neuronet [28*28,30,10]
- done
+main2 :: IO ()
+main2 = do
- str "Reading Samples"
+ let ep = 10 -- number of epochs
+ let mbs = 10 -- mini-batch size
+ let lr = 3 -- learning rate
+ let lay = [28*28,30,10] -- number of neurons by layer
+ let cap = 999999 -- cap number of training samples
+
+ nt <- newnet lay
smpl_train <- take cap <$> read_samples "train-images-idx3-ubyte" "train-labels-idx1-ubyte"
smpl_test <- read_samples "t10k-images-idx3-ubyte" "t10k-labels-idx1-ubyte"
- done
-
- putStrLn "= Training ="
- tr <- epochs ep mbs smpl_train >>= training True (tst smpl_test) (lr/fromIntegral mbs) nt
-
- putStrLn "= THE END ="
+ tr <- epochs ep mbs smpl_train >>= training True (tst smpl_test) (lr/fromIntegral mbs) nt
+ putStrLn "end"
where chk y1 y2 = val y1 == val y2
val x=snd . last . sort $ zip (toList x) [0..9]