A PyTree is the most common way of representing a training state in JAX. While Orbax is designed to be as generic as possible, and provides customization options for all manner of checkpointable objects, PyTrees naturally have pride of place. Furthermore, the standard object used to represent large, sharded arrays is the jax.Array. This, too, has extensive first-class support.
CheckpointHandler Support¶There are essentially two options provided by Orbax for working with PyTrees.
StandardCheckpointHandler - applicable in the majority of use cases.PyTreeCheckpointHandler - useful when advanced customization is desired.import numpy as np
import orbax.checkpoint as ocp
import jax
sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ('model',)),
jax.sharding.PartitionSpec(
'model',
),
)
create_sharded_array = lambda x: jax.device_put(x, sharding)
state = {
'a': np.arange(16),
'b': np.ones(16),
}
state = jax.tree_util.tree_map(create_sharded_array, state)
abstract_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)
Let's use StandardCheckpointHandler to work with PyTrees of jax.Array.
path = ocp.test_utils.erase_and_create_empty('/tmp/basic/')
# Make sure to use async for improved performance!
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
ckptr.save(path / '1', args=ocp.args.StandardSave(state))
ckptr.wait_until_finished()
We specify the abstract_state in order to restore with the given dtypes, shapes, and shardings for each leaf.
restored = ckptr.restore(path / '1', args=ocp.args.StandardRestore(abstract_state))
restored
restored['a'].sharding
You can do the exact same with a "concrete" target rather than an "abstract" target. However, this requires that you fully initialize the target train state
before restoring from the checkpoint, which is inefficient. It is better practice to only initialize metadata (either by manually creating jax.ShapeDtypeStructs or using jax.eval_shape).
ckptr.restore(path / '1', args=ocp.args.StandardRestore(state))
def set_restore_dtype(x: jax.ShapeDtypeStruct) -> jax.ShapeDtypeStruct:
x.dtype = np.int16
return x
cast_dtype_abstract_state = jax.tree_util.tree_map(
set_restore_dtype, abstract_state)
ckptr.restore(
path / '1',
args=ocp.args.StandardRestore(cast_dtype_abstract_state),
)
different_shape_abstract_state = {
'a': jax.ShapeDtypeStruct(
shape=(8,),
dtype=abstract_state['a'].dtype,
sharding=abstract_state['a'].sharding
),
'b': jax.ShapeDtypeStruct(
shape=(32,),
dtype=abstract_state['b'].dtype,
sharding=abstract_state['b'].sharding
),
}
Ordinarily, specifying a target array with a different shape than in the checkpoint results in an error.
try:
ckptr.restore(
path / '1',
args=ocp.args.StandardRestore(different_shape_abstract_state),
)
except BaseException as e:
print(e)
We can pad or truncate arrays as they are loaded by specifying strict=False.
ckptr.restore(
path / '1',
args=ocp.args.StandardRestore(different_shape_abstract_state, strict=False),
)
NOTE: This can often be a particularly sharp edge.
Sharding commonly needs to be changed when loading a checkpoint saved on one topology to a different topology.
If changing topologies, you MUST specify sharding when restoring.
Unless you are loading on the exact same topology, Orbax does not make any decisions about shardings on you behalf. If you have the exact same topology, however, it is possible to avoid specifying the sharding when loading. This is demonstrated below:
restored = ckptr.restore(path / '1')
restored['a'].sharding
In the example below, we alter the sharding while loading.
sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ('x',)),
jax.sharding.PartitionSpec(),
)
def set_sharding(x: jax.ShapeDtypeStruct) -> jax.ShapeDtypeStruct:
x.sharding = sharding
return x
change_sharding_abstract_state = jax.tree_util.tree_map(
set_sharding, abstract_state)
restored = ckptr.restore(
path / '1',
args=ocp.args.StandardRestore(change_sharding_abstract_state),
)
restored['a'].sharding
There are some advanced options that StandardCheckpointHandler does not provide. Additional options can be specified using PyTreeCheckpointHandler
instead.
For example, PyTreeCheckpointHandler can be used to customize the on-disk type used to save individual arrays. First, let's save and restore as normal.
path = ocp.test_utils.erase_and_create_empty('/tmp/advanced/')
# Make sure to use async for improved performance!
ckptr = ocp.AsyncCheckpointer(ocp.PyTreeCheckpointHandler())
ckptr.save(path / '1', args=ocp.args.PyTreeSave(state))
ckptr.wait_until_finished()
restored = ckptr.restore(path / '1')
restored['a'].dtype
restored['b'].dtype
Now, let's set the dtype of the array when saving.
ckptr.save(
path / '2',
args=ocp.args.PyTreeSave(
state,
save_args={
# We must set one ocp.SaveArgs per leaf.
'a': ocp.SaveArgs(dtype=np.dtype(np.int16)),
'b': ocp.SaveArgs()
}
),
)
ckptr.wait_until_finished()
restored = ckptr.restore(path / '2')
restored['a'].dtype
restored['b'].dtype
Options similar to the above are available, where we can customize shape, dtype, and sharding when restoring.
ckptr.restore(
path / '2',
args=ocp.args.PyTreeRestore(
restore_args={
# RestoreArgs is the parent class for ArrayRestoreArgs.
# We must set one RestoreArgs per leaf.
'a': ocp.RestoreArgs(restore_type=np.ndarray),
'b': ocp.ArrayRestoreArgs(dtype=np.dtype(np.int16), sharding=sharding)
}
),
)
Note that "a" was restored as np.ndarray rather than jax.Array.
PyTreeCheckpointHandler also provides options to perform transformations when restoring. This is useful when your target tree has a different structure than your checkpoint tree. This allows you to avoid loading some keys or rename other keys. Full details are available at the Transformations page.
ckptr.restore(
path / '2',
args=ocp.args.PyTreeRestore(
# `item` serves as a guide to what the result tree structure should look
# like.
item={
# Value doesn't really matter here, as long as it's not None.
'c': ...,
# Can add in extra keys.
'd': np.arange(8)
},
# `restore_args` must be relative to the result tree, not the
# checkpoint.
restore_args={
'c': ocp.RestoreArgs(restore_type=np.ndarray),
},
transforms={
'c': ocp.Transform(original_key='a')
},
),
)