from sklearn.datasets import load_digits digits = load_digits() print(f"{digits.data.shape=}") print(f"{digits.target.shape=}") import matplotlib.pyplot as plt fig, axes = plt.subplots(10, 10, figsize=(6, 6), subplot_kw={'xticks':[], 'yticks':[]}, gridspec_kw=dict(hspace=0.1, wspace=0.1)) for i, ax in enumerate(axes.flat): ax.imshow(digits.images[i], cmap='binary', interpolation='gaussian') ax.text(0.05, 0.05, str(digits.target[i]), transform=ax.transAxes, color='green') from sklearn.model_selection import train_test_split splits = train_test_split(digits.images, digits.target, random_state=0) import jax.numpy as jnp images_train, images_test, label_train, label_test = map(jnp.asarray, splits) print(f"{images_train.shape=} {label_train.shape=}") print(f"{images_test.shape=} {label_test.shape=}") from flax import nnx class SimpleNN(nnx.Module): def __init__(self, n_features: int = 64, n_hidden: int = 100, n_targets: int = 10, *, rngs: nnx.Rngs): self.n_features = n_features self.layer1 = nnx.Linear(n_features, n_hidden, rngs=rngs) self.layer2 = nnx.Linear(n_hidden, n_hidden, rngs=rngs) self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs) def __call__(self, x): x = x.reshape(x.shape[0], self.n_features) # Flatten images. x = nnx.selu(self.layer1(x)) x = nnx.selu(self.layer2(x)) x = self.layer3(x) return x model = SimpleNN(rngs=nnx.Rngs(0)) nnx.display(model) # Interactive display if penzai is installed. import jax import optax optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.05)) def loss_fun( model: nnx.Module, data: jax.Array, labels: jax.Array): logits = model(data) loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=labels ).mean() return loss, logits @nnx.jit # JIT-compile the function def train_step( model: nnx.Module, optimizer: nnx.Optimizer, data: jax.Array, labels: jax.Array): loss_gradient = nnx.grad(loss_fun, has_aux=True) # gradient transform! grads, logits = loss_gradient(model, data, labels) optimizer.update(grads) # inplace update for i in range(301): # 300 training epochs train_step(model, optimizer, images_train, label_train) if i % 50 == 0: # Print metrics. loss, _ = loss_fun(model, images_test, label_test) print(f"epoch {i}: loss={loss:.2f}") label_pred = model(images_test).argmax(axis=1) num_matches = jnp.count_nonzero(label_pred == label_test) num_total = len(label_test) accuracy = num_matches / num_total print(f"{num_matches} labels match out of {num_total}:" f" accuracy = {num_matches/num_total:%}") fig, axes = plt.subplots(10, 10, figsize=(6, 6), subplot_kw={'xticks':[], 'yticks':[]}, gridspec_kw=dict(hspace=0.1, wspace=0.1)) for i, ax in enumerate(axes.flat): ax.imshow(images_test[i], cmap='binary', interpolation='gaussian') color = 'green' if label_pred[i] == label_test[i] else 'red' ax.text(0.05, 0.05, str(label_pred[i]), transform=ax.transAxes, color=color) import jax.numpy as jnp def selu(x, alpha=1.67, lam=1.05): return lam * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) x = jnp.arange(5.0) print(selu(x)) import jax selu_jit = jax.jit(selu) x = jnp.arange(1E6) jnp.allclose(selu(x), selu_jit(x)) # results match %timeit selu(x).block_until_ready() %timeit selu_jit(x).block_until_ready() x = jnp.float32(-1.0) jax.grad(selu)(x) eps = 1E-3 (selu(x + eps) - selu(x)) / eps def loss(x: jax.Array, x0: jax.Array): return jnp.sum((x - x0) ** 2) x = jnp.arange(3.) x0 = jnp.ones(3) loss(x, x0) batched_x = jnp.arange(12).reshape(4, 3) # batch of 4 vectors loss(batched_x, x0) # wrong! loss_batched = jax.vmap(loss, in_axes=(0, None)) # batch x over axis 0, do not batch x0 loss_batched(batched_x, x0)