As of orbax-checkpoint-0.5.0, several new APIs have been introduced at multiple different levels. The most significant change is to how users interact with CheckpointManager. This page shows a side-by-side comparison of the old and new APIs.
The legacy APIs is deprecated and will stop working after May 1st, 2024. Please ensure you are using the new style by then.
CheckpointManager.save(...) is now async by default. Make sure you call wait_until_finished if depending on a previous save being completed. Otherwise, the behavior can be disabled via the
CheckpointManagerOptions.enable_async_checkpointing option.
For further information on how to use the new API, see the introductory tutorial and the API Overview.
import orbax.checkpoint as ocp
# Dummy PyTrees for simplicity.
# In reality, this would be a tree of np.ndarray or jax.Array.
pytree = {'a': 0}
# In reality, this would be a tree of jax.ShapeDtypeStruct (metadata
# for restoration).
abstract_pytree = {'a': 0}
extra_metadata = {'version': 1.0}
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
ocp.test_utils.erase_and_create_empty('/tmp/ckpt1/'),
ocp.Checkpointer(ocp.PyTreeCheckpointHandler()),
options=options,
)
restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_pytree)
mngr.save(0, pytree)
mngr.wait_until_finished()
mngr.restore(
0,
items=abstract_pytree,
restore_kwargs={'restore_args': restore_args}
)
options = ocp.CheckpointManagerOptions()
with ocp.CheckpointManager(
ocp.test_utils.erase_and_create_empty('/tmp/ckpt2/'),
options=options,
) as mngr:
mngr.save(0, args=ocp.args.StandardSave(pytree))
mngr.wait_until_finished()
# After providing `args` during an initial `save` or `restore` call, the
# `CheckpointManager` instance records the type so that you do not need to
# specify it again. If the `CheckpointManager` instance is not provided with a
# `ocp.args.CheckpointArgs` instance for a particular item on a previous
# occasion it cannot be restored without specifying the argument at restore
# time.
# In many cases, you can restore exactly as saved without specifying additional
# arguments.
mngr.restore(0)
# If customization of properties like sharding or dtype is desired, just provide
# the abstract target PyTree, the properties of which will be used to set
# the properties of the restored arrays.
mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))
Important notes:
args=... for save and restore! Otherwise you will get the legacy API. This will not be necessary forever, but only until the legacy API is removed.args is a subclass of CheckpointArgs, present in the ocp.args module. These classes are used to communicate the logic that you wish to use to save and restore your object. For a typical PyTree consisting of arrays, use StandardSave/StandardRestore.Let's explore scenarios when restore() and item_metadata() calls raise errors due to unspecified CheckpointHandlers for item names.
CheckpointManager(..., item_handlers=...) can be used to resolve these scenarios.
# Unmapped CheckpointHandlers on a new CheckpointManager instance.
new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)
new_mngr.restore(0) # Raises error due to unmapped CheckpointHandler
new_mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))
new_mngr.close()
# item_handlers can be used as an alternative to restore(..., args=...).
with ocp.CheckpointManager(
'/tmp/ckpt2/',
options=options,
item_handlers=ocp.StandardCheckpointHandler()
) as new_mngr:
print(new_mngr.restore(0))
NOTE:
CheckpointManager.item_metadata(step) doesn't support any input like args in restore(..., args=...).
So, item_handlers is the only option available with item_metadata(step) calls.
# item_handlers becomes even more critical with item_metadata() calls.
new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)
new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandler
new_mngr = ocp.CheckpointManager(
'/tmp/ckpt2/',
options=options,
item_handlers=ocp.StandardCheckpointHandler(),
)
new_mngr.item_metadata(0)
new_mngr.close()
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
ocp.test_utils.erase_and_create_empty('/tmp/ckpt3/'),
{
'state': ocp.Checkpointer(ocp.PyTreeCheckpointHandler()),
'extra_metadata': ocp.Checkpointer(ocp.JsonCheckpointHandler())
},
options=options,
)
restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_pytree)
mngr.save(0, {'state': pytree, 'extra_metadata': extra_metadata})
mngr.wait_until_finished()
mngr.restore(
0,
items={'state': abstract_pytree, 'extra_metadata': None},
restore_kwargs={
'state': {'restore_args': restore_args},
'extra_metadata': None
},
)
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
ocp.test_utils.erase_and_create_empty('/tmp/ckpt4/'),
# `item_names` defines an up-front contract about what items the
# CheckpointManager will be dealing with.
item_names=('state', 'extra_metadata'),
options=options,
)
mngr.save(0, args=ocp.args.Composite(
state=ocp.args.StandardSave(pytree),
extra_metadata=ocp.args.JsonSave(extra_metadata))
)
mngr.wait_until_finished()
# Restore as saved
mngr.restore(0)
# Restore with customization. Restore a subset of items.
mngr.restore(0, args=ocp.args.Composite(
state=ocp.args.StandardRestore(abstract_pytree)))
mngr.close()
Just like single item use case described above, let's explore scenarios when restore() and item_metadata() calls raise errors due to unspecified CheckpointHandlers for item names.
CheckpointManager(..., item_handlers=...) can be used to resolve these scenarios.
# Unmapped CheckpointHandlers on a new CheckpointManager instance.
new_mngr = ocp.CheckpointManager(
'/tmp/ckpt4/',
options=options,
item_names=('state', 'extra_metadata'),
)
new_mngr.restore(0) # Raises error due to unmapped CheckpointHandlers
new_mngr.restore(
0,
args=ocp.args.Composite(
state=ocp.args.StandardRestore(abstract_pytree),
extra_metadata=ocp.args.JsonRestore(),
),
)
new_mngr.close()
# item_handlers can be used as an alternative to restore(..., args=...).
with ocp.CheckpointManager(
'/tmp/ckpt4/',
options=options,
item_handlers={
'state': ocp.StandardCheckpointHandler(),
'extra_metadata': ocp.JsonCheckpointHandler(),
},
) as new_mngr:
print(new_mngr.restore(0))
NOTE:
CheckpointManager.item_metadata(step) doesn't support any input like args in restore(..., args=...).
So, item_handlers is the only option available with item_metadata(step) calls.
# item_handlers becomes even more critical with item_metadata() calls.
with ocp.CheckpointManager(
'/tmp/ckpt4/',
options=options,
item_names=('state', 'extra_metadata'),
) as new_mngr:
new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandlers
with ocp.CheckpointManager(
'/tmp/ckpt4/',
options=options,
item_handlers={
'state': ocp.StandardCheckpointHandler(),
'extra_metadata': ocp.JsonCheckpointHandler(),
},
) as new_mngr:
print(new_mngr.item_metadata(0))