from typing import Iterable import flax.linen as nn import jax import jax.numpy as jnp import optax import chex class MLP(nn.Module): """A simple multilayer perceptron model.""" @nn.compact def __call__(self, x): # Flattens inputs in the batch. x = x.reshape((x.shape[0], -1)) x = nn.Dense(features=512)(x) x = nn.relu(x) x = nn.Dense(features=512)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) return x net = MLP() def loss_fn(params, batch): """Computes loss over a mini-batch. """ logits = net.apply(params, batch['image']) loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=batch['label'] ).mean() return loss def build_train_step(optimizer: optax.GradientTransformation): """Builds a function for executing a single step in the optimization.""" @jax.jit def update(params, opt_state, batch): grads = jax.grad(loss_fn)(params, batch) updates, opt_state = optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) return params, opt_state return update def fit( optimizer: optax.GradientTransformation, params: optax.Params, batches: Iterable[dict[str, jnp.ndarray]], ) -> optax.Params: """Executes a train loop over the train batches using the given optimizer.""" train_step = build_train_step(optimizer) opt_state = optimizer.init(params) for batch in batches: params, opt_state = train_step(params, opt_state, batch) return params EXAMPLES = jax.random.uniform(jax.random.PRNGKey(0), (9, 28, 28, 1)) LABELS = jax.random.randint(jax.random.PRNGKey(0), (9,), minval=0, maxval=10) optimizer = optax.sgd(1e-4) params = net.init(jax.random.PRNGKey(0), EXAMPLES) new_params_single_batch = fit( optimizer, params, batches=[dict(image=EXAMPLES, label=LABELS),] ) new_params_gradient_accumulation = fit( optax.MultiSteps(optimizer, every_k_schedule=3), params, batches=[ dict(image=EXAMPLES[0:3], label=LABELS[0:3]), dict(image=EXAMPLES[3:6], label=LABELS[3:6]), dict(image=EXAMPLES[6:9], label=LABELS[6:9]), ], ) chex.assert_trees_all_close( new_params_single_batch, new_params_gradient_accumulation, atol=1e-7, ) learning_rate_schedule = optax.piecewise_constant_schedule( init_value=1.0, boundaries_and_scales={ 0: 1e-4, 1: 1e-1, }, ) optimizer = optax.sgd(learning_rate_schedule) new_params_single_batch = fit( optimizer, params, batches=[ dict(image=EXAMPLES, label=LABELS), ], ) new_params_gradient_accumulation = fit( optax.MultiSteps(optimizer, every_k_schedule=3), params, batches=[ dict(image=EXAMPLES[0:3], label=LABELS[0:3]), dict(image=EXAMPLES[3:6], label=LABELS[3:6]), dict(image=EXAMPLES[6:9], label=LABELS[6:9]), ], ) chex.assert_trees_all_close( new_params_single_batch, new_params_gradient_accumulation, atol=1e-7, )