import jax.numpy as jnp import jax import optax import functools @functools.partial(jax.vmap, in_axes=(None, 0)) def network(params, x): return jnp.dot(params, x) def compute_loss(params, x, y): y_pred = network(params, x) loss = jnp.mean(optax.l2_loss(y_pred, y)) return loss key = jax.random.PRNGKey(42) target_params = 0.5 # Generate some data. xs = jax.random.normal(key, (16, 2)) ys = jnp.sum(xs * target_params, axis=-1) start_learning_rate = 1e-1 optimizer = optax.adam(start_learning_rate) # Initialize parameters of the model + optimizer. params = jnp.array([0.0, 0.0]) opt_state = optimizer.init(params) # A simple update loop. for _ in range(1000): grads = jax.grad(compute_loss)(params, xs, ys) updates, opt_state = optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) assert jnp.allclose(params, target_params), \ 'Optimization should retrive the target params used to generate the data.' # Exponential decay of the learning rate. scheduler = optax.exponential_decay( init_value=start_learning_rate, transition_steps=1000, decay_rate=0.99) # Combining gradient transforms using `optax.chain`. gradient_transform = optax.chain( optax.clip_by_global_norm(1.0), # Clip by the gradient by the global norm. optax.scale_by_adam(), # Use the updates from adam. optax.scale_by_schedule(scheduler), # Use the learning rate from the scheduler. # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss. optax.scale(-1.0) ) # Initialize parameters of the model + optimizer. params = jnp.array([0.0, 0.0]) # Recall target_params=0.5. opt_state = gradient_transform.init(params) # A simple update loop. for _ in range(1000): grads = jax.grad(compute_loss)(params, xs, ys) updates, opt_state = gradient_transform.update(grads, opt_state) params = optax.apply_updates(params, updates) assert jnp.allclose(params, target_params), \ 'Optimization should retrive the target params used to generate the data.' decaying_global_norm_tx = optax.inject_hyperparams(optax.clip_by_global_norm)( max_norm=optax.linear_schedule(1.0, 0.0, transition_steps=99)) opt_state = decaying_global_norm_tx.init(None) assert opt_state.hyperparams['max_norm'] == 1.0, 'Max norm should start at 1.0' for _ in range(100): _, opt_state = decaying_global_norm_tx.update(None, opt_state) assert opt_state.hyperparams['max_norm'] == 0.0, 'Max norm should end at 0.0' import optax import jax.numpy as jnp import jax import numpy as np BATCH_SIZE = 5 NUM_TRAIN_STEPS = 1_000 RAW_TRAINING_DATA = np.random.randint(255, size=(NUM_TRAIN_STEPS, BATCH_SIZE, 1)) TRAINING_DATA = np.unpackbits(RAW_TRAINING_DATA.astype(np.uint8), axis=-1) LABELS = jax.nn.one_hot(RAW_TRAINING_DATA % 2, 2).astype(jnp.float32).reshape(NUM_TRAIN_STEPS, BATCH_SIZE, 2) initial_params = { 'hidden': jax.random.normal(shape=[8, 32], key=jax.random.PRNGKey(0)), 'output': jax.random.normal(shape=[32, 2], key=jax.random.PRNGKey(1)), } def net(x: jnp.ndarray, params: optax.Params) -> jnp.ndarray: x = jnp.dot(x, params['hidden']) x = jax.nn.relu(x) x = jnp.dot(x, params['output']) return x def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray: y_hat = net(batch, params) # optax also provides a number of common loss functions. loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1) return loss_value.mean() def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params: opt_state = optimizer.init(params) @jax.jit def step(params, opt_state, batch, labels): loss_value, grads = jax.value_and_grad(loss)(params, batch, labels) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) return params, opt_state, loss_value for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)): params, opt_state, loss_value = step(params, opt_state, batch, labels) if i % 100 == 0: print(f'step {i}, loss: {loss_value}') return params # Finally, we can fit our parametrized function using the Adam optimizer # provided by optax. optimizer = optax.adam(learning_rate=1e-2) params = fit(initial_params, optimizer) schedule = optax.warmup_cosine_decay_schedule( init_value=0.0, peak_value=1.0, warmup_steps=50, decay_steps=1_000, end_value=0.0, ) optimizer = optax.chain( optax.clip(1.0), optax.adamw(learning_rate=schedule), ) params = fit(initial_params, optimizer) tx = optax.scale_by_rms() state = tx.init(params) # init stats grads = jax.grad(loss)(params, TRAINING_DATA, LABELS) updates, state = tx.update(grads, state, params) # transform & update stats. max_norm = 100. learning_rate = 1e-3 my_optimiser = optax.chain( optax.clip_by_global_norm(max_norm), optax.scale_by_adam(eps=1e-4), optax.scale(-learning_rate)) my_optimiser = optax.flatten(optax.adam(learning_rate)) schedule_fn = optax.polynomial_schedule( init_value=1., end_value=0., power=1, transition_steps=5) for step_count in range(6): print(schedule_fn(step_count)) # [1., 0.8, 0.6, 0.4, 0.2, 0.] schedule_fn = optax.polynomial_schedule( init_value=-learning_rate, end_value=0., power=1, transition_steps=5) optimiser = optax.chain( optax.clip_by_global_norm(max_norm), optax.scale_by_adam(eps=1e-4), optax.scale_by_schedule(schedule_fn)) optimiser = optax.adam(learning_rate=schedule_fn) def adamw(learning_rate, b1, b2, eps, weight_decay): return optax.chain( optax.scale_by_adam(b1=b1, b2=b2, eps=eps), optax.scale_and_decay(-learning_rate, weight_decay=weight_decay)) updates, state = tx.update(grads, state, params) # transform & update stats. new_params = optax.apply_updates(params, updates) # update the parameters. predictions = net(TRAINING_DATA, params) loss = optax.huber_loss(predictions, LABELS) avg_loss = jnp.mean(optax.huber_loss(predictions, LABELS)) sum_loss = jnp.sum(optax.huber_loss(predictions, LABELS)) mean, log_scale, rng, num_samples = 0., 1., jax.random.PRNGKey(0), 100 dist_params = [mean, log_scale] function = lambda x: jnp.sum(x) jacobians = optax.monte_carlo.pathwise_jacobians( function, dist_params, optax.multi_normal, rng, num_samples) mean_grads = jnp.mean(jacobians[0], axis=0) log_scale_grads = jnp.mean(jacobians[1], axis=0) grads = [mean_grads, log_scale_grads] optim = optax.adam(1e-3) optim_state = optim.init(grads) optim_update, optim_state = optim.update(grads, optim_state) updated_dist_params = optax.apply_updates(dist_params, optim_update)