This tutorial explores a simplified version of a generative model called Variational Autoencoder (VAE) with scikit-learn digits dataset, and expands on what we learned in Getting started with JAX. Along the way, you'll learn more about how JAX's JIT compilation (jax.jit) actually works, and what this means for debugging JAX programs, as we learn how to identify what can go wrong during model training.
If you are new to JAX for AI, check out the first tutorial, which explains how to build a simple neural netwwork with Flax and Optax, and JAX's key features, including the NumPy-style interface with jax.numpy, JAX transformations for JIT compilation with jax.jit, automatic vectorization with jax.vmap, and automatic differentiation with jax.grad.
As before, this example uses the well-known, small and self-contained scikit-learn digits dataset:
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
import jax.numpy as jnp
digits = load_digits()
splits = train_test_split(digits.images, random_state=0)
images_train, images_test = map(jnp.asarray, splits)
print(f"{images_train.shape=}")
print(f"{images_test.shape=}")
images_train.shape=(1347, 8, 8) images_test.shape=(450, 8, 8)
The dataset comprises 1800 images of hand-written digits, each represented by an 8x8 pixel grid, and their corresponding labels. For visualization of this data, refer to loading the data in the previous tutorial.
Previously, we learned how to use Flax NNX to create a simple feed-forward neural network trained for classification with an architecture that looked roughly like this:
import jax
import jax.numpy as jnp
from flax import nnx
class SimpleNN(nnx.Module):
def __init__(self, n_features=64, n_hidden=100, n_targets=10, *, rngs: nnx.Rngs):
self.layer1 = nnx.Linear(n_features, n_hidden, rngs=rngs)
self.layer2 = nnx.Linear(n_hidden, n_hidden, rngs=rngs)
self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs)
def __call__(self, x: jax.Array) -> jax.Array:
x = nnx.selu(self.layer1(x))
x = nnx.selu(self.layer2(x))
return self.layer3(x)
This kind of network has one output per class, and the loss function is designed such that once the model is trained, the output corresponding to the correct class would return the strongest signal, thus predicting the correct label in upwards of 95% of cases.
To create a VAE with Flax NNX, we will use similar building blocks - subclassing flax.nnx.Module, stacking flax.nnx.Linear layers, and adding a rectified linear unit activation function (flax.nnx.relu). A VAE maps the input data into the parameters of a probability distribution (mean, std), and the output is a small probabilistic model representing the data.
Note that the classic VAE is generally based on convolutional layers, this example uses linear layers for simplicity.
The sub-network that produces this probabilistic encoding is the Encoder:
class Encoder(nnx.Module):
def __init__(self, input_size: int, intermediate_size: int, output_size: int,
*, rngs: nnx.Rngs):
self.rngs = rngs
self.linear = nnx.Linear(input_size, intermediate_size, rngs=rngs)
self.linear_mean = nnx.Linear(intermediate_size, output_size, rngs=rngs)
self.linear_std = nnx.Linear(intermediate_size, output_size, rngs=rngs)
def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:
x = self.linear(x)
x = jax.nn.relu(x)
mean = self.linear_mean(x)
std = jnp.exp(self.linear_std(x))
key = self.rngs.noise()
z = mean + std * jax.random.normal(key, mean.shape)
return z, mean, std
The idea here is that mean and std define a low-dimensional probability distribution over a latent space, and that z is a draw from this latent space that represents the training data.
To ensure that this latent distribution faithfully represents the actual data, define a Decoder that maps back to the input space as follows:
class Decoder(nnx.Module):
def __init__(self, input_size: int, intermediate_size: int, output_size: int,
*, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(input_size, intermediate_size, rngs=rngs)
self.linear2 = nnx.Linear(intermediate_size, output_size, rngs=rngs)
def __call__(self, z: jax.Array) -> jax.Array:
z = self.linear1(z)
z = jax.nn.relu(z)
logits = self.linear2(z)
return logits
Now, define the VAE model (again by subclassing flax.nnx.Module) by combining Encoder and Decoder in a single network (VAE).
The model returns both the reconstructed image and the internal latent space model:
class VAE(nnx.Module):
def __init__(
self,
image_shape: tuple[int, int],
hidden_size: int,
latent_size: int,
*,
rngs: nnx.Rngs
):
self.image_shape = image_shape
self.latent_size = latent_size
input_size = image_shape[0] * image_shape[1]
self.encoder = Encoder(input_size, hidden_size, latent_size, rngs=rngs)
self.decoder = Decoder(latent_size, hidden_size, input_size, rngs=rngs)
def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:
x = jax.vmap(jax.numpy.ravel)(x) # flatten
z, mean, std = self.encoder(x)
logits = self.decoder(z)
logits = jnp.reshape(logits, (-1, *self.image_shape))
return logits, mean, std
Next, we need to define the loss function. The are two components to the model that we want to ensure:
logits output faithfully reconstruct the input image.mean and std faithfully represents the "true" latent distribution.Note that VAE uses a loss function based on the Evidence lower bound to quantify these two goals in a single loss value:
def vae_loss(model: VAE, x: jax.Array):
logits, mean, std = model(x)
kl_loss = jnp.mean(0.5 * jnp.mean(
-jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))
reconstruction_loss = jnp.mean(
optax.sigmoid_binary_cross_entropy(logits, x)
)
return reconstruction_loss + 0.1 * kl_loss
Now all that's left:
VAE model.optax.adam (the Adam optimizer in our example), and instantiate the optimizer with flax.nnx.Optimizer for setting the train step.train_step using flax.nnx.value_and_grad for computing the gradients and update the model’s parameters using the optimizer.flax.nnx.jit transformation decorator to trace the train_step function for just-in-time compilation.import optax
model = VAE(
image_shape=(8, 8),
hidden_size=32,
latent_size=8,
rngs=nnx.Rngs(0, noise=1),
)
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
@nnx.jit
def train_step(model: VAE, optimizer: nnx.Optimizer, x: jax.Array):
loss, grads = nnx.value_and_grad(vae_loss)(model, x)
optimizer.update(grads)
return loss
for epoch in range(2001):
loss = train_step(model, optimizer, images_train)
if epoch % 500 == 0:
print(f'Epoch {epoch} loss: {loss}')
Epoch 0 loss: 16745235.0 Epoch 500 loss: nan Epoch 1000 loss: nan Epoch 1500 loss: nan Epoch 2000 loss: nan
Notice in the output that something has gone wrong - the loss value has become NaN after some number of iterations.
Despite our best efforts, the VAE model is producing NaNs. What can we do?
JAX offers a number of debugging approaches for situations like this, outlined in JAX's Debugging runtime values guide. (There is also the Introduction to debugging tutorial you may find useful.)
In this case, we can use the jax.debug_nans configuration to check where the NaN value is arising.
model = VAE(
image_shape=(8, 8),
hidden_size=32,
latent_size=8,
rngs=nnx.Rngs(0, noise=1),
)
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
with jax.debug_nans(True):
for epoch in range(2001):
train_step(model, optimizer, images_train)
Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version. Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.
--------------------------------------------------------------------------- FloatingPointError Traceback (most recent call last) /usr/local/lib/python3.10/dist-packages/jax/_src/api.py in _nan_check_posthook(fun, args, kwargs, output) 113 try: --> 114 dispatch.check_special(pjit.pjit_p.name, buffers) 115 except FloatingPointError: /usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py in check_special(name, bufs) 320 for buf in bufs: --> 321 _check_special(name, buf.dtype, buf) 322 /usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py in _check_special(name, dtype, buf) 325 if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))): --> 326 raise FloatingPointError(f"invalid value (nan) encountered in {name}") 327 if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))): FloatingPointError: invalid value (nan) encountered in pjit During handling of the above exception, another exception occurred: FloatingPointError Traceback (most recent call last) [... skipping hidden 1 frame] /usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py in check_special(name, bufs) 320 for buf in bufs: --> 321 _check_special(name, buf.dtype, buf) 322 /usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py in _check_special(name, dtype, buf) 325 if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))): --> 326 raise FloatingPointError(f"invalid value (nan) encountered in {name}") 327 if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))): FloatingPointError: invalid value (nan) encountered in pjit During handling of the above exception, another exception occurred: FloatingPointError Traceback (most recent call last) [... skipping hidden 1 frame] /usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py in wrapper(*args, **kwargs) 332 with TraceAnnotation(name, **decorator_kwargs): --> 333 return func(*args, **kwargs) 334 return wrapper /usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py in __call__(self, *args) 1290 for arrays in out_arrays: -> 1291 dispatch.check_special(self.name, arrays) 1292 out = self.out_handler(out_arrays) /usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py in check_special(name, bufs) 320 for buf in bufs: --> 321 _check_special(name, buf.dtype, buf) 322 /usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py in _check_special(name, dtype, buf) 325 if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))): --> 326 raise FloatingPointError(f"invalid value (nan) encountered in {name}") 327 if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))): FloatingPointError: invalid value (nan) encountered in jit(jit_fn) During handling of the above exception, another exception occurred: FloatingPointError Traceback (most recent call last) [... skipping hidden 1 frame] /usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py in wrapper(*args, **kwargs) 332 with TraceAnnotation(name, **decorator_kwargs): --> 333 return func(*args, **kwargs) 334 return wrapper /usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py in __call__(self, *args) 1290 for arrays in out_arrays: -> 1291 dispatch.check_special(self.name, arrays) 1292 out = self.out_handler(out_arrays) /usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py in check_special(name, bufs) 320 for buf in bufs: --> 321 _check_special(name, buf.dtype, buf) 322 /usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py in _check_special(name, dtype, buf) 325 if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))): --> 326 raise FloatingPointError(f"invalid value (nan) encountered in {name}") 327 if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))): FloatingPointError: invalid value (nan) encountered in jit(dot_general) During handling of the above exception, another exception occurred: JaxStackTraceBeforeTransformation Traceback (most recent call last) /usr/lib/python3.10/runpy.py in _run_module_as_main() 195 sys.argv[0] = mod_spec.origin --> 196 return _run_code(code, main_globals, None, 197 "__main__", mod_spec) /usr/lib/python3.10/runpy.py in _run_code() 85 __spec__ = mod_spec) ---> 86 exec(code, run_globals) 87 return run_globals /usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py in <module> 36 if __name__ == '__main__': ---> 37 ColabKernelApp.launch_instance() /usr/local/lib/python3.10/dist-packages/traitlets/config/application.py in launch_instance() 991 app.initialize(argv) --> 992 app.start() 993 /usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py in start() 618 try: --> 619 self.io_loop.start() 620 except KeyboardInterrupt: /usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py in start() 194 def start(self) -> None: --> 195 self.asyncio_loop.run_forever() 196 /usr/lib/python3.10/asyncio/base_events.py in run_forever() 602 while True: --> 603 self._run_once() 604 if self._stopping: /usr/lib/python3.10/asyncio/base_events.py in _run_once() 1908 else: -> 1909 handle._run() 1910 handle = None # Needed to break cycles when an exception occurs. /usr/lib/python3.10/asyncio/events.py in _run() 79 try: ---> 80 self._context.run(self._callback, *self._args) 81 except (SystemExit, KeyboardInterrupt): /usr/local/lib/python3.10/dist-packages/tornado/ioloop.py in <lambda>() 684 future.add_done_callback( --> 685 lambda f: self._run_callback(functools.partial(callback, future)) 686 ) /usr/local/lib/python3.10/dist-packages/tornado/ioloop.py in _run_callback() 737 try: --> 738 ret = callback() 739 if ret is not None: /usr/local/lib/python3.10/dist-packages/tornado/gen.py in inner() 824 f = None # noqa: F841 --> 825 self.ctx_run(self.run) 826 /usr/local/lib/python3.10/dist-packages/tornado/gen.py in run() 785 else: --> 786 yielded = self.gen.send(value) 787 /usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py in process_one() 360 return None --> 361 yield gen.maybe_future(dispatch(*args)) 362 /usr/local/lib/python3.10/dist-packages/tornado/gen.py in wrapper() 233 try: --> 234 yielded = ctx_run(next, result) 235 except (StopIteration, Return) as e: /usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py in dispatch_shell() 260 try: --> 261 yield gen.maybe_future(handler(stream, idents, msg)) 262 except Exception: /usr/local/lib/python3.10/dist-packages/tornado/gen.py in wrapper() 233 try: --> 234 yielded = ctx_run(next, result) 235 except (StopIteration, Return) as e: /usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py in execute_request() 538 reply_content = yield gen.maybe_future( --> 539 self.do_execute( 540 code, silent, store_history, /usr/local/lib/python3.10/dist-packages/tornado/gen.py in wrapper() 233 try: --> 234 yielded = ctx_run(next, result) 235 except (StopIteration, Return) as e: /usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py in do_execute() 301 # letting shell dispatch to loop runners --> 302 res = shell.run_cell(code, store_history=store_history, silent=silent) 303 finally: /usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py in run_cell() 538 self._last_traceback = None --> 539 return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs) 540 /usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py in run_cell() 2974 try: -> 2975 result = self._run_cell( 2976 raw_cell, store_history, silent, shell_futures, cell_id /usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py in _run_cell() 3029 try: -> 3030 return runner(coro) 3031 except BaseException as e: /usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py in _pseudo_sync_runner() 77 try: ---> 78 coro.send(None) 79 except StopIteration as exc: /usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py in run_cell_async() 3256 -> 3257 has_raised = await self.run_ast_nodes(code_ast.body, cell_name, 3258 interactivity=interactivity, compiler=compiler, result=result) /usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py in run_ast_nodes() 3472 asy = compare(code) -> 3473 if (await self.run_code(code, result, async_=asy)): 3474 return True /usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py in run_code() 3552 else: -> 3553 exec(code_obj, self.user_global_ns, self.user_ns) 3554 finally: <ipython-input-8-0e49237a86d4> in <cell line: 10>() 11 for epoch in range(2001): ---> 12 train_step(model, optimizer, images_train) /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/graph.py in update_context_manager_wrapper() 1042 with self: -> 1043 return f(*args, **kwargs) 1044 /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/transforms.py in jit_wrapper() 358 graphdef, state = ctx.split(input_graph_nodes) --> 359 out, output_state, output_graphdef = jitted_fn( 360 *args, /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/transforms.py in jit_fn() 157 --> 158 out = f(*args, **kwargs) 159 <ipython-input-7-b5b28eeeadf6> in train_step() 13 def train_step(model: VAE, optimizer: nnx.Optimizer, x: jax.Array): ---> 14 loss, grads = nnx.value_and_grad(vae_loss)(model, x) 15 optimizer.update(grads) /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/graph.py in update_context_manager_wrapper() 1042 with self: -> 1043 return f(*args, **kwargs) 1044 /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/transforms.py in grad_wrapper() 567 --> 568 out = transform( 569 grad_fn, /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/transforms.py in grad_fn() 511 --> 512 out = f(*args) 513 <ipython-input-6-305266f603a1> in vae_loss() 1 def vae_loss(model: VAE, x: jax.Array): ----> 2 logits, mean, std = model(x) 3 kl_loss = jnp.mean(0.5 * jnp.mean( <ipython-input-5-f5ec22b83e57> in __call__() 17 x = jax.vmap(jax.numpy.ravel)(x) # flatten ---> 18 z, mean, std = self.encoder(x) 19 logits = self.decoder(z) <ipython-input-3-05c99264f49e> in __call__() 13 mean = self.linear_mean(x) ---> 14 std = jnp.exp(self.linear_std(x)) 15 /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/nn/linear.py in __call__() 380 ) --> 381 y = self.dot_general( 382 inputs, JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in jit(dot_general). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. If you see this error, consider opening a bug report at https://github.com/google/jax. The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception. -------------------- The above exception was the direct cause of the following exception: FloatingPointError Traceback (most recent call last) <ipython-input-8-0e49237a86d4> in <cell line: 10>() 10 with jax.debug_nans(True): 11 for epoch in range(2001): ---> 12 train_step(model, optimizer, images_train) /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/graph.py in update_context_manager_wrapper(*args, **kwargs) 1041 def update_context_manager_wrapper(*args, **kwargs): 1042 with self: -> 1043 return f(*args, **kwargs) 1044 1045 return update_context_manager_wrapper # type: ignore /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/transforms.py in jit_wrapper(*args, **kwargs) 357 ) 358 graphdef, state = ctx.split(input_graph_nodes) --> 359 out, output_state, output_graphdef = jitted_fn( 360 *args, 361 _nnx_jit_static=JitStaticInputs(graphdef, _constrain_state, fun), /usr/local/lib/python3.10/dist-packages/jax/_src/api.py in _nan_check_posthook(fun, args, kwargs, output) 118 print("Invalid nan value encountered in the output of a C++-jit/pmap " 119 "function. Calling the de-optimized version.") --> 120 fun._cache_miss(*args, **kwargs)[0] # probably won't return 121 122 def _update_debug_special_global(_): [... skipping hidden 24 frame] /usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py in _pjit_call_impl_python(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, *args) 1697 "If you see this error, consider opening a bug report at " 1698 "https://github.com/google/jax.") -> 1699 raise FloatingPointError(msg) 1700 1701 FloatingPointError: invalid value (nan) encountered in jit(dot_general). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. If you see this error, consider opening a bug report at https://github.com/google/jax.
The output here is complicated, because the function we're evaluating is complicated. The key to "deciphering" this traceback is to look for the places where the traceback touches our implementation.
In particular here, the output above indicates that NaN values arise during the gradient update:
<ipython-input-9-b5b28eeeadf6> in train_step()
14 loss, grads = nnx.value_and_grad(vae_loss)(model, x)
---> 15 optimizer.update(grads)
16 return loss
and further down from this, the details of the gradient update step where the NaN is arising:
/usr/local/lib/python3.10/dist-packages/optax/tree_utils/_tree_math.py in <lambda>()
280 lambda g, t: (
--> 281 (1 - decay) * (g**order) + decay * t if g is not None else None
282 ),
This suggests that the gradient is returning values that lead to NaN during the model update. Typically, this would come about when the gradient itself is for some reason diverging.
A diverging gradient means that something with the loss function may be amiss. Previously, we had loss=NaN at iteration 500. Let's print the progress up to this point:
model = VAE(
image_shape=(8, 8),
hidden_size=32,
latent_size=8,
rngs=nnx.Rngs(0, noise=1),
)
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
for epoch in range(501):
loss = train_step(model, optimizer, images_train)
if epoch % 50 == 0:
print(f'Epoch {epoch} loss: {loss}')
Epoch 0 loss: 16745235.0 Epoch 50 loss: 19.595727920532227 Epoch 100 loss: -13.440512657165527 Epoch 150 loss: -145.24871826171875 Epoch 200 loss: -683.0828247070312 Epoch 250 loss: -2291.444091796875 Epoch 300 loss: -6880.775390625
It looks like the loss value is decreasing toward negative infinity until the point where the values are no longer well-represented by floating point math.
At this point, we may wish to inspect the values within the loss function itself to see where the diverging loss might be coming from.
In typical Python programs we can do this by inserting either a print statement or a breakpoint in the loss function. This may look something like this:
def vae_loss(model: VAE, x: jax.Array):
logits, mean, std = model(x)
kl_loss = jnp.mean(0.5 * jnp.mean(
-jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))
reconstruction_loss = jnp.mean(
optax.sigmoid_binary_cross_entropy(logits, x)
)
print("kl loss", kl_loss)
print("reconstruction loss", reconstruction_loss)
return reconstruction_loss + 0.1 * kl_loss
model = VAE(
image_shape=(8, 8),
hidden_size=32,
latent_size=8,
rngs=nnx.Rngs(0, noise=1),
)
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
train_step(model, optimizer, images_train)
kl loss Traced<ShapedArray(float32[])>with<JVPTrace(level=3/0)> with
primal = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=2/0)> with
pval = (ShapedArray(float32[]), None)
recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7b9be70b1b20>, in_tracers=(Traced<ShapedArray(float32[1347]):JaxprTrace(level=2/0)>,), out_tracer_refs=[<weakref at 0x7b9be6aecb30; to 'JaxprTracer' at 0x7b9be6aecae0>], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[1347]. let
b:f32[] = reduce_sum[axes=(0,)] a
c:f32[] = div b 1347.0
in (c,) }, 'in_shardings': (UnspecifiedValue,), 'out_shardings': (UnspecifiedValue,), 'in_layouts': (None,), 'out_layouts': (None,), 'resource_env': None, 'donated_invars': (False,), 'name': '_mean', 'keep_unused': False, 'inline': True}, effects=set(), source_info=<jax._src.source_info_util.SourceInfo object at 0x7b9be6ae38b0>, ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={})
reconstruction loss Traced<ShapedArray(float32[])>with<JVPTrace(level=3/0)> with
primal = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=2/0)> with
pval = (ShapedArray(float32[]), None)
recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7b9be70b2040>, in_tracers=(Traced<ShapedArray(float32[1347,8,8]):JaxprTrace(level=2/0)>,), out_tracer_refs=[<weakref at 0x7b9be6aed850; to 'JaxprTracer' at 0x7b9be6aed800>], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[1347,8,8]. let
b:f32[] = reduce_sum[axes=(0, 1, 2)] a
c:f32[] = div b 86208.0
in (c,) }, 'in_shardings': (UnspecifiedValue,), 'out_shardings': (UnspecifiedValue,), 'in_layouts': (None,), 'out_layouts': (None,), 'resource_env': None, 'donated_invars': (False,), 'name': '_mean', 'keep_unused': False, 'inline': True}, effects=set(), source_info=<jax._src.source_info_util.SourceInfo object at 0x7b9be6af89a0>, ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={})
Array(16745235., dtype=float32)
But here rather than printing the value, we're getting some kind of Traced object. You'll encounter this frequently when inspecting the progress of JAX programs: tracers are the mechanism that JAX uses to implement transformations like jax.jit and jax.grad, and you can read more about them in JAX Key Concepts: Tracing.
In this example, the workaround is to use another tool from the Debugging runtime values guide: namely jax.debug.print, which allows us to print runtime values even when they're traced:
def vae_loss(model: VAE, x: jax.Array):
logits, mean, std = model(x)
kl_loss = jnp.mean(0.5 * jnp.mean(
-jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))
reconstruction_loss = jnp.mean(
optax.sigmoid_binary_cross_entropy(logits, x)
)
jax.debug.print("kl_loss: {}", kl_loss)
jax.debug.print("reconstruction_loss: {}", reconstruction_loss)
return reconstruction_loss + 0.1 * kl_loss
model = VAE(
image_shape=(8, 8),
hidden_size=32,
latent_size=8,
rngs=nnx.Rngs(0, noise=1),
)
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
for i in range(5):
train_step(model, optimizer, images_train)
kl_loss: 167451888.0 reconstruction_loss: 44.51668167114258 kl_loss: 21651530.0 reconstruction_loss: 6.270397186279297 kl_loss: 4448844.5 reconstruction_loss: -14.727174758911133 kl_loss: 1285240.625
Let's iterate a few hundred more times (we'll use the IPython %%capture magic to avoid printing all the output on the first several hundred iterations) and then do one more run to print these intermediate values:
%%capture
for i in range(300):
train_step(model, optimizer, images_train)
loss = train_step(model, optimizer, images_train)
kl_loss: 2462.782470703125 reconstruction_loss: -8067.7255859375
The output above suggests that the large negative value is coming from the reconstruction_loss term. Let's return to this and inspect what it's actually doing:
reconstruction_loss = jnp.mean(
optax.sigmoid_binary_cross_entropy(logits, x)
)
This is a binary cross entropy described at optax.sigmoid_binary_cross_entropy. Based on the Optax documentation, the first input should be a logit, and the second input is assumed to be a binary label (i.e. a 0 or 1) – but in the current implementation x is associated with images_train, which is not a binary label!
print(images_train[0])
[[ 0. 3. 13. 16. 9. 0. 0. 0.] [ 0. 10. 15. 13. 15. 2. 0. 0.] [ 0. 15. 4. 4. 16. 1. 0. 0.] [ 0. 0. 0. 5. 16. 2. 0. 0.] [ 0. 0. 1. 14. 13. 0. 0. 0.] [ 0. 0. 10. 16. 5. 0. 0. 0.] [ 0. 4. 16. 13. 8. 10. 9. 1.] [ 0. 2. 16. 16. 14. 12. 9. 1.]]
This is likely the source of the issue: we forgot to normalize the input images to the range (0, 1)!
Let's fix this by binarizing the inputs, and then run the training loop again (redefining the loss function to remove the debug statements):
images_normed = (digits.images / 16) > 0.5
splits = train_test_split(images_normed, random_state=0)
images_train, images_test = map(jnp.asarray, splits)
def vae_loss(model: VAE, x: jax.Array):
logits, mean, std = model(x)
kl_loss = jnp.mean(0.5 * jnp.mean(
-jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))
reconstruction_loss = jnp.mean(
optax.sigmoid_binary_cross_entropy(logits, x)
)
return reconstruction_loss + 0.1 * kl_loss
model = VAE(
image_shape=(8, 8),
hidden_size=32,
latent_size=8,
rngs=nnx.Rngs(0, noise=1),
)
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
for epoch in range(2001):
loss = train_step(model, optimizer, images_train)
if epoch % 500 == 0:
print(f'Epoch {epoch} loss: {loss}')
Epoch 0 loss: 0.7710005640983582 Epoch 500 loss: 0.3110124468803406 Epoch 1000 loss: 0.2782602906227112 Epoch 1500 loss: 0.26861754059791565 Epoch 2000 loss: 0.26275068521499634
The loss values are now "behaving" without showing NaNs.
We have successfully debugged the initial NaN problem, which was not in the VAE model but rather in the input data.
Now that we have a trained VAE model, let's explore what it can be used for.
First, let's pass the test data through the model to output the result of the associated latent space representation for each input.
Pass the logits through a sigmoid function to recover predicted images in the input space:
logits, mean, std = model(images_test)
images_pred = jax.nn.sigmoid(logits)
Let's visualize several of these inputs and outputs:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(2, 10, figsize=(6, 1.5),
subplot_kw={'xticks':[], 'yticks':[]},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i in range(10):
ax[0, i].imshow(images_test[i], cmap='binary', interpolation='gaussian')
ax[1, i].imshow(images_pred[i], cmap='binary', interpolation='gaussian')
The top row here are the input images, and the bottom row are what the model "thinks" these images look like, given their latent space representation. There's not perfect fidelity, but the essential features are recovered.
We can go a step further and generate a set of new images from scratch by sampling randomly from the latent space. Let's generate 36 new digits this way:
import numpy as np
# generate new images by sampling the latent space
z = np.random.normal(scale=1.5, size=(36, model.latent_size))
logits = model.decoder(z).reshape(-1, 8, 8)
images_gen = nnx.sigmoid(logits)
fig, ax = plt.subplots(6, 6, figsize=(4, 4),
subplot_kw={'xticks':[], 'yticks':[]},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i in range(36):
ax.flat[i].imshow(images_gen[i], cmap='binary', interpolation='gaussian')
Another possibility here is to use the latent model to interpolate between two entries in the training set through the latent model space.
Here's an interpolation between a digit 9 and a digit 3:
z, _, _ = model.encoder(images_train.reshape(-1, 64))
zrange = jnp.linspace(z[2], z[9], 10)
logits = model.decoder(zrange).reshape(-1, 8, 8)
images_gen = nnx.sigmoid(logits)
fig, ax = plt.subplots(1, 10, figsize=(8, 1),
subplot_kw={'xticks':[], 'yticks':[]},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i in range(10):
ax.flat[i].imshow(images_gen[i], cmap='binary', interpolation='gaussian')
This tutorial offered an example of defining and training a generative model - a simplified VAE - and approaches to debugging JAX programs using the jax.debug_nans configuration and the jax.debug.print function.
You can learn more about debugging on the JAX documentation site in Debugging runtime values and Introduction to debugging.