It is important to have an understanding of how Orbax structures checkpoints on disk, particularly if you ever need to debug at the checkpoint level, or if you wish to work with specific pieces of a larger checkpoint.
First, some setup:
from etils import epath
import jax
import numpy as np
import orbax.checkpoint as ocp
sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ('model',)),
jax.sharding.PartitionSpec(
'model',
),
)
create_sharded_array = lambda x: jax.device_put(x, sharding)
state = {
'a': np.arange(16),
'b': np.ones(16),
}
state = jax.tree_util.tree_map(create_sharded_array, state)
abstract_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)
state['c'] = np.arange(4)
state['d'] = 5
state['e'] = 'foo'
state
def print_directory(directory: epath.PathLike, level: int = 0):
"""Prints a directory tree for debugging purposes."""
directory = epath.Path(directory)
assert directory.is_dir()
level_str = '..' * level
if level == 0:
print(f'Printing directory tree: {directory}/')
else:
print(f'{level_str}{directory.name}/')
level_str = '..' * (level + 1)
for p in directory.iterdir():
if p.is_dir():
print_directory(p, level=level + 1)
else:
print(f'{level_str}{p.name}')
We will start by creating a checkpoint for step 0, consisting of two items:
state and metadata.
path = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint')
global_metadata = {'global_property': 'foo'}
with ocp.CheckpointManager(
path, item_names=('state', 'custom_data'), metadata=global_metadata
) as mngr:
mngr.save(
0,
args=ocp.args.Composite(
state=ocp.args.PyTreeSave(state),
custom_data=ocp.args.JsonSave({'lang': 'en', 'version': 1.2}),
),
)
print_directory(path)
Let's understand each of these pieces separately.
The "root directory" is understood to be the directory provided when creating a
CheckpointManager. It represents the parent directory where all "sequential"
checkpoints will reside (see below). In the above example, this corresponds to
/tmp/checkpoint/.
Within the root directory, aside from the sequential checkpoints, there may also
be a metadata subdirectory (if metadata was provided when configuring the
CheckpointManager).
With the term "sequential checkpoint", we refer to a checkpoint that represents
a particular step in a longer sequence. Typically, in Orbax, this is simply
denote with a directory named with an integer value (0/ in the above example).
However, options are available to
customize
the default format.
The sequential checkpoint has a top-level _CHECKPOINT_METADATA file that
stores basic information like the creation timestamp, and other fields.
Within a sequential checkpoint directory, we have subdirectories corresponding
to "items". An "item" represents a logically distinct unit of a larger
checkpoint, so these are naturally represented in separate subdirectories. In
the above example, the items are state and custom_data.
This representation makes compositition easier if you want to combine the
dataset from one checkpoint with the state from another, for instance. It also
prevents collisions if you use the same CheckpointHandler to save both state
and embeddings, for instance.
Below this level, the format is no longer universally standard, because each
CheckpointHandler customizes its own file format.
Because the state item was saved with ocp.args.PyTreeSave (the same would
apply if saved with ocp.args.StandardSave), it takes the following form:
print_directory(path / '0' / 'state')
The _METADATA file provides a complete description of the PyTree structure,
including custom and empty nodes.
The tree is represented as a flattened dictionary, where each key is represented
as a tuple, where successive elements denote successive levels of nesting. For
example, for the dict {'a': {'b': [1, 2]}} the metadata file would contain two
entries with keys ('a', 'b', '0') and ('a', 'b', '1').
Keys at each level of nesting also encode what type they are: i.e. whether they are a dict key or a sequential key.
Finally, metadata about the value type is stored (e.g. jax.Array,
np.ndarray, etc.) in order to allow for later reconstruction without
explicitly requiring the object type to be provided.
import json
json.loads((path / '0' / 'state' / '_METADATA').read_text())
The _sharding file stores information about the shardings originally used when
saving jax.Arrays in the tree. It isn't really human-readable though. To get
information about shardings, use the metadata APIs.
Beyond these metadata files, which are directly managed by Orbax, we also have a
manifest.ocdbt file managed by the TensorStore library. Actual array data is
stored within the d/ subdirectory. Since these files are opaque to human
readers, we will not go into detail on their structure.
Finally, you'll notice the presence of the directory ocdbt.process_0/, which
also has a manifest.ocdbt and its own d/ subdirectory. One such folder
exists for every process on which the checkpoint was saved. This exists because
each process first writes its own data independently to its corresponding
subdirectory.
When all processes have finished, Orbax runs a finalization pass to cheaply merge the metadatas from all per-process subdirectories into a global view (note that this still references data in the original subdirectories).
Sometimes, it is helpful to work directly with the TensorStore API to debug individual parameters in a checkpoint.
from etils import epath
import jax
import tensorstore as ts
ts_context = ts.Context(
{
# Provide cache pool for B-tree nodes to avoid repeated reads.
# 100MB limit.
'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
},
parent=jax.experimental.array_serialization.serialization.TS_CONTEXT,
)
To read using TensorStore, we need to construct a TensorStore Spec. For this, we can use Orbax APIs. The spec points to a base path, as well as a particular parameter name (a in this case). It contains further options related to the checkpoint format.
ParamInfo = ocp.type_handlers.ParamInfo
state_dir = path / '0' / 'state'
param_name = 'a'
param_path = state_dir / param_name
info = ParamInfo(name='a', path=path, parent_dir=state_dir, is_ocdbt_checkpoint=True, use_zarr3=True)
tspec = ocp.type_handlers.get_json_tspec_read(info, use_ocdbt=True)
tspec
We can verify which keys are present in the checkpoint, which matches information we gathered earlier from the Orbax metadata API.
ts.KvStore.open({"driver": "ocdbt", "base": "file:///tmp/checkpoint/0/state/"}).result().list().result()
Finally, we can directly restore the array using TensorStore.
tspec = {'driver': 'zarr', 'kvstore': {'driver': 'ocdbt', 'base': 'file:///tmp/checkpoint/0/state/', 'path': 'a'}}
t = ts.open(ts.Spec(tspec), open=True, context=ts_context).result()
result = t.read().result()
result