Here’s Philip Brown, Elixir engineer Yellow Flag, with a tutorial on building a fullstack machine learning application using Nx, Axon, and LiveView. Coincidentally, Fly.io is the perfect place to run your LiveView apps. Get started.
(Updated March 2023 by Tom Berman for the official Axon releases.)
Machine learning allows you to solve problems that were once totally unimaginable. The ability for a computer to take an image and tell you what it sees was once only possible in science fiction.
Now, it’s possible to build machine learning models that can do amazing things. However, part of the challenge of machine learning is that there are a lot of moving parts to learn. This means that solving a problem with machine learning can be a difficult task for an individual engineer.
One of the big advantages Elixir has over similar programming languages is the integrated nature of what you have available to you. You can do a lot in Elixir without ever leaving the comfort of the language you love.
What are we going to build?
In this tutorial we’re going to look at building out an end-to-end machine learning project using only Elixir. Boom! As if that wasn’t enough, we’re going to build a machine learning model that can recognize a handwritten digit. We’ll train the model so that it will predict the digit from an image. We’ll also build an application that can accept new handwritten digits from the user, and then display the prediction.
HOT TIP! Be sure to grab the full code for this tutorial here!
Here’s a preview of what it looks like:
Let’s get started!
Setting up the project
We’re going to build this project using Phoenix, so the first thing we need is to create a new Phoenix project.
If you don’t already have Elixir installed on your computer, you can find instructions for your operating system on the Elixir Website.
Once you have Phoenix installed, you can run the following command in terminal:
$ mix archive.install hex phx_new
With Elixir and Phoenix installed, we can create a new Phoenix project:
$ mix phx.new digits --no-ecto
I’m including the --no-ecto
flag because we don’t need a database for this project. This command should prompt you to install the project’s dependencies. Hit Y
on that prompt and wait for the dependencies to be installed.
Once the dependencies are installed, follow the onscreen instructions to run your new Phoenix application and verify that everything was set up correctly.
I’m also going to add the Tailwind package for styling the application. If you want to add Tailwind to your project add the following to the list of dependencies in your mix.exs
file:
{:tailwind, "~> 0.1", runtime: Mix.env() == :dev}
Then follow the configuration instructions listed here.
Where will get our training data?
One of the most important aspects of machine learning is having good, quality data to train on. When working on real life machine learning projects, expect to spend the majority of your time on the data.
Fortunately for us, there is already a ready made dataset we can use. The MNIST Database is a large dataset of handwritten digits that have already been prepared and labeled. This dataset is commonly used for training image recognition machine learning models. The dataset consists of images of handwritten digits from 0 - 9 that are already labeled.
Prepare the project for machine learning
Next, we need to set up the machine learning model. The Elixir ecosystem has a number of exciting packages that can be used for training machine learning models.
The Nx package is the foundation of machine learning in Elixir. Nx allows us to manipulate our data using tensors, which are essentially efficient multi-dimensional arrays. When we say “tensor” below, just think “multi-dimensional array”.
Next, we have EXLA, which provides hardware acceleration for training our models. Crunching the numbers of machine learning is a very intensive process, but EXLA makes that much faster.
Axon builds on top of Nx and makes it possible for us to create neural networks in Elixir.
Finally we have the Scidata package, which provides conveniences for working with machine learning datasets, including MNIST.
So, the first thing we need to do is to add those dependencies to our mix.exs
file:
{:axon, "~> 0.5.1"},
{:exla, "~> 0.5.1"},
{:nx, "~> 0.5.1"},
{:scidata, "~> 0.1.5"}
Then we can install our new dependencies from a terminal:
$ mix deps.get
We also need to set the default backend in config.exs
import Config
# Set the backend for Nx
config :nx, :default_backend, EXLA.Backend
Working with our training data
There’s a couple of steps required for getting and transforming the training data, so we’ll start building out a module that can encapsulate everything that we’re building:
defmodule Digits.Model do
@moduledoc """
The Digits Machine Learning model
"""
end
First up we’ll add a download/0
function that downloads the training data for us. We’re just delegating to the Scidata
package for that.
def download do
Scidata.MNIST.download()
end
This function returns a tuple of {images, labels}
. However, we want to transform the images and labels so we can use them in our model.
First, we’ll use the following function to transform the images:
def transform_images({binary, type, shape}) do
binary
|> Nx.from_binary(type)
|> Nx.reshape(shape)
|> Nx.divide(255)
end
The image data from the download includes the following:
- Binary data - This is the image data as a binary.
- The type of the data - In this example the type is
{:u, 8}
unsigned integer. - The shape of the data - In this example the shape is
{60000, 1, 28, 28}
. This means there are 60000 images, which all have 1 channel (ie they’re black and white) and have a dimension of 28x28.
We can convert the binary into a tensor using Nx.
If we open up iex
we can visualize the image data. Run the following command in a terminal to open up iex
with our project loaded:
$ iex -S mix
Next, we run the following code:
{images, labels} = Digits.Model.download()
images
|> Digits.Model.transform_images()
|> Nx.slice_axis(0, 1, 0)
|> Nx.reshape({1, 1, 28, 28})
|> Nx.to_heatmap()
You should see the first handwritten digit of the dataset. This is what it looks like:
We can see the corresponding label for the image too. Let’s see how to do that.
First, we pattern match the binary data and type from the downloaded label data.
{binary, type, _} = labels
Then we convert the binary to a tensor and “slice” off the first item as our example.
binary
|> Nx.from_binary(type)
|> Nx.new_axis(-1)
|> Nx.slice_axis(0, 1, 0)
The first label should be a 5
. We’ll refactor that code in our transform function to get the labels.
def transform_labels({binary, type, _}) do
binary
|> Nx.from_binary(type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
end
The labels of the training data are used as targets for the model’s predictions. For each image, we know how it was labelled. During training, the model uses the labels to compare it’s predictions with the actual correct result. The guessing is adjusted to give better results in the future.
Currently, the labels are integers from 0 - 9. You can think of them as 10 different categories. In our case, the categories are integers, but when training a machine learning model, you might have categories such as colors, sizes, types of animals, etc.
So we need to convert our categories into something that the machine learning model can understand. The way we do this is to convert the label into a tensor of size {1, 10}
, where 10
is the number of categories you have.
For example:
#Nx.Tensor<
u8[1][10]
[
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
]
>
In this example, the long list of numbers has a 1
is in the first position. This represents the first category. In our case, that is the number “0”, but it could also be the color “red”, the size “small”, or the type of animal “dog”.
The second category would be:
#Nx.Tensor<
u8[1][10]
[
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
]
>
And so on.
This process is called one-hot encoding.
You can see what the first label of the training data is when it’s been one-hot encoded using the following chunk of code. (Still in iex
):
labels
|> Digits.Model.transform_labels()
|> Nx.slice_axis(0, 1, 0)
This should output the following tensor:
#Nx.Tensor<
u8[1][10]
[
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
]
>
Remember, we’re working with the number 5
right now. This tensor is an array of zeros with a 1
in the index for the 5. Counting from 0, the 1
is in the 5th spot.
Next, we convert the images and labels into batches. During training, we feed the data into the model in batches rather all at once. In this example we’re using a batch size of 32. This means each batch will include 32 examples.
batch_size = 32
images =
images
|> Digits.Model.transform_images()
|> Nx.to_batched(batch_size)
|> Enum.to_list()
labels =
labels
|> Digits.Model.transform_labels()
|> Nx.to_batched(batch_size)
|> Enum.to_list()
Next, we zip the images and labels together using Enum.zip
. Then we split the dataset into training, testing, and validation datasets. We need to use the majority of the data for training, and then a portion of the data to use to test the accuracy of the model. In this example we’re using 80% of the data for training and validation, and the remaining 20% unseen data will be used for testing.
data = Enum.zip(images, labels)
training_count = floor(0.8 * Enum.count(data))
validation_count = floor(0.2 * training_count)
{training_data, test_data} = Enum.split(data, training_count)
{validation_data, training_data} = Enum.split(train, validation_count)
Phew! That may seem pretty heavy but we’ve already achieved a lot! We’ve downloaded our training data, preprocessed it, and got it ready for building the model. During a real-life machine learning project you will likely spend a lot of time at acquiring, cleaning, and manipulating the data. We’re now in a great position to build and train the model!
Building the model
Next up we’ll use Axon to build the machine learning model. Add a new function to the Digits.Model
module with the following code:
def new({channels, height, width}) do
Axon.input("input_0", shape: {nil, channels, height, width})
|> Axon.flatten()
|> Axon.dense(128, activation: :relu)
|> Axon.dense(10, activation: :softmax)
end
First we need to set the input shape of the model to fit our training data. Next we flatten the previous layer and add a dense layer that uses relu as the activation function. Finally the output layer returns one of 10 labels (because our labels are 0 - 9).
You can experiment with different model configurations to get different results.
Training the model
Now that we have the data and the model, we can start training. Add another function to Digits.Model
to train the model:
def train(model, training_data, validation_data) do
model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.01))
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.run(training_data, %{}, compiler: EXLA, epochs: 10)
end
We’re using categorical cross entropy because we’re matching multiple labels and the “adam” optimizer because it gives fairly good results. We’ll track a single accuracy metric, and we’ll also validate the model with our validation data from earlier to ensure the model is not over-fitting on the training data.
Finally we’ll use EXLA as the compiler and we’ll train for 10 epochs. An epoch is one cycle through the data, so this means we’ll cycle through the data 10 times during training.
Testing our model
We can also test our model after training to get an idea of how well it performs. Add the following function to Digits.Model
:
def test(model, state, test_data) do
model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.run(test_data, state)
end
This tests the model using previously unseen data to check the accuracy of the predictions.
Saving and loading our model
The final thing to do is to add the ability to save and load the model. Our model is just an Elixir struct, so saving and loading it is simply a case of using Erlang’s binary_to_term/1
and term_to_binary/1
functions:
def save!(model, state) do
contents = Axon.serialize(model, state)
File.write!(path(), contents)
end
def load! do
path()
|> File.read!()
|> Axon.deserialize()
end
def path do
Path.join(Application.app_dir(:digits, "priv"), "model.axon")
end
Running the model
Now that we’ve written all the code to transform the data, train, and test our machine learning model, we’ll write a mix command to put it all together:
NOTE: The mix task locally caches the downloaded Minst data set so it doesn’t download it every time it’s run.
defmodule Mix.Tasks.Train do
use Mix.Task
@requirements ["app.start"]
alias Digits
def run(_) do
{images, labels} = load_mnist()
images =
images
|> Digits.Model.transform_images()
|> Nx.to_batched(32)
|> Enum.to_list()
labels =
labels
|> Digits.Model.transform_labels()
|> Nx.to_batched(32)
|> Enum.to_list()
data = Enum.zip(images, labels)
training_count = floor(0.8 * Enum.count(data))
validation_count = floor(0.2 * training_count)
{training_data, test_data} = Enum.split(data, training_count)
{validation_data, training_data} = Enum.split(training_data, validation_count)
model = Digits.Model.new({1, 28, 28})
Mix.Shell.IO.info("training...")
state = Digits.Model.train(model, training_data, validation_data)
Mix.Shell.IO.info("testing...")
Digits.Model.test(model, state, test_data)
Digits.Model.save!(model, state)
:ok
end
defp load_mnist() do
if !File.exists?(path()) do
save_mnist()
end
load!()
end
defp save_mnist do
Digits.Model.download()
|> save!()
end
defp save!(data) do
contents = :erlang.term_to_binary(data)
File.write!(path(), contents)
end
defp load! do
path()
|> File.read!()
|> :erlang.binary_to_term()
end
defp path do
Path.join(Application.app_dir(:digits, "priv"), "mnist.axon")
end
end
We can run the training with the following command:
mix train
Setting up the LiveView
Now that we have a trained machine learning model, we can set up a LiveView to accept new handwritten digits, and then display the predicted results.
First, we add a new live route to our router file in lib/digits_web/router.ex
:
scope "/", DigitsWeb do
pipe_through :browser
live "/", PageLive, :index
end
Next, we create a new file under lib/digits_web/live
called page_live.ex
. This is our LiveView module where all the interactivity happens:
defmodule DigitsWeb.PageLive do
@moduledoc """
PageLive LiveView
"""
use DigitsWeb, :live_view
end
When a user submits a new handwritten digit, the machine learning model makes a prediction on what digit was written and then the LiveView displays the prediction to the user. However, when the LiveView is first loaded, there isn’t a prediction to display. So, first, we need to initiate the prediction
assign value as nil
inside the mount/3
callback:
def mount(_params, _session, socket) do
{:ok, assign(socket, %{prediction: nil})}
end
Next, the render/1
function is responsible for rendering the LiveView:
def render(assigns) do
~H"""
<div id="wrapper" phx-update="ignore">
<div id="canvas" phx-hook="Draw"></div>
</div>
<div>
<button phx-click="reset">Reset</button>
<button phx-click="predict">Predict</button>
</div>
<%= if @prediction do %>
<div>
<div>
Prediction:
</div>
<div>
<%= @prediction %>
</div>
</div>
<% end %>
"""
end
Notice above that we have a div
with the id
of “canvas”. This will have an HTML canvas attached. The phx-hook
uses Javascript to let us interact with the canvas. The canvas div
is wrapped in another div
with the phx-update="ignore"
because we don’t want Phoenix to update it.
Next are two buttons, one to reset the canvas and one to make a prediction from what the user drew. Each of these buttons are wired up to phx-click
triggers.
Finally, if we have a prediction
, it is displayed.
Adding the canvas
Next, we need some input from the user. We could let the user upload images using Phoenix’s LiveView upload functionality, but a better (and way cooler) experience is to let the user draw new examples directly into the LiveView.
There’s a handy NPM package called draw-on-canvas that make this part easy.
To install it, cd
into the assets
directory and run the following command in a terminal:
$ npm i draw-on-canvas
This installs the draw-on-canvas
into the project.
Now we connect the draw-on-canvas
package to our LiveView via a hook. Open up assets/js/app.js
and import the draw-on-canvas
package:
import Draw from 'draw-on-canvas'
Let’s create a new Hooks
object:
let Hooks = {}
Remember to register the hook object in the LiveSocket
:
let liveSocket = new LiveSocket("/live", Socket, {
params: {_csrf_token: csrfToken},
hooks: Hooks
})
Next we add a new Draw
hook:
Hooks.Draw = {}
We need to implement the mounted
function, which is called when the hook is mounted. This is where we set up the canvas:
Hooks.Draw = {
mounted() {
this.draw = new Draw(this.el, 384, 384, {
backgroundColor: "black",
strokeColor: "white",
strokeWeight: 10
})
}
}
When we open the app in a browser, we should see a black square canvas that we can draw on!
Interacting with the canvas
Remember back in our PageLive
module, we added two buttons for interacting with the canvas.
The first button is used to reset the canvas. When the button is pressed we send a message to the client to reset the canvas. The push_event
function makes this easy.
Our new “reset” event handler in PageLive
looks like this:
def handle_event("reset", _params, socket) do
{:noreply,
socket
|> assign(prediction: nil)
|> push_event("reset", %{})
}
end
When the reset button is clicked, the phx-click
trigger sends the reset
event to the server. We then push an event called reset
to the client. We also set the prediction to nil
in the socket assigns.
On the Javascript side, we add a handleEvent
, that listens for the reset
event, and resets the canvas:
this.handleEvent("reset", () => {
this.draw.reset()
})
Next, let’s make our “predict” button work. We want to grab the contents of the canvas as an image. Again, we send a message to the client from the PageLive
LiveView module:
def handle_event("predict", _params, socket) do
{:noreply, push_event(socket, "predict", %{})}
end
In the mounted
callback, we add another handleEvent
. This grabs the contents of the canvas as a data URL and sends it to the server using pushEvent
:
this.handleEvent("predict", () => {
this.pushEvent("image", this.draw.canvas.toDataURL('image/png'))
})
Making predictions
Now that we hooked up the buttons to reset the canvas and send up the canvas contents to make a prediction, we will use the image from the canvas as a new input to our machine learning model.
We can accept the image data URL from the client using another handle_event/3
callback function:
def handle_event("image", "data:image/png;base64," <> raw, socket) do
name = Base.url_encode64(:crypto.strong_rand_bytes(10), padding: false)
path = Path.join(System.tmp_dir!(), "#{name}.webp")
File.write!(path, Base.decode64!(raw))
prediction = Digits.Model.predict(path)
File.rm!(path)
{:noreply, assign(socket, prediction: prediction)}
end
In this function, we use a binary pattern matching on the params
to get the image data. Next, we generate a random file name and create a path to a temporary directory for storing the image. Then we decode the image data and write it to the path.
Next we pass the path into the Digits.Model.predict/1
function and return a prediction. The prediction result is a number between 0 and 9. We’ll write that function next.
Finally, we delete the image file and assign the prediction to the socket for display in our LiveView.
Before we can use the user’s drawing with our model, we need to prepare the image. We need to:
- Convert it to grayscale to reduce the number channels from 3 to 1
- Resize it to 28 x 28
The Evision
library can do these changes for us. Let’s add it as a dependency in our mix.exs
file now:
{:evision, "~> 0.1.28"}
Install the dependency using:
mix deps.get
In the Digits.Model
module, let’s add a new function for making a prediction.
def predict(path) do
{:ok, mat} = Evision.imread(path, flags: Evision.Constant.cv_IMREAD_GRAYSCALE())
{:ok, mat} = Evision.resize(mat, [28, 28])
data =
Evision.Nx.to_nx(mat)
|> Nx.reshape({1, 28, 28})
|> List.wrap()
|> Nx.stack()
|> Nx.backend_transfer()
{model, state} = load!()
model
|> Axon.predict(state, data)
|> Nx.argmax()
|> Nx.to_number()
end
First, we read the image path and convert it to grayscale. This reduces the number of channels from 3 to 1. Then we resize the image to 28
x 28
.
We also need to convert the image data to an Nx tensor and reshape it to an expected correct shape. Our machine learning model expects a “batch” of inputs, and so we’ll wrap the tensor using List.wrap/1
and then stack it using Nx.stack/1
.
Next, we load the model
and the state
, using the load!/0
function from earlier. Ideally you wouldn’t be loading the model
and state
for each prediction, but it’s fine for our basic example.
We pass the model
, state
and data
into the Axon.predict/4
function. One thing to note is, you will need to add require Axon
to the Digits.Model
module because Axon.predict/4
is actually a macro.
The Axon.predict/4
function returns a prediction in the form of a one-hot encoded tensor. We use the Nx.argmax/1
function to convert it to a tensor that contains a single scalar value between 0 and 9, and then we use Nx.to_number/1
to return the value as a number.
Our predicted number is set as the prediction
is the LiveView assigns, displaying it to the user.
We built an end-to-end machine learning application in Elixir!
Wow! Check out what we just did!
We built an end-to-end machine learning application using Elixir! We trained a model from scratch. We used LiveView for interactive, real-time application input from the user. We ran predictions and displayed the results interactively.
One of the most amazing things here was that we did it all using Elixir and didn’t need external machine learning tools or languages. Machine learning in Elixir is still maturing, but I hope this inspires you to try something new in your own project.
Full code for this tutorial is found at philipbrown/handwritten-digits-elixir.