This example trains a Latent ODE.
In this case, it's on a simple dataset of decaying oscillators. That is, 2-dimensional time series that look like:
xx ***
** *
x* **
*x
x *
* * xxxxx
* x * xx xx *******
x x **
x * x * x * xxxxxxxx ******
x * x * x * xxx *xx *
x * xx ** x ** xx
x * x * x * xx ** xx
* x * x ** x * xxx
x * * x * xx **
x * x * xx xx* ***
x *x * xxx xxx *****
x x* * xx
x xx ******
xxxxx
The model is trained to generate samples that look like this.
What's really nice about this example is that we will take the underlying data to be irregularly sampled. We will have different observation times for different batch elements.
Most differential equation libraries will struggle with this, as they usually mandate that the differential equation be solved over the same timespan for all batch elements. Working around this can involve programming complexity like outputting at lots and lots of times (the union of all the observations times in the batch), or mathematical complexities like reparameterising the differentiating equation.
However Diffrax is capable of handling this without such issues! You can vmap over
different integration times for different batch elements.
Reference:
@incollection{rubanova2019latent,
title={{L}atent {O}rdinary {D}ifferential {E}quations for {I}rregularly-{S}ampled
{T}ime {S}eries},
author={Rubanova, Yulia and Chen, Ricky T. Q. and Duvenaud, David K.},
booktitle={Advances in Neural Information Processing Systems},
publisher={Curran Associates, Inc.},
year={2019},
}
This example is available as a Jupyter notebook here.
import time
import diffrax
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import optax
matplotlib.rcParams.update({"font.size": 30})
The vector field. Note its overall structure of scalar * tanh(mlp(y)) which is a good structure for Latent ODEs. (Here the tanh is part of self.mlp.)
class Func(eqx.Module):
scale: jnp.ndarray
mlp: eqx.nn.MLP
def __call__(self, t, y, args):
return self.scale * self.mlp(y)
Wrap up the differential equation solve into a model.
class LatentODE(eqx.Module):
func: Func
rnn_cell: eqx.nn.GRUCell
hidden_to_latent: eqx.nn.Linear
latent_to_hidden: eqx.nn.MLP
hidden_to_data: eqx.nn.Linear
hidden_size: int
latent_size: int
def __init__(
self, *, data_size, hidden_size, latent_size, width_size, depth, key, **kwargs
):
super().__init__(**kwargs)
mkey, gkey, hlkey, lhkey, hdkey = jr.split(key, 5)
scale = jnp.ones(())
mlp = eqx.nn.MLP(
in_size=hidden_size,
out_size=hidden_size,
width_size=width_size,
depth=depth,
activation=jnn.softplus,
final_activation=jnn.tanh,
key=mkey,
)
self.func = Func(scale, mlp)
self.rnn_cell = eqx.nn.GRUCell(data_size + 1, hidden_size, key=gkey)
self.hidden_to_latent = eqx.nn.Linear(hidden_size, 2 * latent_size, key=hlkey)
self.latent_to_hidden = eqx.nn.MLP(
latent_size, hidden_size, width_size=width_size, depth=depth, key=lhkey
)
self.hidden_to_data = eqx.nn.Linear(hidden_size, data_size, key=hdkey)
self.hidden_size = hidden_size
self.latent_size = latent_size
# Encoder of the VAE
def _latent(self, ts, ys, key):
data = jnp.concatenate([ts[:, None], ys], axis=1)
hidden = jnp.zeros((self.hidden_size,))
for data_i in reversed(data):
hidden = self.rnn_cell(data_i, hidden)
context = self.hidden_to_latent(hidden)
mean, logstd = context[: self.latent_size], context[self.latent_size :]
std = jnp.exp(logstd)
latent = mean + jr.normal(key, (self.latent_size,)) * std
return latent, mean, std
# Decoder of the VAE
def _sample(self, ts, latent):
dt0 = 0.4 # selected as a reasonable choice for this problem
y0 = self.latent_to_hidden(latent)
sol = diffrax.diffeqsolve(
diffrax.ODETerm(self.func),
diffrax.Tsit5(),
ts[0],
ts[-1],
dt0,
y0,
saveat=diffrax.SaveAt(ts=ts),
)
return jax.vmap(self.hidden_to_data)(sol.ys)
@staticmethod
def _loss(ys, pred_ys, mean, std):
# -log p_θ with Gaussian p_θ
reconstruction_loss = 0.5 * jnp.sum((ys - pred_ys) ** 2)
# KL(N(mean, std^2) || N(0, 1))
variational_loss = 0.5 * jnp.sum(mean**2 + std**2 - 2 * jnp.log(std) - 1)
return reconstruction_loss + variational_loss
# Run both encoder and decoder during training.
def train(self, ts, ys, *, key):
latent, mean, std = self._latent(ts, ys, key)
pred_ys = self._sample(ts, latent)
return self._loss(ys, pred_ys, mean, std)
# Run just the decoder during inference.
def sample(self, ts, *, key):
latent = jr.normal(key, (self.latent_size,))
return self._sample(ts, latent)
Toy dataset of decaying oscillators.
By way of illustration we set this up as a differential equation and solve this using Diffrax as well. (Despite this being an autonomous linear ODE, for which a closed-form solution is actually available.)
def get_data(dataset_size, *, key):
ykey, tkey1, tkey2 = jr.split(key, 3)
y0 = jr.normal(ykey, (dataset_size, 2))
t0 = 0
t1 = 2 + jr.uniform(tkey1, (dataset_size,))
ts = jr.uniform(tkey2, (dataset_size, 20)) * (t1[:, None] - t0) + t0
ts = jnp.sort(ts)
dt0 = 0.1
def func(t, y, args):
return jnp.array([[-0.1, 1.3], [-1, -0.1]]) @ y
def solve(ts, y0):
sol = diffrax.diffeqsolve(
diffrax.ODETerm(func),
diffrax.Tsit5(),
ts[0],
ts[-1],
dt0,
y0,
saveat=diffrax.SaveAt(ts=ts),
)
return sol.ys
ys = jax.vmap(solve)(ts, y0)
return ts, ys
def dataloader(arrays, batch_size, *, key):
dataset_size = arrays[0].shape[0]
assert all(array.shape[0] == dataset_size for array in arrays)
indices = jnp.arange(dataset_size)
while True:
perm = jr.permutation(key, indices)
(key,) = jr.split(key, 1)
start = 0
end = batch_size
while start < dataset_size:
batch_perm = perm[start:end]
yield tuple(array[batch_perm] for array in arrays)
start = end
end = start + batch_size
The main entry point. Try running main() to train a model.
def main(
dataset_size=10000,
batch_size=256,
lr=1e-2,
steps=250,
save_every=50,
hidden_size=16,
latent_size=16,
width_size=16,
depth=2,
seed=5678,
):
key = jr.PRNGKey(seed)
data_key, model_key, loader_key, train_key, sample_key = jr.split(key, 5)
ts, ys = get_data(dataset_size, key=data_key)
model = LatentODE(
data_size=ys.shape[-1],
hidden_size=hidden_size,
latent_size=latent_size,
width_size=width_size,
depth=depth,
key=model_key,
)
@eqx.filter_value_and_grad
def loss(model, ts_i, ys_i, key_i):
batch_size, _ = ts_i.shape
key_i = jr.split(key_i, batch_size)
loss = jax.vmap(model.train)(ts_i, ys_i, key=key_i)
return jnp.mean(loss)
@eqx.filter_jit
def make_step(model, opt_state, ts_i, ys_i, key_i):
value, grads = loss(model, ts_i, ys_i, key_i)
key_i = jr.split(key_i, 1)[0]
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return value, model, opt_state, key_i
optim = optax.adam(lr)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
# Plot results
num_plots = 1 + (steps - 1) // save_every
if ((steps - 1) % save_every) != 0:
num_plots += 1
fig, axs = plt.subplots(1, num_plots, figsize=(num_plots * 8, 8))
axs[0].set_ylabel("x")
axs = iter(axs)
for step, (ts_i, ys_i) in zip(
range(steps), dataloader((ts, ys), batch_size, key=loader_key)
):
start = time.time()
value, model, opt_state, train_key = make_step(
model, opt_state, ts_i, ys_i, train_key
)
end = time.time()
print(f"Step: {step}, Loss: {value}, Computation time: {end - start}")
if (step % save_every) == 0 or step == steps - 1:
ax = next(axs)
# Sample over a longer time interval than we trained on. The model will be
# sufficiently good that it will correctly extrapolate!
sample_t = jnp.linspace(0, 12, 300)
sample_y = model.sample(sample_t, key=sample_key)
sample_t = np.asarray(sample_t)
sample_y = np.asarray(sample_y)
ax.plot(sample_t, sample_y[:, 0])
ax.plot(sample_t, sample_y[:, 1])
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("t")
plt.savefig("latent_ode.png")
plt.show()
main()
Step: 0, Loss: 19.934764862060547, Computation time: 27.07537531852722 Step: 1, Loss: 17.945302963256836, Computation time: 0.1743943691253662 Step: 2, Loss: 16.862319946289062, Computation time: 0.16676902770996094 Step: 3, Loss: 17.838266372680664, Computation time: 0.1676805019378662 Step: 4, Loss: 15.913865089416504, Computation time: 0.16959643363952637 Step: 5, Loss: 15.387907028198242, Computation time: 0.16565966606140137 Step: 6, Loss: 16.50263214111328, Computation time: 0.16969871520996094 Step: 7, Loss: 17.307086944580078, Computation time: 0.17042207717895508 Step: 8, Loss: 15.414609909057617, Computation time: 0.16952204704284668 Step: 9, Loss: 16.912670135498047, Computation time: 0.16579079627990723 Step: 10, Loss: 17.230003356933594, Computation time: 0.16723251342773438 Step: 11, Loss: 18.290681838989258, Computation time: 0.16434955596923828 Step: 12, Loss: 15.541263580322266, Computation time: 0.16330623626708984 Step: 13, Loss: 15.520601272583008, Computation time: 0.16518783569335938 Step: 14, Loss: 14.719974517822266, Computation time: 0.16350150108337402 Step: 15, Loss: 15.513769149780273, Computation time: 0.16359448432922363 Step: 16, Loss: 16.30827522277832, Computation time: 0.1634058952331543 Step: 17, Loss: 14.704435348510742, Computation time: 0.16392016410827637 Step: 18, Loss: 14.534599304199219, Computation time: 0.16302919387817383 Step: 19, Loss: 14.99282455444336, Computation time: 0.1640028953552246 Step: 20, Loss: 15.04023551940918, Computation time: 0.16433429718017578 Step: 21, Loss: 15.750327110290527, Computation time: 0.16364169120788574 Step: 22, Loss: 14.745054244995117, Computation time: 0.163421630859375 Step: 23, Loss: 15.654170989990234, Computation time: 0.16426348686218262 Step: 24, Loss: 14.102017402648926, Computation time: 0.16342639923095703 Step: 25, Loss: 13.730924606323242, Computation time: 0.16349434852600098 Step: 26, Loss: 14.454326629638672, Computation time: 0.162459135055542 Step: 27, Loss: 16.074562072753906, Computation time: 0.16372108459472656 Step: 28, Loss: 14.457178115844727, Computation time: 0.16365718841552734 Step: 29, Loss: 14.899832725524902, Computation time: 0.16407418251037598 Step: 30, Loss: 14.21741771697998, Computation time: 0.16400694847106934 Step: 31, Loss: 12.896212577819824, Computation time: 0.16325831413269043 Step: 32, Loss: 13.572277069091797, Computation time: 0.16397356986999512 Step: 33, Loss: 14.58654499053955, Computation time: 0.1686105728149414 Step: 34, Loss: 14.236112594604492, Computation time: 0.1673274040222168 Step: 35, Loss: 13.96904182434082, Computation time: 0.16666364669799805 Step: 36, Loss: 13.717779159545898, Computation time: 0.16426467895507812 Step: 37, Loss: 13.212942123413086, Computation time: 0.16362261772155762 Step: 38, Loss: 13.356792449951172, Computation time: 0.16526198387145996 Step: 39, Loss: 13.750845909118652, Computation time: 26.91799235343933 Step: 40, Loss: 15.398611068725586, Computation time: 0.1675868034362793 Step: 41, Loss: 11.830371856689453, Computation time: 0.16466093063354492 Step: 42, Loss: 12.59495735168457, Computation time: 0.16176891326904297 Step: 43, Loss: 13.213092803955078, Computation time: 0.16349530220031738 Step: 44, Loss: 12.40422534942627, Computation time: 0.16125273704528809 Step: 45, Loss: 13.30964469909668, Computation time: 0.16145730018615723 Step: 46, Loss: 12.55689811706543, Computation time: 0.16156625747680664 Step: 47, Loss: 11.785927772521973, Computation time: 0.1622486114501953 Step: 48, Loss: 11.325067520141602, Computation time: 0.16244864463806152 Step: 49, Loss: 11.61506462097168, Computation time: 0.1624457836151123 Step: 50, Loss: 10.890422821044922, Computation time: 0.16366934776306152 Step: 51, Loss: 13.305912017822266, Computation time: 0.16304707527160645 Step: 52, Loss: 11.54366397857666, Computation time: 0.16243696212768555 Step: 53, Loss: 11.796025276184082, Computation time: 0.16330742835998535 Step: 54, Loss: 12.504520416259766, Computation time: 0.16342830657958984 Step: 55, Loss: 11.736138343811035, Computation time: 0.16159415245056152 Step: 56, Loss: 11.351236343383789, Computation time: 0.16047382354736328 Step: 57, Loss: 11.916851997375488, Computation time: 0.16179728507995605 Step: 58, Loss: 11.83980655670166, Computation time: 0.16157770156860352 Step: 59, Loss: 11.1612548828125, Computation time: 0.16280055046081543 Step: 60, Loss: 11.311992645263672, Computation time: 0.1631929874420166 Step: 61, Loss: 11.657142639160156, Computation time: 0.16200017929077148 Step: 62, Loss: 10.814916610717773, Computation time: 0.16182494163513184 Step: 63, Loss: 10.638484001159668, Computation time: 0.16114020347595215 Step: 64, Loss: 9.871231079101562, Computation time: 0.16211938858032227 Step: 65, Loss: 10.842245101928711, Computation time: 0.16185402870178223 Step: 66, Loss: 11.241954803466797, Computation time: 0.16134214401245117 Step: 67, Loss: 10.528236389160156, Computation time: 0.16387319564819336 Step: 68, Loss: 10.252235412597656, Computation time: 0.16159725189208984 Step: 69, Loss: 10.343666076660156, Computation time: 0.16295313835144043 Step: 70, Loss: 9.838155746459961, Computation time: 0.16141152381896973 Step: 71, Loss: 10.129756927490234, Computation time: 0.16135191917419434 Step: 72, Loss: 10.172172546386719, Computation time: 0.16157221794128418 Step: 73, Loss: 9.98276424407959, Computation time: 0.16115164756774902 Step: 74, Loss: 9.925966262817383, Computation time: 0.16163945198059082 Step: 75, Loss: 9.98451042175293, Computation time: 0.16181254386901855 Step: 76, Loss: 10.033723831176758, Computation time: 0.1613597869873047 Step: 77, Loss: 9.620193481445312, Computation time: 0.1607823371887207 Step: 78, Loss: 9.448945045471191, Computation time: 0.1607818603515625 Step: 79, Loss: 7.9748687744140625, Computation time: 0.1488492488861084 Step: 80, Loss: 9.215356826782227, Computation time: 0.16275405883789062 Step: 81, Loss: 9.691690444946289, Computation time: 0.1624891757965088 Step: 82, Loss: 8.748353958129883, Computation time: 0.16045212745666504 Step: 83, Loss: 8.528343200683594, Computation time: 0.16178536415100098 Step: 84, Loss: 8.34644889831543, Computation time: 0.16109156608581543 Step: 85, Loss: 9.200542449951172, Computation time: 0.16094589233398438 Step: 86, Loss: 8.57141399383545, Computation time: 0.1619279384613037 Step: 87, Loss: 7.508444786071777, Computation time: 0.1600663661956787 Step: 88, Loss: 7.279205322265625, Computation time: 0.16137266159057617 Step: 89, Loss: 7.090503215789795, Computation time: 0.16118311882019043 Step: 90, Loss: 7.453930377960205, Computation time: 0.16112112998962402 Step: 91, Loss: 7.0916032791137695, Computation time: 0.16120529174804688 Step: 92, Loss: 7.136333465576172, Computation time: 0.16111302375793457 Step: 93, Loss: 7.14594841003418, Computation time: 0.16206598281860352 Step: 94, Loss: 6.871617317199707, Computation time: 0.19673919677734375 Step: 95, Loss: 7.352797031402588, Computation time: 0.16296100616455078 Step: 96, Loss: 6.726633548736572, Computation time: 0.16156458854675293 Step: 97, Loss: 6.9557905197143555, Computation time: 0.16250896453857422 Step: 98, Loss: 7.102599143981934, Computation time: 0.1620466709136963 Step: 99, Loss: 7.049860954284668, Computation time: 0.16131353378295898 Step: 100, Loss: 6.750383377075195, Computation time: 0.16186952590942383 Step: 101, Loss: 7.038060188293457, Computation time: 0.16181278228759766 Step: 102, Loss: 7.034355640411377, Computation time: 0.16237926483154297 Step: 103, Loss: 6.82716178894043, Computation time: 0.16185402870178223 Step: 104, Loss: 6.787952423095703, Computation time: 0.16224908828735352 Step: 105, Loss: 6.880023002624512, Computation time: 0.16243886947631836 Step: 106, Loss: 6.616780757904053, Computation time: 0.1620333194732666 Step: 107, Loss: 6.402748107910156, Computation time: 0.16213607788085938 Step: 108, Loss: 6.7207746505737305, Computation time: 0.16174864768981934 Step: 109, Loss: 5.961440563201904, Computation time: 0.16174983978271484 Step: 110, Loss: 6.086441993713379, Computation time: 0.16232728958129883 Step: 111, Loss: 5.67965030670166, Computation time: 0.1625194549560547 Step: 112, Loss: 5.820930480957031, Computation time: 0.1604611873626709 Step: 113, Loss: 6.119414329528809, Computation time: 0.16963505744934082 Step: 114, Loss: 6.096449851989746, Computation time: 0.16268205642700195 Step: 115, Loss: 5.988513469696045, Computation time: 0.1606006622314453 Step: 116, Loss: 6.118512153625488, Computation time: 0.16241216659545898 Step: 117, Loss: 5.241769790649414, Computation time: 0.16131067276000977 Step: 118, Loss: 6.166355609893799, Computation time: 0.16092491149902344 Step: 119, Loss: 6.842771530151367, Computation time: 0.1441802978515625 Step: 120, Loss: 6.375185489654541, Computation time: 0.16277027130126953 Step: 121, Loss: 5.80587100982666, Computation time: 0.1614992618560791 Step: 122, Loss: 5.733676433563232, Computation time: 0.16245174407958984 Step: 123, Loss: 5.918340682983398, Computation time: 0.16118144989013672 Step: 124, Loss: 5.5885467529296875, Computation time: 0.16121363639831543 Step: 125, Loss: 5.8133063316345215, Computation time: 0.16047954559326172 Step: 126, Loss: 5.448032379150391, Computation time: 0.1612851619720459 Step: 127, Loss: 5.919766902923584, Computation time: 0.16178321838378906 Step: 128, Loss: 5.811756610870361, Computation time: 0.16073966026306152 Step: 129, Loss: 5.2886857986450195, Computation time: 0.16239547729492188 Step: 130, Loss: 5.062446594238281, Computation time: 0.1623084545135498 Step: 131, Loss: 5.370600700378418, Computation time: 0.16302895545959473 Step: 132, Loss: 5.032846450805664, Computation time: 0.16185617446899414 Step: 133, Loss: 5.3186492919921875, Computation time: 0.16357207298278809 Step: 134, Loss: 4.988264083862305, Computation time: 0.16092920303344727 Step: 135, Loss: 5.364264488220215, Computation time: 0.16193294525146484 Step: 136, Loss: 5.038562774658203, Computation time: 0.16143488883972168 Step: 137, Loss: 5.195552825927734, Computation time: 0.16141676902770996 Step: 138, Loss: 4.877957344055176, Computation time: 0.16106271743774414 Step: 139, Loss: 4.971206188201904, Computation time: 0.15976953506469727 Step: 140, Loss: 4.850249767303467, Computation time: 0.16672515869140625 Step: 141, Loss: 5.053151607513428, Computation time: 0.16182613372802734 Step: 142, Loss: 4.553808212280273, Computation time: 0.16060352325439453 Step: 143, Loss: 4.6004109382629395, Computation time: 0.16107678413391113 Step: 144, Loss: 4.889383316040039, Computation time: 0.1608583927154541 Step: 145, Loss: 4.736492156982422, Computation time: 0.16157317161560059 Step: 146, Loss: 4.708489894866943, Computation time: 0.16304683685302734 Step: 147, Loss: 4.679104804992676, Computation time: 0.1609785556793213 Step: 148, Loss: 4.689470291137695, Computation time: 0.16070127487182617 Step: 149, Loss: 4.528751850128174, Computation time: 0.16136622428894043 Step: 150, Loss: 4.48677396774292, Computation time: 0.1604769229888916 Step: 151, Loss: 4.637646675109863, Computation time: 0.16101288795471191 Step: 152, Loss: 4.762913703918457, Computation time: 0.16133403778076172 Step: 153, Loss: 4.44551944732666, Computation time: 0.1619107723236084 Step: 154, Loss: 4.5776472091674805, Computation time: 0.1616075038909912 Step: 155, Loss: 4.562440395355225, Computation time: 0.16150236129760742 Step: 156, Loss: 4.409887313842773, Computation time: 0.16173315048217773 Step: 157, Loss: 4.46767520904541, Computation time: 0.16112399101257324 Step: 158, Loss: 4.25125789642334, Computation time: 0.16138744354248047 Step: 159, Loss: 4.785336971282959, Computation time: 0.1468524932861328 Step: 160, Loss: 5.054254055023193, Computation time: 0.16128849983215332 Step: 161, Loss: 4.8799567222595215, Computation time: 0.1611628532409668 Step: 162, Loss: 4.688265800476074, Computation time: 0.16042160987854004 Step: 163, Loss: 4.51352596282959, Computation time: 0.1602628231048584 Step: 164, Loss: 4.331615447998047, Computation time: 0.1609640121459961 Step: 165, Loss: 4.137004852294922, Computation time: 0.16290068626403809 Step: 166, Loss: 4.654952049255371, Computation time: 0.16114187240600586 Step: 167, Loss: 4.4677629470825195, Computation time: 0.16231393814086914 Step: 168, Loss: 4.510952949523926, Computation time: 0.16344356536865234 Step: 169, Loss: 4.258943557739258, Computation time: 0.16016602516174316 Step: 170, Loss: 4.283701419830322, Computation time: 0.1614704132080078 Step: 171, Loss: 4.368310451507568, Computation time: 0.1617722511291504 Step: 172, Loss: 4.095067024230957, Computation time: 0.16355204582214355 Step: 173, Loss: 4.290921211242676, Computation time: 0.16144156455993652 Step: 174, Loss: 4.135052680969238, Computation time: 0.16065239906311035 Step: 175, Loss: 4.188730239868164, Computation time: 0.16092491149902344 Step: 176, Loss: 3.9966931343078613, Computation time: 0.16103458404541016 Step: 177, Loss: 4.127541542053223, Computation time: 0.16103053092956543 Step: 178, Loss: 4.2538557052612305, Computation time: 0.1615607738494873 Step: 179, Loss: 4.453568458557129, Computation time: 0.1603102684020996 Step: 180, Loss: 4.0408525466918945, Computation time: 0.16083049774169922 Step: 181, Loss: 4.516185760498047, Computation time: 0.1609797477722168 Step: 182, Loss: 4.250395774841309, Computation time: 0.1612706184387207 Step: 183, Loss: 4.046529769897461, Computation time: 0.16176581382751465 Step: 184, Loss: 4.198785781860352, Computation time: 0.16283583641052246 Step: 185, Loss: 3.9407706260681152, Computation time: 0.16234254837036133 Step: 186, Loss: 4.026411056518555, Computation time: 0.1624460220336914 Step: 187, Loss: 4.224530220031738, Computation time: 0.16072320938110352 Step: 188, Loss: 4.028736591339111, Computation time: 0.16074919700622559 Step: 189, Loss: 3.837322950363159, Computation time: 0.16036534309387207 Step: 190, Loss: 4.123674392700195, Computation time: 0.16191387176513672 Step: 191, Loss: 3.9622178077697754, Computation time: 0.16129708290100098 Step: 192, Loss: 3.969315528869629, Computation time: 0.16092944145202637 Step: 193, Loss: 3.7825825214385986, Computation time: 0.16073131561279297 Step: 194, Loss: 3.9199018478393555, Computation time: 0.16074514389038086 Step: 195, Loss: 4.052471160888672, Computation time: 0.16427040100097656 Step: 196, Loss: 3.7691221237182617, Computation time: 0.16066265106201172 Step: 197, Loss: 3.937032699584961, Computation time: 0.16099143028259277 Step: 198, Loss: 4.042672634124756, Computation time: 0.16167831420898438 Step: 199, Loss: 3.7281570434570312, Computation time: 0.14007043838500977 Step: 200, Loss: 4.159261226654053, Computation time: 0.16143798828125 Step: 201, Loss: 4.408998489379883, Computation time: 0.16060853004455566 Step: 202, Loss: 4.1045427322387695, Computation time: 0.16067767143249512 Step: 203, Loss: 4.352884292602539, Computation time: 0.1615588665008545 Step: 204, Loss: 4.170437335968018, Computation time: 0.16057705879211426 Step: 205, Loss: 3.970756769180298, Computation time: 0.1603851318359375 Step: 206, Loss: 4.299739837646484, Computation time: 0.16051793098449707 Step: 207, Loss: 4.127477645874023, Computation time: 0.16169023513793945 Step: 208, Loss: 4.360357761383057, Computation time: 0.1614537239074707 Step: 209, Loss: 3.9281232357025146, Computation time: 0.16314291954040527 Step: 210, Loss: 3.9255576133728027, Computation time: 0.16143369674682617 Step: 211, Loss: 4.089841842651367, Computation time: 0.162628173828125 Step: 212, Loss: 4.131923675537109, Computation time: 0.1637284755706787 Step: 213, Loss: 4.047548294067383, Computation time: 0.16175484657287598 Step: 214, Loss: 4.078159809112549, Computation time: 0.1614534854888916 Step: 215, Loss: 4.092671871185303, Computation time: 0.16064238548278809 Step: 216, Loss: 4.069928169250488, Computation time: 0.16089081764221191 Step: 217, Loss: 3.7901744842529297, Computation time: 0.16229534149169922 Step: 218, Loss: 4.05171012878418, Computation time: 0.16241717338562012 Step: 219, Loss: 4.072657585144043, Computation time: 0.16231489181518555 Step: 220, Loss: 4.119385719299316, Computation time: 0.16376709938049316 Step: 221, Loss: 3.946767568588257, Computation time: 0.16153383255004883 Step: 222, Loss: 3.8579845428466797, Computation time: 0.16051745414733887 Step: 223, Loss: 3.955892324447632, Computation time: 0.16411495208740234 Step: 224, Loss: 4.090612411499023, Computation time: 0.16119980812072754 Step: 225, Loss: 3.871494770050049, Computation time: 0.1633768081665039 Step: 226, Loss: 4.001490116119385, Computation time: 0.1612398624420166 Step: 227, Loss: 3.856689453125, Computation time: 0.16136479377746582 Step: 228, Loss: 3.854506254196167, Computation time: 0.16175079345703125 Step: 229, Loss: 3.920146942138672, Computation time: 0.16027593612670898 Step: 230, Loss: 3.8486571311950684, Computation time: 0.16107869148254395 Step: 231, Loss: 4.150424003601074, Computation time: 0.161329984664917 Step: 232, Loss: 4.034335613250732, Computation time: 0.16145658493041992 Step: 233, Loss: 3.862642288208008, Computation time: 0.16074752807617188 Step: 234, Loss: 3.879786491394043, Computation time: 0.16097068786621094 Step: 235, Loss: 3.9150876998901367, Computation time: 0.1610715389251709 Step: 236, Loss: 3.6582045555114746, Computation time: 0.16137981414794922 Step: 237, Loss: 4.022642612457275, Computation time: 0.16101980209350586 Step: 238, Loss: 3.920273780822754, Computation time: 0.16168999671936035 Step: 239, Loss: 4.942720890045166, Computation time: 0.139939546585083 Step: 240, Loss: 3.820035457611084, Computation time: 0.16096997261047363 Step: 241, Loss: 4.027595520019531, Computation time: 0.1608715057373047 Step: 242, Loss: 3.9767158031463623, Computation time: 0.16132664680480957 Step: 243, Loss: 3.927661895751953, Computation time: 0.16009283065795898 Step: 244, Loss: 4.054908275604248, Computation time: 0.16004633903503418 Step: 245, Loss: 4.072584629058838, Computation time: 0.1604931354522705 Step: 246, Loss: 4.165594100952148, Computation time: 0.16080093383789062 Step: 247, Loss: 3.9277215003967285, Computation time: 0.16055607795715332 Step: 248, Loss: 4.001946449279785, Computation time: 0.1610417366027832 Step: 249, Loss: 3.9720990657806396, Computation time: 0.16184639930725098