This example constructs a neural SDE as a generative time series model.
An SDE is, of course, random: it defines some distribution. Each sample is a whole path. Thus in modern machine learning parlance, an SDE is a generative time series model. This means it can be trained as a GAN, for example. This does mean we need a discriminator that consumes a path as an input; we use a CDE.
Training an SDE as a GAN is precisely what this example does. Doing so will reproduce the following toy example, which is trained on irregularly-sampled time series:

References:
Training SDEs as GANs:
@inproceedings{kidger2021sde1,
title={{N}eural {SDE}s as {I}nfinite-{D}imensional {GAN}s},
author={Kidger, Patrick and Foster, James and Li, Xuechen and Lyons, Terry J},
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
pages = {5453--5463},
year = {2021},
volume = {139},
series = {Proceedings of Machine Learning Research},
publisher = {PMLR},
}
Improved training techniques:
@incollection{kidger2021sde2,
title={{E}fficient and {A}ccurate {G}radients for {N}eural {SDE}s},
author={Kidger, Patrick and Foster, James and Li, Xuechen and Lyons, Terry},
booktitle = {Advances in Neural Information Processing Systems 34},
year = {2021},
publisher = {Curran Associates, Inc.},
}
This example is available as a Jupyter notebook here.
!!! warning
This example will need a GPU to run efficiently.
!!! danger "Advanced example"
This is an advanced example, due to the complexity of the modelling techniques used.
from typing import Union
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax # https://github.com/deepmind/optax
LipSwish activation functions are a good choice for the discriminator of an SDE-GAN. (Their use here was introduced in the second reference above.) For simplicity we will actually use LipSwish activations everywhere, even in the generator.
def lipswish(x):
return 0.909 * jnn.silu(x)
Now set up the vector fields appearing on the right hand side of each differential equation.
class VectorField(eqx.Module):
scale: Union[int, jnp.ndarray]
mlp: eqx.nn.MLP
def __init__(self, hidden_size, width_size, depth, scale, *, key, **kwargs):
super().__init__(**kwargs)
scale_key, mlp_key = jr.split(key)
if scale:
self.scale = jr.uniform(scale_key, (hidden_size,), minval=0.9, maxval=1.1)
else:
self.scale = 1
self.mlp = eqx.nn.MLP(
in_size=hidden_size + 1,
out_size=hidden_size,
width_size=width_size,
depth=depth,
activation=lipswish,
final_activation=jnn.tanh,
key=mlp_key,
)
def __call__(self, t, y, args):
t = jnp.asarray(t)
return self.scale * self.mlp(jnp.concatenate([t[None], y]))
class ControlledVectorField(eqx.Module):
scale: Union[int, jnp.ndarray]
mlp: eqx.nn.MLP
control_size: int
hidden_size: int
def __init__(
self, control_size, hidden_size, width_size, depth, scale, *, key, **kwargs
):
super().__init__(**kwargs)
scale_key, mlp_key = jr.split(key)
if scale:
self.scale = jr.uniform(
scale_key, (hidden_size, control_size), minval=0.9, maxval=1.1
)
else:
self.scale = 1
self.mlp = eqx.nn.MLP(
in_size=hidden_size + 1,
out_size=hidden_size * control_size,
width_size=width_size,
depth=depth,
activation=lipswish,
final_activation=jnn.tanh,
key=mlp_key,
)
self.control_size = control_size
self.hidden_size = hidden_size
def __call__(self, t, y, args):
t = jnp.asarray(t)
return self.scale * self.mlp(jnp.concatenate([t[None], y])).reshape(
self.hidden_size, self.control_size
)
Now set up the neural SDE (the generator) and the neural CDE (the discriminator).
Note the use of very large step sizes. By using a large step size we essentially "bake in" the discretisation. This is quite a standard thing to do to decrease computational costs, when the vector field is a pure neural network. (You can reduce the step size here if you want to -- which will increase the computational cost, of course.)
Note the clip_weights method on the CDE -- this is part of imposing the Lipschitz condition on the discriminator of a Wasserstein GAN.
(The other thing doing this is the use of those LipSwish activation functions we saw earlier)
class NeuralSDE(eqx.Module):
initial: eqx.nn.MLP
vf: VectorField # drift
cvf: ControlledVectorField # diffusion
readout: eqx.nn.Linear
initial_noise_size: int
noise_size: int
def __init__(
self,
data_size,
initial_noise_size,
noise_size,
hidden_size,
width_size,
depth,
*,
key,
**kwargs,
):
super().__init__(**kwargs)
initial_key, vf_key, cvf_key, readout_key = jr.split(key, 4)
self.initial = eqx.nn.MLP(
initial_noise_size, hidden_size, width_size, depth, key=initial_key
)
self.vf = VectorField(hidden_size, width_size, depth, scale=True, key=vf_key)
self.cvf = ControlledVectorField(
noise_size, hidden_size, width_size, depth, scale=True, key=cvf_key
)
self.readout = eqx.nn.Linear(hidden_size, data_size, key=readout_key)
self.initial_noise_size = initial_noise_size
self.noise_size = noise_size
def __call__(self, ts, *, key):
t0 = ts[0]
t1 = ts[-1]
# Very large dt0 for computational speed
dt0 = 1.0
init_key, bm_key = jr.split(key, 2)
init = jr.normal(init_key, (self.initial_noise_size,))
control = diffrax.VirtualBrownianTree(
t0=t0, t1=t1, tol=dt0 / 2, shape=(self.noise_size,), key=bm_key
)
vf = diffrax.ODETerm(self.vf) # Drift term
cvf = diffrax.ControlTerm(self.cvf, control) # Diffusion term
terms = diffrax.MultiTerm(vf, cvf)
# ReversibleHeun is a cheap choice of SDE solver. We could also use Euler etc.
solver = diffrax.ReversibleHeun()
y0 = self.initial(init)
saveat = diffrax.SaveAt(ts=ts)
sol = diffrax.diffeqsolve(terms, solver, t0, t1, dt0, y0, saveat=saveat)
return jax.vmap(self.readout)(sol.ys)
class NeuralCDE(eqx.Module):
initial: eqx.nn.MLP
vf: VectorField
cvf: ControlledVectorField
readout: eqx.nn.Linear
def __init__(self, data_size, hidden_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
initial_key, vf_key, cvf_key, readout_key = jr.split(key, 4)
self.initial = eqx.nn.MLP(
data_size + 1, hidden_size, width_size, depth, key=initial_key
)
self.vf = VectorField(hidden_size, width_size, depth, scale=False, key=vf_key)
self.cvf = ControlledVectorField(
data_size, hidden_size, width_size, depth, scale=False, key=cvf_key
)
self.readout = eqx.nn.Linear(hidden_size, 1, key=readout_key)
def __call__(self, ts, ys):
# Interpolate data into a continuous path.
ys = diffrax.linear_interpolation(
ts, ys, replace_nans_at_start=0.0, fill_forward_nans_at_end=True
)
init = jnp.concatenate([ts[0, None], ys[0]])
control = diffrax.LinearInterpolation(ts, ys)
vf = diffrax.ODETerm(self.vf)
cvf = diffrax.ControlTerm(self.cvf, control)
terms = diffrax.MultiTerm(vf, cvf)
solver = diffrax.ReversibleHeun()
t0 = ts[0]
t1 = ts[-1]
dt0 = 1.0
y0 = self.initial(init)
# Have the discriminator produce an output at both `t0` *and* `t1`.
# The output at `t0` has only seen the initial point of a sample. This gives
# additional supervision to the distribution learnt for the initial condition.
# The output at `t1` has seen the entire path of a sample. This is needed to
# actually learn the evolving trajectory.
saveat = diffrax.SaveAt(t0=True, t1=True)
sol = diffrax.diffeqsolve(terms, solver, t0, t1, dt0, y0, saveat=saveat)
return jax.vmap(self.readout)(sol.ys)
@eqx.filter_jit
def clip_weights(self):
leaves, treedef = jax.tree_util.tree_flatten(
self, is_leaf=lambda x: isinstance(x, eqx.nn.Linear)
)
new_leaves = []
for leaf in leaves:
if isinstance(leaf, eqx.nn.Linear):
lim = 1 / leaf.out_features
leaf = eqx.tree_at(
lambda x: x.weight, leaf, leaf.weight.clip(-lim, lim)
)
new_leaves.append(leaf)
return jax.tree_util.tree_unflatten(treedef, new_leaves)
Next, the dataset. This follows the trajectories you can see in the picture above. (Namely positive drift with mean-reversion and time-dependent diffusion.)
@jax.jit
@jax.vmap
def get_data(key):
bm_key, y0_key, drop_key = jr.split(key, 3)
mu = 0.02
theta = 0.1
sigma = 0.4
t0 = 0
t1 = 63
t_size = 64
def drift(t, y, args):
return mu * t - theta * y
def diffusion(t, y, args):
return 2 * sigma * t / t1
bm = diffrax.UnsafeBrownianPath(shape=(), key=bm_key)
drift = diffrax.ODETerm(drift)
diffusion = diffrax.ControlTerm(diffusion, bm)
terms = diffrax.MultiTerm(drift, diffusion)
solver = diffrax.Euler()
dt0 = 0.1
y0 = jr.uniform(y0_key, (1,), minval=-1, maxval=1)
ts = jnp.linspace(t0, t1, t_size)
saveat = diffrax.SaveAt(ts=ts)
sol = diffrax.diffeqsolve(
terms, solver, t0, t1, dt0, y0, saveat=saveat, adjoint=diffrax.DirectAdjoint()
)
# Make the data irregularly sampled
to_drop = jr.bernoulli(drop_key, 0.3, (t_size, 1))
ys = jnp.where(to_drop, jnp.nan, sol.ys)
return ts, ys
def dataloader(arrays, batch_size, loop, *, key):
dataset_size = arrays[0].shape[0]
assert all(array.shape[0] == dataset_size for array in arrays)
indices = jnp.arange(dataset_size)
while True:
perm = jr.permutation(key, indices)
key = jr.split(key, 1)[0]
start = 0
end = batch_size
while end < dataset_size:
batch_perm = perm[start:end]
yield tuple(array[batch_perm] for array in arrays)
start = end
end = start + batch_size
if not loop:
break
Now the usual training step for GAN training.
There is one neural-SDE-specific trick here: we increase the update size (i.e. the learning rate) for those parameters describing (and discriminating) the initial condition of the SDE. Otherwise the model tends to focus just on fitting just the rest of the data (i.e. the random evolution over time).
@eqx.filter_jit
def loss(generator, discriminator, ts_i, ys_i, key, step=0):
batch_size, _ = ts_i.shape
key = jr.fold_in(key, step)
key = jr.split(key, batch_size)
fake_ys_i = jax.vmap(generator)(ts_i, key=key)
real_score = jax.vmap(discriminator)(ts_i, ys_i)
fake_score = jax.vmap(discriminator)(ts_i, fake_ys_i)
return jnp.mean(real_score - fake_score)
@eqx.filter_grad
def grad_loss(g_d, ts_i, ys_i, key, step):
generator, discriminator = g_d
return loss(generator, discriminator, ts_i, ys_i, key, step)
def increase_update_initial(updates):
get_initial_leaves = lambda u: jax.tree_util.tree_leaves(u.initial)
return eqx.tree_at(get_initial_leaves, updates, replace_fn=lambda x: x * 10)
@eqx.filter_jit
def make_step(
generator,
discriminator,
g_opt_state,
d_opt_state,
g_optim,
d_optim,
ts_i,
ys_i,
key,
step,
):
g_grad, d_grad = grad_loss((generator, discriminator), ts_i, ys_i, key, step)
g_updates, g_opt_state = g_optim.update(g_grad, g_opt_state)
d_updates, d_opt_state = d_optim.update(d_grad, d_opt_state)
g_updates = increase_update_initial(g_updates)
d_updates = increase_update_initial(d_updates)
generator = eqx.apply_updates(generator, g_updates)
discriminator = eqx.apply_updates(discriminator, d_updates)
discriminator = discriminator.clip_weights()
return generator, discriminator, g_opt_state, d_opt_state
This is our main entry point. Try running main().
def main(
initial_noise_size=5,
noise_size=3,
hidden_size=16,
width_size=16,
depth=1,
generator_lr=2e-5,
discriminator_lr=1e-4,
batch_size=1024,
steps=10000,
steps_per_print=200,
dataset_size=8192,
seed=5678,
):
key = jr.PRNGKey(seed)
(
data_key,
generator_key,
discriminator_key,
dataloader_key,
train_key,
evaluate_key,
sample_key,
) = jr.split(key, 7)
data_key = jr.split(data_key, dataset_size)
ts, ys = get_data(data_key)
_, _, data_size = ys.shape
generator = NeuralSDE(
data_size,
initial_noise_size,
noise_size,
hidden_size,
width_size,
depth,
key=generator_key,
)
discriminator = NeuralCDE(
data_size, hidden_size, width_size, depth, key=discriminator_key
)
g_optim = optax.rmsprop(generator_lr)
d_optim = optax.rmsprop(-discriminator_lr)
g_opt_state = g_optim.init(eqx.filter(generator, eqx.is_inexact_array))
d_opt_state = d_optim.init(eqx.filter(discriminator, eqx.is_inexact_array))
infinite_dataloader = dataloader(
(ts, ys), batch_size, loop=True, key=dataloader_key
)
for step, (ts_i, ys_i) in zip(range(steps), infinite_dataloader):
step = jnp.asarray(step)
generator, discriminator, g_opt_state, d_opt_state = make_step(
generator,
discriminator,
g_opt_state,
d_opt_state,
g_optim,
d_optim,
ts_i,
ys_i,
key,
step,
)
if (step % steps_per_print) == 0 or step == steps - 1:
total_score = 0
num_batches = 0
for ts_i, ys_i in dataloader(
(ts, ys), batch_size, loop=False, key=evaluate_key
):
score = loss(generator, discriminator, ts_i, ys_i, sample_key)
total_score += score.item()
num_batches += 1
print(f"Step: {step}, Loss: {total_score / num_batches}")
# Plot samples
fig, ax = plt.subplots()
num_samples = min(50, dataset_size)
ts_to_plot = ts[:num_samples]
ys_to_plot = ys[:num_samples]
def _interp(ti, yi):
return diffrax.linear_interpolation(
ti, yi, replace_nans_at_start=0.0, fill_forward_nans_at_end=True
)
ys_to_plot = jax.vmap(_interp)(ts_to_plot, ys_to_plot)[..., 0]
ys_sampled = jax.vmap(generator)(ts_to_plot, key=jr.split(sample_key, num_samples))[
..., 0
]
kwargs = dict(label="Real")
for ti, yi in zip(ts_to_plot, ys_to_plot):
ax.plot(ti, yi, c="dodgerblue", linewidth=0.5, alpha=0.7, **kwargs)
kwargs = {}
kwargs = dict(label="Generated")
for ti, yi in zip(ts_to_plot, ys_sampled):
ax.plot(ti, yi, c="crimson", linewidth=0.5, alpha=0.7, **kwargs)
kwargs = {}
ax.set_title(f"{num_samples} samples from both real and generated distributions.")
fig.legend()
fig.tight_layout()
fig.savefig("neural_sde.png")
plt.show()
main()
Step: 0, Loss: 0.13390611750738962 Step: 200, Loss: 4.786926678248814 Step: 400, Loss: 7.736175605228969 Step: 600, Loss: 10.103722981044225 Step: 800, Loss: 11.831081799098424 Step: 1000, Loss: 7.418417045048305 Step: 1200, Loss: 6.938951356070382 Step: 1400, Loss: 2.881302390779768 Step: 1600, Loss: 1.5363099915640694 Step: 1800, Loss: 1.0079529796327864 Step: 2000, Loss: 0.936917781829834 Step: 2200, Loss: 0.9594544768333435 Step: 2400, Loss: 1.247592806816101 Step: 2600, Loss: 0.9021680951118469 Step: 2800, Loss: 0.861811808177403 Step: 3000, Loss: 1.1381437267575945 Step: 3200, Loss: 1.5369644505637032 Step: 3400, Loss: 1.3387839964457922 Step: 3600, Loss: 1.0477747491427831 Step: 3800, Loss: 1.7565655538014002 Step: 4000, Loss: 1.8188678196498327 Step: 4200, Loss: 1.4719816957201277 Step: 4400, Loss: 1.4189972026007516 Step: 4600, Loss: 0.6867345826966422 Step: 4800, Loss: 0.6138326355389186 Step: 5000, Loss: 0.5908999613353184 Step: 5200, Loss: 0.579599814755576 Step: 5400, Loss: -0.8964726499148777 Step: 5600, Loss: -4.22784035546439 Step: 5800, Loss: 1.8623723132269723 Step: 6000, Loss: -0.17913252328123366 Step: 6200, Loss: 1.2232166869299752 Step: 6400, Loss: 1.1680303982325964 Step: 6600, Loss: -0.5765694592680249 Step: 6800, Loss: 0.5931433950151715 Step: 7000, Loss: 0.12497492773192269 Step: 7200, Loss: 0.5957097922052655 Step: 7400, Loss: 0.33551327671323505 Step: 7600, Loss: 0.5243289640971592 Step: 7800, Loss: 0.797236042363303 Step: 8000, Loss: 0.5341930559703282 Step: 8200, Loss: 1.1995042221886771 Step: 8400, Loss: -0.5231874521289553 Step: 8600, Loss: -0.42040516648973736 Step: 8800, Loss: 1.384656548500061 Step: 9000, Loss: 1.4223246574401855 Step: 9200, Loss: 0.2646511915538992 Step: 9400, Loss: -0.046253203813518794 Step: 9600, Loss: 0.738983656678881 Step: 9800, Loss: 1.1247712458883012 Step: 9999, Loss: -0.44179755449295044