import orbax.checkpoint as ocp # Dummy PyTrees for simplicity. # In reality, this would be a tree of np.ndarray or jax.Array. pytree = {'a': 0} # In reality, this would be a tree of jax.ShapeDtypeStruct (metadata # for restoration). abstract_pytree = {'a': 0} extra_metadata = {'version': 1.0} options = ocp.CheckpointManagerOptions() mngr = ocp.CheckpointManager( ocp.test_utils.erase_and_create_empty('/tmp/ckpt1/'), ocp.Checkpointer(ocp.PyTreeCheckpointHandler()), options=options, ) restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_pytree) mngr.save(0, pytree) mngr.wait_until_finished() mngr.restore( 0, items=abstract_pytree, restore_kwargs={'restore_args': restore_args} ) options = ocp.CheckpointManagerOptions() with ocp.CheckpointManager( ocp.test_utils.erase_and_create_empty('/tmp/ckpt2/'), options=options, ) as mngr: mngr.save(0, args=ocp.args.StandardSave(pytree)) mngr.wait_until_finished() # After providing `args` during an initial `save` or `restore` call, the # `CheckpointManager` instance records the type so that you do not need to # specify it again. If the `CheckpointManager` instance is not provided with a # `ocp.args.CheckpointArgs` instance for a particular item on a previous # occasion it cannot be restored without specifying the argument at restore # time. # In many cases, you can restore exactly as saved without specifying additional # arguments. mngr.restore(0) # If customization of properties like sharding or dtype is desired, just provide # the abstract target PyTree, the properties of which will be used to set # the properties of the restored arrays. mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree)) # Unmapped CheckpointHandlers on a new CheckpointManager instance. new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options) new_mngr.restore(0) # Raises error due to unmapped CheckpointHandler new_mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree)) new_mngr.close() # item_handlers can be used as an alternative to restore(..., args=...). with ocp.CheckpointManager( '/tmp/ckpt2/', options=options, item_handlers=ocp.StandardCheckpointHandler() ) as new_mngr: print(new_mngr.restore(0)) # item_handlers becomes even more critical with item_metadata() calls. new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options) new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandler new_mngr = ocp.CheckpointManager( '/tmp/ckpt2/', options=options, item_handlers=ocp.StandardCheckpointHandler(), ) new_mngr.item_metadata(0) new_mngr.close() options = ocp.CheckpointManagerOptions() mngr = ocp.CheckpointManager( ocp.test_utils.erase_and_create_empty('/tmp/ckpt3/'), { 'state': ocp.Checkpointer(ocp.PyTreeCheckpointHandler()), 'extra_metadata': ocp.Checkpointer(ocp.JsonCheckpointHandler()) }, options=options, ) restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_pytree) mngr.save(0, {'state': pytree, 'extra_metadata': extra_metadata}) mngr.wait_until_finished() mngr.restore( 0, items={'state': abstract_pytree, 'extra_metadata': None}, restore_kwargs={ 'state': {'restore_args': restore_args}, 'extra_metadata': None }, ) options = ocp.CheckpointManagerOptions() mngr = ocp.CheckpointManager( ocp.test_utils.erase_and_create_empty('/tmp/ckpt4/'), # `item_names` defines an up-front contract about what items the # CheckpointManager will be dealing with. item_names=('state', 'extra_metadata'), options=options, ) mngr.save(0, args=ocp.args.Composite( state=ocp.args.StandardSave(pytree), extra_metadata=ocp.args.JsonSave(extra_metadata)) ) mngr.wait_until_finished() # Restore as saved mngr.restore(0) # Restore with customization. Restore a subset of items. mngr.restore(0, args=ocp.args.Composite( state=ocp.args.StandardRestore(abstract_pytree))) mngr.close() # Unmapped CheckpointHandlers on a new CheckpointManager instance. new_mngr = ocp.CheckpointManager( '/tmp/ckpt4/', options=options, item_names=('state', 'extra_metadata'), ) new_mngr.restore(0) # Raises error due to unmapped CheckpointHandlers new_mngr.restore( 0, args=ocp.args.Composite( state=ocp.args.StandardRestore(abstract_pytree), extra_metadata=ocp.args.JsonRestore(), ), ) new_mngr.close() # item_handlers can be used as an alternative to restore(..., args=...). with ocp.CheckpointManager( '/tmp/ckpt4/', options=options, item_handlers={ 'state': ocp.StandardCheckpointHandler(), 'extra_metadata': ocp.JsonCheckpointHandler(), }, ) as new_mngr: print(new_mngr.restore(0)) # item_handlers becomes even more critical with item_metadata() calls. with ocp.CheckpointManager( '/tmp/ckpt4/', options=options, item_names=('state', 'extra_metadata'), ) as new_mngr: new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandlers with ocp.CheckpointManager( '/tmp/ckpt4/', options=options, item_handlers={ 'state': ocp.StandardCheckpointHandler(), 'extra_metadata': ocp.JsonCheckpointHandler(), }, ) as new_mngr: print(new_mngr.item_metadata(0))