import jax
import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt
Consider the following min-max problem:
$$ \min_{x \in \mathbb R^m} \max_{y\in\mathbb R^n} f(x,y), $$where $f: \mathbb R^m \times \mathbb R^n \to \mathbb R$ is a convex-concave function. The solution to such a problem is a saddle-point $(x^\star, y^\star)\in \mathbb R^m \times \mathbb R^n$ such that
$$ f(x^\star, y) \leq f(x^\star, y^\star) \leq f(x, y^\star). $$Standard gradient descent-ascent (GDA) updates $x$ and $y$ according to the following update rule at step $k$:
$$ x_{k+1} = x_k - \eta_k \nabla_x f(x_k, y_k) \\ y_{k+1} = y_k + \eta_k \nabla_y f(x_k, y_k), $$where $\eta_k$ is a step size. However, it's well-documented that GDA can fail to converge in this setting. This is an important issue because gradient-based min-max optimisation is increasingly prevalent in machine learning (e.g., GANs, constrained RL). Optimistic GDA (OGDA) addresses this shortcoming by introducing a form of memory-based negative momentum:
$$ x_{k+1} = x_k - 2 \eta_k \nabla_x f(x_k, y_k) + \eta_k \nabla_x f(x_{k-1}, y_{k-1}) \\ y_{k+1} = y_k + 2 \eta_k \nabla_y f(x_k, y_k) - \eta_k \nabla_y f(x_{k-1}, y_{k-1})). $$Thus, to implement OGD (or OGA), the optimiser needs to keep track of the gradient from the previous step. OGDA has been formally shown to converge to the optimum $(x_k, y_k) \to (x^\star, y^\star)$ in this setting. The generalised form of the OGDA update rule is given by
$$ x_{k+1} = x_k - (\alpha + \beta) \eta_k \nabla_x f(x_k, y_k) + \beta \eta_k \nabla_x f(x_{k-1}, y_{k-1}) \\ y_{k+1} = y_k + (\alpha + \beta) \eta_k \nabla_y f(x_k, y_k) - \beta \eta_k \nabla_y f(x_{k-1}, y_{k-1})), $$which recovers standard OGDA when $\alpha=\beta=1$. See Mokhtari et al., 2019 for more details.
where $\eta_k$ is a step size. However, it's well-documented that GDA can fail to converge in this setting. This is an important issue because gradient-based min-max optimisation is increasingly prevalent in machine learning (e.g., GANs, constrained RL). Optimistic GDA (OGDA) addresses this shortcoming by introducing a form of memory-based negative momentum:
$$ x_{k+1} = x_k - 2 \eta_k \nabla_x f(x_k, y_k) + \eta_k \nabla_x f(x_{k-1}, y_{k-1}) \\ y_{k+1} = y_k + 2 \eta_k \nabla_y f(x_k, y_k) - \eta_k \nabla_y f(x_{k-1}, y_{k-1})). $$Define a bilinear min-max objective function: $\min_x \max_y xy$.
def f(params: jnp.ndarray) -> jnp.ndarray:
"""Objective: min_x max_y xy."""
return params["x"] * params["y"]
Define an optimisation loop.
def optimise(params: optax.Params, x_optimiser: optax.GradientTransformation, y_optimiser: optax.GradientTransformation, n_steps: int = 1000, display_every: int = 100) -> optax.Params:
"""An optimisation loop minimising x and maximising y."""
x_opt_state = x_optimiser.init(params["x"])
y_opt_state = y_optimiser.init(params["y"])
param_hist = [params]
f_hist = []
@jax.jit
def step(params, x_opt_state, y_opt_state):
f_value, grads = jax.value_and_grad(f)(params)
x_update, x_opt_state = x_optimiser.update(grads["x"], x_opt_state, params["x"])
# note that we"re maximising y so we feed in the negative gradient to the OGD update
y_update, y_opt_state = y_optimiser.update(-grads["y"], y_opt_state, params["y"])
updates = {"x": x_update, "y": y_update}
params = optax.apply_updates(params, updates)
return params, x_opt_state, y_opt_state, f_value
for k in range(n_steps):
params, x_opt_state, y_opt_state, f_value = step(params, x_opt_state, y_opt_state)
param_hist.append(params)
f_hist.append(f_value)
if k % display_every == 0:
print(f"step {k}, f(x, y) = {f_value}, (x, y) = ({params['x']}, {params['y']})")
return param_hist, f_hist
Initialise $x$ and $y$, as well as optimisers for each.
initial_params = {
"x": jnp.array(1.0),
"y": jnp.array(1.0)
}
# GDA
x_gd_optimiser = optax.sgd(learning_rate=0.1)
y_ga_optimiser = optax.sgd(learning_rate=0.1)
# OGDA
x_ogd_optimiser = optax.optimistic_gradient_descent(learning_rate=0.1)
y_oga_optimiser = optax.optimistic_gradient_descent(learning_rate=0.1)
Run each method.
gda_hist, gda_f_hist = optimise(initial_params, x_gd_optimiser, y_ga_optimiser)
ogda_hist, ogda_f_hist = optimise(initial_params, x_ogd_optimiser, y_oga_optimiser)
Visualise the optimisation trajectories. The optimal solution is $(0, 0)$.
gda_xs, gda_ys = [p["x"] for p in gda_hist], [p["y"] for p in gda_hist]
ogda_xs, ogda_ys = [p["x"] for p in ogda_hist], [p["y"] for p in ogda_hist]
plt.plot(gda_xs, gda_ys, alpha=0.6, color="C0", label="GDA")
plt.plot(ogda_xs, ogda_ys, alpha=0.6, color="C1", label="OGDA")
plt.scatter([1], [1], color="r", label=r"$(x_0, y_0)$", s=30)
plt.scatter([0], [0], color="k", label=r"$(x^\star, y^\star)$", s=30)
plt.xlim([-2.0, 2.0])
plt.ylim([-2.0, 2.0])
plt.legend(loc="lower right")
plt.show()