import jax import jax.numpy as jnp import optax import matplotlib.pyplot as plt def f(params: jnp.ndarray) -> jnp.ndarray: """Objective: min_x max_y xy.""" return params["x"] * params["y"] def optimise(params: optax.Params, x_optimiser: optax.GradientTransformation, y_optimiser: optax.GradientTransformation, n_steps: int = 1000, display_every: int = 100) -> optax.Params: """An optimisation loop minimising x and maximising y.""" x_opt_state = x_optimiser.init(params["x"]) y_opt_state = y_optimiser.init(params["y"]) param_hist = [params] f_hist = [] @jax.jit def step(params, x_opt_state, y_opt_state): f_value, grads = jax.value_and_grad(f)(params) x_update, x_opt_state = x_optimiser.update(grads["x"], x_opt_state, params["x"]) # note that we"re maximising y so we feed in the negative gradient to the OGD update y_update, y_opt_state = y_optimiser.update(-grads["y"], y_opt_state, params["y"]) updates = {"x": x_update, "y": y_update} params = optax.apply_updates(params, updates) return params, x_opt_state, y_opt_state, f_value for k in range(n_steps): params, x_opt_state, y_opt_state, f_value = step(params, x_opt_state, y_opt_state) param_hist.append(params) f_hist.append(f_value) if k % display_every == 0: print(f"step {k}, f(x, y) = {f_value}, (x, y) = ({params['x']}, {params['y']})") return param_hist, f_hist initial_params = { "x": jnp.array(1.0), "y": jnp.array(1.0) } # GDA x_gd_optimiser = optax.sgd(learning_rate=0.1) y_ga_optimiser = optax.sgd(learning_rate=0.1) # OGDA x_ogd_optimiser = optax.optimistic_gradient_descent(learning_rate=0.1) y_oga_optimiser = optax.optimistic_gradient_descent(learning_rate=0.1) gda_hist, gda_f_hist = optimise(initial_params, x_gd_optimiser, y_ga_optimiser) ogda_hist, ogda_f_hist = optimise(initial_params, x_ogd_optimiser, y_oga_optimiser) gda_xs, gda_ys = [p["x"] for p in gda_hist], [p["y"] for p in gda_hist] ogda_xs, ogda_ys = [p["x"] for p in ogda_hist], [p["y"] for p in ogda_hist] plt.plot(gda_xs, gda_ys, alpha=0.6, color="C0", label="GDA") plt.plot(ogda_xs, ogda_ys, alpha=0.6, color="C1", label="OGDA") plt.scatter([1], [1], color="r", label=r"$(x_0, y_0)$", s=30) plt.scatter([0], [0], color="k", label=r"$(x^\star, y^\star)$", s=30) plt.xlim([-2.0, 2.0]) plt.ylim([-2.0, 2.0]) plt.legend(loc="lower right") plt.show()