from typing import Sequence from flax import linen as nn import jax import jax.numpy as jnp import optax import numpy as np import tensorflow as tf import tensorflow_datasets as tfds # @markdown The learning rate for the optimizer: LEARNING_RATE = 0.002 # @param{type:"number"} # @markdown Number of samples in each batch: BATCH_SIZE = 128 # @param{type:"integer"} # @markdown Total number of epochs to train for: N_EPOCHS = 1 # @param{type:"integer"} (train_loader, test_loader), info = tfds.load( "mnist", split=["train", "test"], as_supervised=True, with_info=True ) min_max_rgb = lambda image, label: (tf.cast(image, tf.float32) / 255., label) train_loader = train_loader.map(min_max_rgb) test_loader = test_loader.map(min_max_rgb) NUM_CLASSES = info.features["label"].num_classes IMG_SIZE = info.features["image"].shape train_loader_batched = train_loader.shuffle( buffer_size=10_000, reshuffle_each_iteration=True ).batch(BATCH_SIZE, drop_remainder=True) test_loader_batched = test_loader.batch(BATCH_SIZE, drop_remainder=True) class MLP(nn.Module): """A simple multilayer perceptron model for image classification.""" hidden_sizes: Sequence[int] = (1000, 1000) @nn.compact def __call__(self, x): # Flattens images in the batch. x = x.reshape((x.shape[0], -1)) x = nn.Dense(features=self.hidden_sizes[0])(x) x = nn.relu(x) x = nn.Dense(features=self.hidden_sizes[1])(x) x = nn.relu(x) x = nn.Dense(features=NUM_CLASSES)(x) return x net = MLP() @jax.jit def predict(params, inputs): return net.apply({"params": params}, inputs) @jax.jit def loss_accuracy(params, data): """Computes loss and accuracy over a mini-batch. Args: params: parameters of the model. bn_params: state of the model. data: tuple of (inputs, labels). is_training: if true, uses train mode, otherwise uses eval mode. Returns: loss: float """ inputs, labels = data logits = predict(params, inputs) loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=labels ).mean() accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels) return loss, {"accuracy": accuracy} @jax.jit def update_model(state, grads): return state.apply_gradients(grads=grads) solver = optax.adam(LEARNING_RATE) rng = jax.random.PRNGKey(0) dummy_data = jnp.ones((1,) + IMG_SIZE, dtype=jnp.float32) params = net.init({"params": rng}, dummy_data)["params"] solver_state = solver.init(params) def dataset_stats(params, data_loader): """Computes loss and accuracy over the dataset `data_loader`.""" all_accuracy = [] all_loss = [] for batch in data_loader.as_numpy_iterator(): batch_loss, batch_aux = loss_accuracy(params, batch) all_loss.append(batch_loss) all_accuracy.append(batch_aux["accuracy"]) return {"loss": np.mean(all_loss), "accuracy": np.mean(all_accuracy)} train_accuracy = [] train_losses = [] # Computes test set accuracy at initialization. test_stats = dataset_stats(params, test_loader_batched) test_accuracy = [test_stats["accuracy"]] test_losses = [test_stats["loss"]] @jax.jit def train_step(params, solver_state, batch): # Performs a one step update. (loss, aux), grad = jax.value_and_grad(loss_accuracy, has_aux=True)( params, batch ) updates, solver_state = solver.update(grad, solver_state, params) params = optax.apply_updates(params, updates) return params, solver_state, loss, aux for epoch in range(N_EPOCHS): train_accuracy_epoch = [] train_losses_epoch = [] for step, train_batch in enumerate(train_loader_batched.as_numpy_iterator()): params, solver_state, train_loss, train_aux = train_step( params, solver_state, train_batch ) train_accuracy_epoch.append(train_aux["accuracy"]) train_losses_epoch.append(train_loss) if step % 20 == 0: print( f"step {step}, train loss: {train_loss:.2e}, train accuracy:" f" {train_aux['accuracy']:.2f}" ) test_stats = dataset_stats(params, test_loader_batched) test_accuracy.append(test_stats["accuracy"]) test_losses.append(test_stats["loss"]) train_accuracy.append(np.mean(train_accuracy_epoch)) train_losses.append(np.mean(train_losses_epoch)) f"Improved accuracy on test DS from {test_accuracy[0]} to {test_accuracy[-1]}"