import numpy as np import orbax.checkpoint as ocp import jax sharding = jax.sharding.NamedSharding( jax.sharding.Mesh(jax.devices(), ('model',)), jax.sharding.PartitionSpec( 'model', ), ) create_sharded_array = lambda x: jax.device_put(x, sharding) state = { 'a': np.arange(16), 'b': np.ones(16), } state = jax.tree_util.tree_map(create_sharded_array, state) abstract_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state) path = ocp.test_utils.erase_and_create_empty('/tmp/basic/') # Make sure to use async for improved performance! ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler()) ckptr.save(path / '1', args=ocp.args.StandardSave(state)) ckptr.wait_until_finished() restored = ckptr.restore(path / '1', args=ocp.args.StandardRestore(abstract_state)) restored restored['a'].sharding ckptr.restore(path / '1', args=ocp.args.StandardRestore(state)) def set_restore_dtype(x: jax.ShapeDtypeStruct) -> jax.ShapeDtypeStruct: x.dtype = np.int16 return x cast_dtype_abstract_state = jax.tree_util.tree_map( set_restore_dtype, abstract_state) ckptr.restore( path / '1', args=ocp.args.StandardRestore(cast_dtype_abstract_state), ) different_shape_abstract_state = { 'a': jax.ShapeDtypeStruct( shape=(8,), dtype=abstract_state['a'].dtype, sharding=abstract_state['a'].sharding ), 'b': jax.ShapeDtypeStruct( shape=(32,), dtype=abstract_state['b'].dtype, sharding=abstract_state['b'].sharding ), } ckptr.restore( path / '1', args=ocp.args.StandardRestore(different_shape_abstract_state), ) restored = ckptr.restore(path / '1') restored['a'].sharding sharding = jax.sharding.NamedSharding( jax.sharding.Mesh(jax.devices(), ('x',)), jax.sharding.PartitionSpec(), ) def set_sharding(x: jax.ShapeDtypeStruct) -> jax.ShapeDtypeStruct: x.sharding = sharding return x change_sharding_abstract_state = jax.tree_util.tree_map( set_sharding, abstract_state) restored = ckptr.restore( path / '1', args=ocp.args.StandardRestore(change_sharding_abstract_state), ) restored['a'].sharding path = ocp.test_utils.erase_and_create_empty('/tmp/advanced/') # Make sure to use async for improved performance! ckptr = ocp.AsyncCheckpointer(ocp.PyTreeCheckpointHandler()) ckptr.save(path / '1', args=ocp.args.PyTreeSave(state)) ckptr.wait_until_finished() restored = ckptr.restore(path / '1') restored['a'].dtype restored['b'].dtype ckptr.save( path / '2', args=ocp.args.PyTreeSave( state, save_args={ # We must set one ocp.SaveArgs per leaf. 'a': ocp.SaveArgs(dtype=np.int16), 'b': ocp.SaveArgs() } ), ) ckptr.wait_until_finished() restored = ckptr.restore(path / '2') restored['a'].dtype restored['b'].dtype ckptr.restore( path / '2', args=ocp.args.PyTreeRestore( restore_args={ # RestoreArgs is the parent class for ArrayRestoreArgs. # We must set one RestoreArgs per leaf. 'a': ocp.RestoreArgs(restore_type=np.ndarray), 'b': ocp.ArrayRestoreArgs(dtype=np.int16, sharding=sharding) } ), ) ckptr.restore( path / '2', args=ocp.args.PyTreeRestore( # `item` serves as a guide to what the result tree structure should look # like. item={ # Value doesn't really matter here, as long as it's not None. 'c': ..., # Can add in extra keys. 'd': np.arange(8) }, # `restore_args` must be relative to the result tree, not the # checkpoint. restore_args={ 'c': ocp.RestoreArgs(restore_type=np.ndarray), }, transforms={ 'c': ocp.Transform(original_key='a') }, ), )