Recognising digits
Having recently finished the excellent book Machine Learning Projects for .NET Developers by Mathias Brandewinder, I took the closing advice and enrolled in the Stanford Machine Learning course by the awesome Andrew Ng.
As part of this course we look at automatically recognising handwritten digits, a topic also covered in Mathias' book.
Performance
In the first chapter of the book we use the Manhattan distance to compare a target image to every single training sample, one pixel at a time, picking the closest match. This understandably doesn't scale well to large training sets and images.
For that reason, in both the final chapter of the book and the online course we are taught to use logistic regression to train a model which can then make predictions for us very quickly.
If you would like an introduction to the concept of regression you can check out the example .NET Interactive notebook which accompanies my previous blog post. I have also created an notebook of this blog for you to experiment with, which you can find in the same repo.
When learning about logistic and linear regression we implemented the gradient descent algorithms and cost functions ourselves.
Whilst writing the algorithms manually is a great way to gain a deep understanding, it often isn't practical for day to day work. These are solved problems, and we usually wish to spend our time tackling our specific domain issues. Our implementations are also likely to be suboptimal, written as they are for education rather than efficiency.
For this reason we are encouraged to explore third-party machine learning libraries which come with many of the puzzle pieces preassembled for us. We just have to configure and combine them into a useful shape for our particular problem.
ML libraries
-
The book uses Accord.Net. Whilst this is a rich and popular library, the author has declared it archived.
-
The course is taught using Octave / MATLAB, but I wanted to explore ML in F#.
I considered two options at this point - going full Microsoft with ML.NET or stepping towards the huge Python ML community with the SciSharp stack.
Today I will demonstrate the implementation using ML.NET. I intend to repeat the process with the SciSharp stack and compare the two in a future blog, so stay tuned for that one.
1. Import libraries
#r "nuget: Microsoft.ML"
open Microsoft.ML
open Microsoft.ML.Data
open Microsoft.ML.Transforms
open Microsoft.ML.Trainers
[<Literal>]
let TRAIN_PATH = @"C:\trainingsample.csv"
[<Literal>]
let TEST_PATH = @"C:\validationsample.csv"
2. Import data
ML.NET requires you to define data types which map to your input data set and output prediction models.
You can annotate your fields with column numbers and then load directly from CSV.
[<CLIMutable>]
type Digit = {
[<LoadColumn(0)>] Number : float32
[<LoadColumn(1, 784)>] [<VectorType(784)>] PixelValues : float32[]
}
[<CLIMutable>]
type DigitPrediction = {
PredictedLabel : uint
Label : uint
Score : float32 []
}
let context = new MLContext()
let trainData = context.Data.LoadFromTextFile<Digit>(TRAIN_PATH, hasHeader = true, separatorChar = ',')
let testData = context.Data.LoadFromTextFile<Digit>(TEST_PATH, hasHeader = true, separatorChar = ',')
3. Prepare data
By default, the library expects your input fields to be named Label
and Features
. You can however either change the default or map the names in your pipeline. We will do the latter.
let labelMap =
context.Transforms.Conversion.MapValueToKey(
"Label",
"Number",
keyOrdinality = ValueToKeyMappingEstimator.KeyOrdinality.ByValue
)
let featureMap =
context.Transforms.Concatenate(
"Features",
"PixelValues"
)
4. Configure logistic regression
Binary classifiers, as you might expect, only allow to us classify something into one of two groups - we get a value in the range 0-1, where a value less than 0.5 indicates group 1, and over 0.5 group 2.
In our case, we have 10 groups, one for each digit. We can compare each digit to each other digit (1 vs 2, 1 vs 3, 1 vs 4 etc, known as one vs one) or to every other digit (1 vs not 1, 2 vs not 2 etc, known as one vs all)
We are going to use the latter.
Our next decision is which optimiser we want to tune our parameters in this one vs all analysis. In our case we are going to use stochastic gradient descent.
Our final decision is which cost function to choose in order to score our predictions. We will pick the log loss function.
let lambda = 3.f // over-fitting penalty
let costFunction = LogLoss()
let gradientDescent =
context.BinaryClassification.Trainers.SgdNonCalibrated(
lossFunction = costFunction,
l2Regularization = lambda
)
let oneVsAll = context.MulticlassClassification.Trainers.OneVersusAll gradientDescent
let pipeline =
EstimatorChain()
.Append(labelMap)
.Append(featureMap)
.AppendCacheCheckpoint(context) // cache data to speed up training
.Append(oneVsAll)
5. Training and testing a model
Training a model is now as simple as
let model = trainData |> pipeline.Fit
We can now feed in our test data and the model will make predictions. These can then be evaluated in order to obtain metrics describing how accurate they are.
let transformedTestData =
testData
|> model.Transform
let metrics =
transformedTestData
|> context.MulticlassClassification.Evaluate
printfn "Evaluation metrics"
printfn " MicroAccuracy: %f" metrics.MicroAccuracy
printfn " MacroAccuracy: %f" metrics.MacroAccuracy
printfn " LogLoss: %f" metrics.LogLoss
printfn " LogLossReduction: %f" metrics.LogLossReduction
Evaluation metrics
MicroAccuracy: 0.888000
MacroAccuracy: 0.885613
LogLoss: 0.436994
LogLossReduction: 0.809586
We can even see the results in a nicely formatted table
metrics.ConfusionMatrix.GetFormattedConfusionTable()
Confusion table
||================================================================================
PREDICTED || 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | Recall
TRUTH ||================================================================================
0 || 42 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1.0000
1 || 0 | 51 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0.9808
2 || 0 | 1 | 53 | 1 | 0 | 0 | 0 | 3 | 1 | 1 | 0.8833
3 || 0 | 0 | 0 | 36 | 0 | 4 | 0 | 0 | 2 | 1 | 0.8372
4 || 0 | 0 | 0 | 0 | 52 | 1 | 0 | 0 | 0 | 2 | 0.9455
5 || 0 | 0 | 0 | 2 | 0 | 31 | 3 | 1 | 2 | 2 | 0.7561
6 || 0 | 0 | 0 | 0 | 1 | 4 | 44 | 0 | 0 | 0 | 0.8980
7 || 0 | 0 | 2 | 0 | 1 | 0 | 0 | 49 | 0 | 3 | 0.8909
8 || 0 | 2 | 1 | 1 | 0 | 2 | 0 | 2 | 38 | 2 | 0.7917
9 || 0 | 0 | 0 | 0 | 1 | 2 | 0 | 2 | 2 | 48 | 0.8727
||================================================================================
Precision ||1.0000 |0.9444 |0.9298 |0.9000 |0.9455 |0.7045 |0.9362 |0.8596 |0.8444 |0.8136 |
6. Making predictions
The final step is to create a prediction engine. This will allow us to take any data sample and predict its class in a fast and efficient manner.
let engine = context.Model.CreatePredictionEngine model
let predict digit = engine.Predict digit
Let's pick some random examples from our test set and try it out
let randoms =
let r = Random()
[ for i in 0..4 do r.Next(0, 499) ]
let digits =
context.Data.CreateEnumerable(
testData,
reuseRowObject = false
) |> Array.ofSeq
let testDigits = [
digits.[randoms.[0]]
digits.[randoms.[1]]
digits.[randoms.[2]]
digits.[randoms.[3]]
digits.[randoms.[4]]
]
printf " #\t\t"; [0..9] |> Seq.iter(fun i -> printf "%i\t\t" i); printfn ""
testDigits
|> Seq.iter(
fun digit ->
printf " %i\t" (int digit.Number)
let p = predict digit
p.Score |> Seq.iter (fun s -> printf "%f\t" s)
printfn "")
# 0 1 2 3 4 5 6 7 8 9
6 0.001066 0.000000 0.000000 0.000000 0.000000 0.000023 0.998911 0.000000 0.000000 0.000000
7 0.000006 0.002744 0.024458 0.000020 0.000003 0.002633 0.000076 0.316164 0.593821 0.060076
9 0.000000 0.000358 0.000000 0.038600 0.010955 0.067991 0.000004 0.000143 0.009292 0.872657
1 0.000002 0.915526 0.003066 0.002941 0.000001 0.000020 0.000135 0.000001 0.078100 0.000209
6 0.000003 0.000000 0.000009 0.000000 0.000000 0.000000 0.999987 0.000000 0.000000 0.000000
Not bad, although our 7 was confused with an 8.
Conclusion
It's a funny one, this. On the one hand, it is great that Microsoft have embraced ML and a lot of work has clearly been put into the libraries and documentation. I found it quite easy to use once I got my head around the methodology.
On the other hand, as an F# developer, used to pipelining data in a simple, intuitive and immutable way, it is hard to fall in love with the ML.NET .Append
syntax which is mutating an object. Also creating mutable annotated types is less than ideal. The whole thing does feel very OO-oriented, and I get the feeling something designed with F# in mind would be a much nicer experience.
I am excited to compare and contrast this process using the SciSharp stack, as I am sure it will be quite a different experience. I hope you join me next time!