This notebook trains a simple Multilayer Perceptron (MLP) classifier for hand-written digit recognition (MNIST dataset).
from typing import Sequence
from flax import linen as nn
import jax
import jax.numpy as jnp
import optax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
# @markdown The learning rate for the optimizer:
LEARNING_RATE = 0.002 # @param{type:"number"}
# @markdown Number of samples in each batch:
BATCH_SIZE = 128 # @param{type:"integer"}
# @markdown Total number of epochs to train for:
N_EPOCHS = 1 # @param{type:"integer"}
MNIST is a dataset of 28x28 images with 1 channel. We now load the dataset using tensorflow_datasets, apply min-max normalization to the images, shuffle the data in the train set and create batches of size BATCH_SIZE.
(train_loader, test_loader), info = tfds.load(
"mnist", split=["train", "test"], as_supervised=True, with_info=True
)
min_max_rgb = lambda image, label: (tf.cast(image, tf.float32) / 255., label)
train_loader = train_loader.map(min_max_rgb)
test_loader = test_loader.map(min_max_rgb)
NUM_CLASSES = info.features["label"].num_classes
IMG_SIZE = info.features["image"].shape
train_loader_batched = train_loader.shuffle(
buffer_size=10_000, reshuffle_each_iteration=True
).batch(BATCH_SIZE, drop_remainder=True)
test_loader_batched = test_loader.batch(BATCH_SIZE, drop_remainder=True)
The data is ready! Next let's define a model. Optax is agnostic to which (if any) neural network library is used. Here we use Flax to implement a simple MLP.
class MLP(nn.Module):
"""A simple multilayer perceptron model for image classification."""
hidden_sizes: Sequence[int] = (1000, 1000)
@nn.compact
def __call__(self, x):
# Flattens images in the batch.
x = x.reshape((x.shape[0], -1))
x = nn.Dense(features=self.hidden_sizes[0])(x)
x = nn.relu(x)
x = nn.Dense(features=self.hidden_sizes[1])(x)
x = nn.relu(x)
x = nn.Dense(features=NUM_CLASSES)(x)
return x
net = MLP()
@jax.jit
def predict(params, inputs):
return net.apply({"params": params}, inputs)
@jax.jit
def loss_accuracy(params, data):
"""Computes loss and accuracy over a mini-batch.
Args:
params: parameters of the model.
bn_params: state of the model.
data: tuple of (inputs, labels).
is_training: if true, uses train mode, otherwise uses eval mode.
Returns:
loss: float
"""
inputs, labels = data
logits = predict(params, inputs)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=labels
).mean()
accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
return loss, {"accuracy": accuracy}
@jax.jit
def update_model(state, grads):
return state.apply_gradients(grads=grads)
Next we need to initialize network parameters and solver state. We also define a convenience function dataset_stats that we'll call once per epoch to collect the loss and accuracy of our solver over the test set.
solver = optax.adam(LEARNING_RATE)
rng = jax.random.PRNGKey(0)
dummy_data = jnp.ones((1,) + IMG_SIZE, dtype=jnp.float32)
params = net.init({"params": rng}, dummy_data)["params"]
solver_state = solver.init(params)
def dataset_stats(params, data_loader):
"""Computes loss and accuracy over the dataset `data_loader`."""
all_accuracy = []
all_loss = []
for batch in data_loader.as_numpy_iterator():
batch_loss, batch_aux = loss_accuracy(params, batch)
all_loss.append(batch_loss)
all_accuracy.append(batch_aux["accuracy"])
return {"loss": np.mean(all_loss), "accuracy": np.mean(all_accuracy)}
Finally, we do the actual training. The next cell train the model for N_EPOCHS. Within each epoch we iterate over the batched loader train_loader_batched, and once per epoch we also compute the test set accuracy and loss.
train_accuracy = []
train_losses = []
# Computes test set accuracy at initialization.
test_stats = dataset_stats(params, test_loader_batched)
test_accuracy = [test_stats["accuracy"]]
test_losses = [test_stats["loss"]]
@jax.jit
def train_step(params, solver_state, batch):
# Performs a one step update.
(loss, aux), grad = jax.value_and_grad(loss_accuracy, has_aux=True)(
params, batch
)
updates, solver_state = solver.update(grad, solver_state, params)
params = optax.apply_updates(params, updates)
return params, solver_state, loss, aux
for epoch in range(N_EPOCHS):
train_accuracy_epoch = []
train_losses_epoch = []
for step, train_batch in enumerate(train_loader_batched.as_numpy_iterator()):
params, solver_state, train_loss, train_aux = train_step(
params, solver_state, train_batch
)
train_accuracy_epoch.append(train_aux["accuracy"])
train_losses_epoch.append(train_loss)
if step % 20 == 0:
print(
f"step {step}, train loss: {train_loss:.2e}, train accuracy:"
f" {train_aux['accuracy']:.2f}"
)
test_stats = dataset_stats(params, test_loader_batched)
test_accuracy.append(test_stats["accuracy"])
test_losses.append(test_stats["loss"])
train_accuracy.append(np.mean(train_accuracy_epoch))
train_losses.append(np.mean(train_losses_epoch))
f"Improved accuracy on test DS from {test_accuracy[0]} to {test_accuracy[-1]}"