summaryrefslogtreecommitdiff
path: root/mnist/Neuronet.hs
diff options
context:
space:
mode:
authorMiguel <m.i@gmx.at>2019-03-23 13:15:09 +0100
committerMiguel <m.i@gmx.at>2019-03-23 13:15:09 +0100
commit1afb966ff3f995b4ac08b9ad30a77caee85721fd (patch)
tree93f968fb0868ed652d45dd2d1274c0cd15885a45 /mnist/Neuronet.hs
parent8281304e3a7bea0cb1678f899e371f8d4776f34f (diff)
more cleaning
Diffstat (limited to 'mnist/Neuronet.hs')
-rw-r--r--mnist/Neuronet.hs20
1 files changed, 6 insertions, 14 deletions
diff --git a/mnist/Neuronet.hs b/mnist/Neuronet.hs
index 6c3ea32..8a77622 100644
--- a/mnist/Neuronet.hs
+++ b/mnist/Neuronet.hs
@@ -1,12 +1,8 @@
module Neuronet
( Neuronet -- the neuronet
- ,neuronet -- initalize neuronet
- ,train -- train with one sample
- ,trainBatch -- train with batch
+ ,newnet -- initalize neuronet
+ ,train -- train with batch
,asknet -- ask the neuroal net
-
- ,wghtact
- ,backprop
)where
import Data.List
@@ -23,8 +19,8 @@ type Neuronet = [Layer]
-- | Initialize a fresh neuronal network given the number of neurons on
-- each layer, as a list. Weights and biases are initialized randomly
-- using gaussian distribution with mean 0 and standard deviation 1.
-neuronet :: [Int] -> IO Neuronet
-neuronet l = mapM nl $ zip l (tail l)
+newnet :: [Int] -> IO Neuronet
+newnet l = mapM nl $ zip l (tail l)
where nl (i,l) = (,) <$> randn l i <*>
(randn 1 l >>= return.fromList.head.toLists)
@@ -62,13 +58,9 @@ sigmoid' x = sigmoid x * (1-sigmoid x)
cost_derivative :: Vector Double -> Vector Double -> Vector Double
cost_derivative a y = a-y
--- | Train on one single sample
-train :: Double -> Neuronet -> Vector Double -> Vector Double -> Neuronet
-train r net x y = zipWith (upd r) net (backprop net x y)
-
-- | Train on a batch of samples
-trainBatch :: Double -> Neuronet -> [Vector Double] -> [Vector Double] -> Neuronet
-trainBatch r net xs ys = zipWith (upd r) net bp
+train :: Double -> Neuronet -> [Vector Double] -> [Vector Double] -> Neuronet
+train r net xs ys = zipWith (upd r) net bp
where bp = foldl1' fc $ map (uncurry $ backprop net) (zip xs ys)
fc v a = zipWith ff v a
ff (a,b) (c,d) = (a+c,b+d)