#!/usr/bin/env python # coding: utf-8 # # Computing second-order sensitivities # This example demonstrates how to compute the Hessian of a differential equation solve. # # This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/diffrax/blob/main/examples/hessian.ipynb). # In[1]: import jax import jax.numpy as jnp from diffrax import diffeqsolve, ODETerm, Tsit5 def vector_field(t, y, args): prey, predator = y α, β, γ, δ = args d_prey = α * prey - β * prey * predator d_predator = -γ * predator + δ * prey * predator d_y = d_prey, d_predator return d_y @jax.jit @jax.hessian def run(y0): term = ODETerm(vector_field) solver = Tsit5(scan_kind="bounded") t0 = 0 t1 = 140 dt0 = 0.1 args = (0.1, 0.02, 0.4, 0.02) sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args) ((prey,), _) = sol.ys return prey y0 = (jnp.array(10.0), jnp.array(10.0)) run(y0) # Note the use of the `scan_kind` argument to `Tsit5`. By default, Diffrax internally uses constructs that are optimised specifically for first-order reverse-mode autodifferentiation. This argument is needed to switch to a different implementation that is compatible with higher-order autodiff. (In this case: for the loop-over-stages in the Runge--Kutta solver.) # # In similar fashion, if using `saveat=SaveAt(ts=...)` (or a handful of other esoteric cases) then you will need to pass `adjoint=DirectAdjoint()`. (In this case: for the loop-over-saving output.)