from etils import epath import jax import numpy as np import orbax.checkpoint as ocp sharding = jax.sharding.NamedSharding( jax.sharding.Mesh(jax.devices(), ('model',)), jax.sharding.PartitionSpec( 'model', ), ) create_sharded_array = lambda x: jax.device_put(x, sharding) state = { 'a': np.arange(16), 'b': np.ones(16), } state = jax.tree_util.tree_map(create_sharded_array, state) abstract_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state) state['c'] = np.arange(4) state['d'] = 5 state['e'] = 'foo' state def print_directory(directory: epath.PathLike, level: int = 0): """Prints a directory tree for debugging purposes.""" directory = epath.Path(directory) assert directory.is_dir() level_str = '..' * level if level == 0: print(f'Printing directory tree: {directory}/') else: print(f'{level_str}{directory.name}/') level_str = '..' * (level + 1) for p in directory.iterdir(): if p.is_dir(): print_directory(p, level=level + 1) else: print(f'{level_str}{p.name}') path = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint') global_metadata = {'global_property': 'foo'} with ocp.CheckpointManager( path, item_names=('state', 'custom_data'), metadata=global_metadata ) as mngr: mngr.save( 0, args=ocp.args.Composite( state=ocp.args.PyTreeSave(state), custom_data=ocp.args.JsonSave({'lang': 'en', 'version': 1.2}), ), ) print_directory(path) print_directory(path / '0' / 'state') import json json.loads((path / '0' / 'state' / '_METADATA').read_text()) from etils import epath import jax import tensorstore as ts ts_context = ts.Context( { # Provide cache pool for B-tree nodes to avoid repeated reads. # 100MB limit. 'cache_pool#ocdbt': {'total_bytes_limit': 100000000}, }, parent=jax.experimental.array_serialization.serialization.TS_CONTEXT, ) ParamInfo = ocp.type_handlers.ParamInfo state_dir = path / '0' / 'state' param_name = 'a' param_path = state_dir / param_name info = ParamInfo(name='a', path=path, parent_dir=state_dir, is_ocdbt_checkpoint=True, use_zarr3=True) tspec = ocp.type_handlers.get_json_tspec_read(info, use_ocdbt=True) tspec ts.KvStore.open({"driver": "ocdbt", "base": "file:///tmp/checkpoint/0/state/"}).result().list().result() tspec = {'driver': 'zarr', 'kvstore': {'driver': 'ocdbt', 'base': 'file:///tmp/checkpoint/0/state/', 'path': 'a'}} t = ts.open(ts.Spec(tspec), open=True, context=ts_context).result() result = t.read().result() result