import numpy as np import orbax.checkpoint as ocp import jax path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/') my_tree = { 'a': np.arange(8), 'b': { 'c': 42, 'd': np.arange(16), }, } abstract_my_tree = jax.tree_util.tree_map( ocp.utils.to_shape_dtype_struct, my_tree) checkpointer = ocp.StandardCheckpointer() # 'checkpoint_name' must not already exist. checkpointer.save(path / 'checkpoint_name', my_tree) checkpointer.restore( path / 'checkpoint_name/', args=ocp.args.StandardRestore(abstract_my_tree) ) checkpointer.metadata(path / 'checkpoint_name') metadata = { 'version': 1.0, 'lang': 'en', } checkpointer = ocp.Checkpointer( ocp.CompositeCheckpointHandler('state', 'metadata') ) checkpointer.save( path / 'composite_checkpoint', args=ocp.args.Composite( state=ocp.args.StandardSave(my_tree), metadata=ocp.args.JsonSave(metadata), ), ) restored = checkpointer.restore(path / 'composite_checkpoint') restored.state restored.metadata list((path / 'composite_checkpoint').iterdir()) path = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint_manager') state = { 'a': np.arange(8), 'b': np.arange(16), } extra_params = [42, 43] # Keeps a maximum of 3 checkpoints, and only saves every other step. options = ocp.CheckpointManagerOptions(max_to_keep=3, save_interval_steps=2) mngr = ocp.CheckpointManager( path, options=options, item_names=('state', 'extra_params') ) for step in range(11): # [0, 1, ..., 10] mngr.save( step, args=ocp.args.Composite( state=ocp.args.StandardSave(state), extra_params=ocp.args.JsonSave(extra_params), ), ) mngr.wait_until_finished() restored = mngr.restore(10) restored_state, restored_extra_params = restored.state, restored.extra_params mngr.all_steps() mngr.latest_step() mngr.should_save(11) import jax path = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint_manager_sharded') sharding = jax.sharding.NamedSharding( jax.sharding.Mesh(jax.devices(), ('model',)), jax.sharding.PartitionSpec( 'model', ), ) create_sharded_array = lambda x: jax.device_put(x, sharding) train_state = { 'a': np.arange(16), 'b': np.ones(16), } train_state = jax.tree_util.tree_map(create_sharded_array, train_state) jax.tree_util.tree_map(lambda x: x.sharding, train_state) num_steps = 10 options = ocp.CheckpointManagerOptions(max_to_keep=3, save_interval_steps=2) mngr = ocp.CheckpointManager(path, options=options) @jax.jit def train_fn(state): return jax.tree_util.tree_map(lambda x: x + 1, state) for step in range(num_steps): train_state = train_fn(train_state) mngr.save(step, args=ocp.args.StandardSave(train_state)) mngr.wait_until_finished() mngr.restore(mngr.latest_step()) train_state = jax.tree_util.tree_map(np.zeros_like, train_state) sharding = jax.sharding.NamedSharding( jax.sharding.Mesh(jax.devices(), ('model',)), jax.sharding.PartitionSpec( None, ), ) create_sharded_array = lambda x: jax.device_put(x, sharding) train_state = jax.tree_util.tree_map(create_sharded_array, train_state) abstract_train_state = jax.tree_util.tree_map( ocp.utils.to_shape_dtype_struct, train_state ) restored = mngr.restore( mngr.latest_step(), args=ocp.args.StandardRestore(abstract_train_state), ) restored jax.tree_util.tree_map(lambda x: x.sharding, restored)