from typing import Callable, Iterator, Tuple import chex import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import optax def generator() -> Iterator[Tuple[chex.Array, chex.Array]]: rng = jax.random.PRNGKey(0) while True: rng, k1, k2 = jax.random.split(rng, num=3) x = jax.random.uniform(k1, minval=0.0, maxval=10.0) y = 10.0 * x + jax.random.normal(k2) yield x, y g = generator() for _ in range(5): x, y = next(g) print(f"Sampled y = {y:.3f}, x = {x:.3f}") def f(theta: chex.Array, x: chex.Array) -> chex.Array: return x * theta theta = jax.random.normal(jax.random.PRNGKey(42)) init_learning_rate = jnp.array(0.1) meta_learning_rate = jnp.array(0.03) opt = optax.inject_hyperparams(optax.rmsprop)(learning_rate=init_learning_rate) meta_opt = optax.adam(learning_rate=meta_learning_rate) def loss(theta, x, y): return optax.l2_loss(y, f(theta, x)) def step(theta, state, x, y): grad = jax.grad(loss)(theta, x, y) updates, state = opt.update(grad, state) theta = optax.apply_updates(theta, updates) return theta, state def outer_loss(eta, theta, state, samples): state.hyperparams['learning_rate'] = jax.nn.sigmoid(eta) for x, y in samples[:-1]: theta, state = step(theta, state, x, y) x, y = samples[-1] return loss(theta, x, y), (theta, state) @jax.jit def outer_step(eta, theta, meta_state, state, samples): grad, (theta, state) = jax.grad( outer_loss, has_aux=True)(eta, theta, state, samples) meta_updates, meta_state = meta_opt.update(grad, meta_state) eta = optax.apply_updates(eta, meta_updates) return eta, theta, meta_state, state state = opt.init(theta) # inverse sigmoid, to match the value we initialized the inner optimizer with. eta = -np.log(1. / init_learning_rate - 1) meta_state = meta_opt.init(eta) N = 7 learning_rates = [] thetas = [] for i in range(2000): samples = [next(g) for i in range(N)] eta, theta, meta_state, state = outer_step(eta, theta, meta_state, state, samples) learning_rates.append(jax.nn.sigmoid(eta)) thetas.append(theta) fig, (ax1, ax2) = plt.subplots(2); fig.suptitle('Meta-learning RMSProp\'s learning rate'); plt.xlabel('Step'); ax1.semilogy(range(len(learning_rates)), learning_rates); ax1.set(ylabel='Learning rate'); ax1.label_outer(); plt.xlabel('Number of updates'); ax2.semilogy(range(len(thetas)), thetas); ax2.label_outer(); ax2.set(ylabel='Theta');