import jax
import numpy as np
from etils import epath
import orbax.checkpoint as ocp
import tensorstore as ts
import collections
import operator
import asyncio
state = {
'a': {
'x': np.arange(2 ** 24),
'y': np.arange(1024),
},
'b': np.ones(8),
'c': 42,
}
default_param_name = 'a.x'
default_path = epath.Path('/tmp/checkpoint')
if default_path.exists():
default_path.rmtree()
with ocp.StandardCheckpointer() as ckptr:
ckptr.save(default_path, state)
This is the actual size of the checkpoint on disk.
path = "" # @param {type:"string"}
path = default_path or epath.Path(path)
async def disk_usage(path: epath.Path) -> int:
"""Returns the size of the checkpoint on disk.
Note: this uses recurision because Orbax checkpoint directories are never
more than a few levels deep.
Args:
path: The path to the checkpoint.
Returns:
The size of the checkpoint on disk.
"""
async def helper(p):
if p.is_dir():
return await disk_usage(p)
else:
stat = await ocp.path.async_utils.async_stat(path)
return stat.length
futures = []
for p in path.iterdir():
futures.append(helper(p))
return sum(await asyncio.gather(*futures))
print('{0:0.3f} GB'.format(float(asyncio.run(disk_usage(path))) / 1e9))
Users sometimes run into a problem where the checkpoint size on disk seems much larger or smaller than we would expect based on the model itself. Determining the implied size of the checkpoint based on the checkpoint's own metadata and cross-referencing it against the actual on-disk size can provide some insight.
The actual size on disk is typically expected to be somewhat smaller than the implied size, due to compression.
path = "" # @param {type:"string"}
path = default_path or epath.Path(path)
metadata = ocp.StandardCheckpointer().metadata(path)
size_counts = collections.defaultdict(int)
def get_arr_bytes(meta):
dtype = meta.dtype
shape = meta.shape
size_counts[dtype] += 1
return np.prod(shape) * np.dtype(dtype).itemsize
total_bytes = jax.tree.reduce(operator.add, jax.tree.map(get_arr_bytes, metadata))
print('{0:0.3f} GB'.format(float(total_bytes) / 1e9))
print()
print('leaf dtype counts:')
for dtype, count in size_counts.items():
print(f'{dtype}: {count}')
Inspecting the tree structure of the checkpoint is crucial, as it allows you to verify that the parameters present in the checkpoint are correct, to say nothing of the array metadata associated with the parameter.
The following can be useful when debugging errors where the loading code was searching for a particular parameter that was not found. A few things could be going wrong here:
path = "" # @param {type:"string"}
path = default_path or epath.Path(path)
metadata = ocp.StandardCheckpointer().metadata(path)
metadata_contents = ['.'.join(k) for k in ocp.tree.to_flat_dict(metadata)]
# Here are the parameters present in the checkpoint tree.
for p in metadata_contents:
print(p)
# Note: instead of "file", use:
# - "gfile" on Google-internal filesystems.
# - "gs" on GCS (do not repeat the "gs://" prefix)
ts_contents = ts.KvStore.open({"driver": "ocdbt", "base": f"file://{path.as_posix()}"}).result().list().result()
ts_contents = [p.decode("utf-8") for p in ts_contents]
ts_contents = [p.replace('.zarray', '')[:-1] for p in ts_contents if '.zarray' in p]
# We can assert that the parameters tracked by the metadata file are
# the same as those tracked by Tensorstore. If there is a discrepancy, there may
# be a deeper underlying problem.
assert len(metadata_contents) == len(ts_contents) and sorted(metadata_contents) == sorted(ts_contents)
path = "" # @param {type:"string"}
# The `param_name` can be obtained by inspecting tree metadata (see above).
param_name = "" # @param {type:"string"}
path = default_path or epath.Path(path)
param_name = default_param_name or param_name
metadata = ocp.StandardCheckpointer().metadata(path)
value_metadata = {'.'.join(k): v for k, v in ocp.tree.to_flat_dict(metadata).items()}[param_name]
print(f'shape: {value_metadata.shape}')
print(f'dtype: {value_metadata.dtype}')
It can often be helpful to check the raw value of a particular parameter as saved in the checkpoint. This can be done to establish the correctness of a parameter as saved, to eliminate any possibility that saving was done incorrectly for the given parameter (or that the checkpoint has been corrupted). This can help confine the bounds of debugging to restoration.
CAUTION: The read below loads the entire array into memory. For very large arrays, this could result in OOM. To load a smaller slice of the array, simply index into the TensorStore object (t), like this: t[:2, :4].read().result().
ParamInfo = ocp.type_handlers.ParamInfo
ts_context = ts.Context({
'file_io_concurrency': {'limit': 128},
'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
})
info = ParamInfo(name=param_name, path=path / param_name, parent_dir=path, is_ocdbt_checkpoint=True, use_zarr3=False)
tspec = ocp.type_handlers.get_json_tspec_read(info, use_ocdbt=True)
t = ts.open(ts.Spec(tspec), open=True, context=ts_context).result()
arr = t.read().result()
print(arr)