f = """ path/to/my/checkpoint/dir/ 0/ state/ layer0.param0/ .zarray 0.0 0.1 1.0 1.1 layer1.param0/ .zarray 0.0 ... / ... 1/ ... 2/ ... Note: in this case, `0.0`, `0.1`, etc. provides an indication of how the array was sharded when originally saved. """ f = """ path/to/my/checkpoint/dir/ 0/ state/ checkpoint # legacy msgpack file, stores tree structure tree_metadata # (maybe) new proto file, stores tree structure d/ # array data stored here 012b2c6e5c9d2a16c240a59d5f0f35c0 056e0816bdc5496a86251e58a0ec202b ... manifest.0000000000000001 ... manifest.ocdbt / ... 1/ ... 2/ ... """ import jax import tempfile import subprocess import os from etils import epath import orbax.checkpoint as ocp # Initialize PyTreeCheckpointHandler with `use_ocdbt=True`. # This option already defaults to True, so it's optional to pass it in. ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler(use_ocdbt=True)) # setup checkpoint data array_len = 8 * 1024 key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) pytree = { 'a': jax.random.normal(subkey, (array_len, ), dtype=jax.numpy.float32), # 32KB 'b': jax.random.normal(subkey, (array_len * 2, ), dtype=jax.numpy.float32), # 64KB } # create save_args to customize the chunk_byte_size save_args = jax.tree_util.tree_map( lambda x: ocp.SaveArgs( chunk_byte_size= 1024, # 1KB ), pytree, ) temp_dir = tempfile.TemporaryDirectory() mgr = ocp.CheckpointManager(epath.Path(temp_dir.name), item_handlers=ocp.PyTreeCheckpointHandler(use_zarr3=True)) # make sure zarr3 is enabled mgr.save( 0, args=ocp.args.PyTreeSave( pytree, save_args=save_args, ), ) mgr.close() def print_directory_file_size(dir: epath.Path) -> None: print(f"dir={dir}:") for f in data_dir.iterdir(): if f.is_file(): print(f"file={f.name}, size={f.stat().length}") # continue from above example, examine the data file sizes data_dir = epath.Path(temp_dir.name) / '0'/ 'default'/ 'ocdbt.process_0'/ 'd' print_directory_file_size(data_dir) temp_dir = tempfile.TemporaryDirectory() mgr = ocp.CheckpointManager(temp_dir.name, item_handlers=ocp.PyTreeCheckpointHandler(use_zarr3=True)) mgr.save( 0, args=ocp.args.PyTreeSave( pytree, save_args=save_args, ocdbt_target_data_file_size=10 * 1024, #10 KB, should be much larger than chunk_byte_size ), ) mgr.close() data_dir = epath.Path(temp_dir.name) / '0'/ 'default'/ 'ocdbt.process_0'/ 'd' print_directory_file_size(data_dir)