JAX can work with a variety of data loaders, including Grain, TensorFlow Datasets and TorchData, but for simplicity this example uses the well-known scikit-learn digits dataset.
from sklearn.datasets import load_digits
digits = load_digits()
print(f"{digits.data.shape=}")
print(f"{digits.target.shape=}")
digits.data.shape=(1797, 64) digits.target.shape=(1797,)
This dataset consists of 8x8 pixelated images of hand-written digits and their corresponding labels. Let’s visualize a handful of them with matplotlib:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(10, 10, figsize=(6, 6),
subplot_kw={'xticks':[], 'yticks':[]},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i, ax in enumerate(axes.flat):
ax.imshow(digits.images[i], cmap='binary', interpolation='gaussian')
ax.text(0.05, 0.05, str(digits.target[i]), transform=ax.transAxes, color='green')
Next, we split the dataset into a training and testing set, and convert these splits into jax.Arrays before we feed them into the model.
We’ll use the jax.numpy module, which provides a familiar NumPy-style API around JAX operations:
from sklearn.model_selection import train_test_split
splits = train_test_split(digits.images, digits.target, random_state=0)
import jax.numpy as jnp
images_train, images_test, label_train, label_test = map(jnp.asarray, splits)
print(f"{images_train.shape=} {label_train.shape=}")
print(f"{images_test.shape=} {label_test.shape=}")
images_train.shape=(1347, 8, 8) label_train.shape=(1347,) images_test.shape=(450, 8, 8) label_test.shape=(450,)
We can now use Flax NNX to create a simple feed-forward neural network - subclassing flax.nnx.Module - with flax.nnx.Linear layers with scaled exponential linear unit (SELU) activation function using the built-in flax.nnx.selu:
from flax import nnx
class SimpleNN(nnx.Module):
def __init__(self, n_features: int = 64, n_hidden: int = 100, n_targets: int = 10,
*, rngs: nnx.Rngs):
self.n_features = n_features
self.layer1 = nnx.Linear(n_features, n_hidden, rngs=rngs)
self.layer2 = nnx.Linear(n_hidden, n_hidden, rngs=rngs)
self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs)
def __call__(self, x):
x = x.reshape(x.shape[0], self.n_features) # Flatten images.
x = nnx.selu(self.layer1(x))
x = nnx.selu(self.layer2(x))
x = self.layer3(x)
return x
model = SimpleNN(rngs=nnx.Rngs(0))
nnx.display(model) # Interactive display if penzai is installed.
With the SimpleNN model created and instantiated, we can now choose the loss function and the optimizer with the Optax package, and then define the training step function. Use:
optax.softmax_cross_entropy_with_integer_labels as the loss, as the output layer will have nodes corresponding to a handwritten integer label.optax.sgd for the stochastic gradient descent optimizer.flax.nnx.Optimizer to instantiate the optimizer and set the train state.import jax
import optax
optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.05))
def loss_fun(
model: nnx.Module,
data: jax.Array,
labels: jax.Array):
logits = model(data)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=labels
).mean()
return loss, logits
@nnx.jit # JIT-compile the function
def train_step(
model: nnx.Module,
optimizer: nnx.Optimizer,
data: jax.Array,
labels: jax.Array):
loss_gradient = nnx.grad(loss_fun, has_aux=True) # gradient transform!
grads, logits = loss_gradient(model, data, labels)
optimizer.update(grads) # inplace update
Notice here the use of flax.nnx.jit and flax.nnx.grad, which are Flax NNX transformations built on jax.jit and jax.grad transformations.
jax.jit is a Just-In-Time compilation transformation, and will cause the function to be passed to the XLA compiler for fast repeated execution.jax.grad is a gradient transformation that uses JAX's automatic differentiation for fast optimization of large networks.We will return to these transformations later in the tutorial.
Now that we have a training step function, let's define a training loop to repeatedly perform this training step over the training data, periodically printing the loss against the test set to monitor convergence:
for i in range(301): # 300 training epochs
train_step(model, optimizer, images_train, label_train)
if i % 50 == 0: # Print metrics.
loss, _ = loss_fun(model, images_test, label_test)
print(f"epoch {i}: loss={loss:.2f}")
epoch 0: loss=5.68 epoch 50: loss=0.16 epoch 100: loss=0.12 epoch 150: loss=0.11 epoch 200: loss=0.10 epoch 250: loss=0.10 epoch 300: loss=0.10
After 300 training epochs, our model should have converged to a target loss of around 0.10. We can check what this implies for the accuracy of the labels for each image:
label_pred = model(images_test).argmax(axis=1)
num_matches = jnp.count_nonzero(label_pred == label_test)
num_total = len(label_test)
accuracy = num_matches / num_total
print(f"{num_matches} labels match out of {num_total}:"
f" accuracy = {num_matches/num_total:%}")
438 labels match out of 450: accuracy = 97.333336%
The simple feed-forward network has achieved approximately 98% accuracy on the test set. We can do a similar visualization as above to review some examples that the model predicted correctly (in green) and incorrectly (in red):
fig, axes = plt.subplots(10, 10, figsize=(6, 6),
subplot_kw={'xticks':[], 'yticks':[]},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i, ax in enumerate(axes.flat):
ax.imshow(images_test[i], cmap='binary', interpolation='gaussian')
color = 'green' if label_pred[i] == label_test[i] else 'red'
ax.text(0.05, 0.05, str(label_pred[i]), transform=ax.transAxes, color=color)
In this tutorial, we have just scraped the surface with JAX, Flax NNX, and Optax here. The Flax NNX package includes a number of useful APIs for tracking metrics during training, which are features in the Flax MNIST tutorial on the Flax website.
The Flax NNX neural network API demonstrated above takes advantage of a number of key JAX features, designed into the library from the ground up. In particular:
JAX provides a familiar NumPy-like API for array computing.
This means that when processing data and outputs, we can reach for APIs like jax.numpy.count_nonzero, which mirror the familiar APIs of the NumPy package; in this case numpy.count_nonzero.
JAX provides just-in-time (JIT) compilation.
This means that we can implement our code easily in Python, but count on fast compiled execution on CPU, GPU, and TPU backends via the XLA compiler by wrapping the code with a simple jax.jit transformation.
JAX provides automatic differentiation (autodiff).
This means that when fitting models, optax and flax can compute closed-form gradient functions for fast optimization of models, using the jax.grad transformation.
JAX provides automatic vectorization.
While we didn't get to use this directly in the code before, but under the hood flax takes advantage of JAX's vectorized map (jax.vmap) to automatically convert loss and gradient functions to efficient batch-aware functions that are just as fast as hand-written versions. This makes JAX implementations simpler and less error-prone.
We will learn more about these features through brief examples in the following sections.
The foundational array computing package in Python is NumPy, and JAX provides a matching API via the jax.numpy subpackage.
Additionally, JAX arrays (jax.Array) behave much like NumPy arrays in their attributes, and in terms of indexing and broadcasting semantics.
In the previous example, we used Flax's built-in flax.nnx.selu implementation. We can also implement SeLU using JAX's NumPy API as follows:
import jax.numpy as jnp
def selu(x, alpha=1.67, lam=1.05):
return lam * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(5.0)
print(selu(x))
[0. 1.05 2.1 3.1499999 4.2 ]
Despite the broad similarities, be aware that JAX does have some well-motivated differences from NumPy that you can read about in 🔪 JAX – The Sharp Bits 🔪 on the JAX site.
As mentioned before, JAX is built on the XLA compiler, and allows sequences of operations to be just-in-time (JIT) compiled using the jax.jit transformation.
In the neural network example above, we used the similar flax.nnx.jit transform, which has some special handling for Flax NNX objects for speed in neural network training.
Returning to the previously defined selu function in JAX, we can create a jax.jit-compiled version this way:
import jax
selu_jit = jax.jit(selu)
selu_jit is now a compiled version of the original function, which returns the same result to typical floating-point precision:
x = jnp.arange(1E6)
jnp.allclose(selu(x), selu_jit(x)) # results match
Array(True, dtype=bool)
We can use IPython's %timeit magic to observe the speedup (note the use of jax.block_until_ready(), which we need to use to account for JAX's asynchronous dispatch):
%timeit selu(x).block_until_ready()
8.32 ms ± 489 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit selu_jit(x).block_until_ready()
1.38 ms ± 184 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
For this computation, running on CPU, jax.jit compilation gives an order of magnitude speedup.
JAX's documentation has more discussion of JIT compilation at Just-in-time compilation.
For efficient optimization of neural network models, fast gradient computations are essential. JAX enables this via its automatic differentiation transformations like jax.grad, which computes a closed-form gradient of a JAX function. In the neural network example, we used the similar flax.nnx.grad function, which has special handling for flax.nnx objects.
Here's how to compute the gradient of a function with jax.grad:
x = jnp.float32(-1.0)
jax.grad(selu)(x)
Array(0.6450766, dtype=float32)
We can briefly check with a finite-difference approximation that this is giving the expected value:
eps = 1E-3
(selu(x + eps) - selu(x)) / eps
Array(0.64539903, dtype=float32)
Importantly, the automatic differentiation approach is both more accurate and efficient than computing numerical gradients. JAX's documentation has more discussion of autodiff at Automatic differentiation and Advanced automatic differentiation.
In the training loop example earlier, we defined the loss function in terms of a single input data vector of shape n_features but trained the model by passing batches of data (of shape [n_samples, n_features]). Rather than requiring a naive and slow loop over batches in Flax and Optax internals, they instead use JAX's automatic vectorization via the jax.vmap transformation to construct a batched version of the kernel automatically.
Consider a simple loss function that looks like this:
def loss(x: jax.Array, x0: jax.Array):
return jnp.sum((x - x0) ** 2)
We can evaluate it on a single data vector this way:
x = jnp.arange(3.)
x0 = jnp.ones(3)
loss(x, x0)
Array(2., dtype=float32)
But if we attempt to evaluate it on a batch of vectors, it does not correctly return a batch of 4 losses:
batched_x = jnp.arange(12).reshape(4, 3) # batch of 4 vectors
loss(batched_x, x0) # wrong!
Array(386., dtype=float32)
The problem is that this loss function is not batch-aware. Without automatic vectorization, there are two ways we can address this:
The jax.vmap transformation offers a third way: it automatically transforms our original function into a batch-aware version, so we get the speed of option 1 with the ease of option 2:
loss_batched = jax.vmap(loss, in_axes=(0, None)) # batch x over axis 0, do not batch x0
loss_batched(batched_x, x0)
Array([ 2., 29., 110., 245.], dtype=float32)
In the neural network example earlier, both flax and optax make use of JAX's vmap to allow for efficient batched computations over our unbatched loss function.
JAX's documentation has more discussion of automatic vectorization at Automatic vectorization.