This page is relevant if your model state contains custom leaves in a PyTree, or doesn't use PyTree at all.
If your model uses PyTree but has custom leaves, read the TypeHandler section to see how register the custom type with PyTreeCheckpointHandler.
If your model doesn't use PyTree or if you want to implement different serialization/deserialization logic, skip to the CheckpointHandler section.
If you're running this guide in a notebook, make sure to run this cell first.
import asyncio
from concurrent import futures
from dataclasses import dataclass
import functools
import os
import time
from typing import Any, List, Optional, Sequence
from etils import epath
import numpy as np
import orbax.checkpoint as ocp
ParamInfo = ocp.type_handlers.ParamInfo
Metadata = ocp.metadata.Metadata
PyTreeCheckpointHandler walks through the input PyTree and uses registered TypeHandlers to serialize/deserialize the leaves. If your custom model state is stored within the leaves of a PyTree, implement a TypeHandler and use it with PyTreeCheckpointHandler.
Standard TypeHandlers
Orbax includes pre-defined TypeHandlers for saving certain types:
ArrayHandler: jax.ArrayNumpyHandler: np.ndarrayScalarHandler: int, floatStringHandler: strThese default implementations all use Tensorstore to serialize and deserialize data except for StringHandler which serializes to JSON.
To implement a custom TypeHandler, we must define the async serialize and deserialize methods (the section "Async vs Non-Async" lists reasons why these methods should be asynchronous). The new TypeHandler is then registered so that the PyTreeCheckpointHandler knows to use this handler when there is a MyState leaf in the PyTree.
The inputs to the TypeHandler are batched to allow for performance optimizations in certain cases. PyTreeCheckpointHandler groups all leaves of the same type and dispatches them all in one-per-type batch.
The example below defines a TypeHandler for a custom dataclass that stores multiple numpy arrays.
@dataclass
class MyState:
a: np.array
b: np.array
# Make sure to only run this cell once, otherwise a new `MyState` dataclass will
# be created which could mess up Python issubclass/isinstance checks.
Here is a possible TypeHandler implementation for MyState:
class MyStateHandler(ocp.type_handlers.TypeHandler):
"""Serializes MyState to the numpy npz format."""
def __init__(self):
self._executor = futures.ThreadPoolExecutor(max_workers=1)
def typestr(self) -> str:
return 'MyState'
async def serialize(
self,
values: Sequence[MyState],
infos: Sequence[ParamInfo],
args: Optional[Sequence[ocp.SaveArgs]],
) -> List[futures.Future]:
del args # Unused in this example.
futures = []
for value, info in zip(values, infos):
# make sure the per-key directory is present as OCDBT doesn't create one
info.path.mkdir(exist_ok=True)
futures.append(
self._executor.submit(
functools.partial(_write_state, value, info.path)
)
)
return futures
async def deserialize(
self,
infos: Sequence[ParamInfo],
args: Optional[Sequence[ocp.RestoreArgs]] = None,
) -> MyState:
del args # Unused in this example.
futures = []
for info in infos:
futures.append(
await asyncio.get_event_loop().run_in_executor(
self._executor, functools.partial(_from_state, info.path)
)
)
return await asyncio.gather(*futures)
async def metadata(self, infos: Sequence[ParamInfo]) -> Sequence[Metadata]:
# This method is explained in a separate section.
return [Metadata(name=info.name, directory=info.path) for info in infos]
def _write_state(state: MyState, path: epath.Path) -> str:
path = path / 'my_state.npz'
np.savez(path, a=state.a, b=state.b)
return path
async def _from_state(path: epath.Path) -> MyState:
data = np.load(path / 'my_state.npz')
return MyState(a=data['a'], b=data['b'])
ocp.type_handlers.register_type_handler(
MyState, MyStateHandler(), override=True
)
assert ocp.type_handlers.has_type_handler(MyState)
Here is MyStateHandler in action:
my_tree = {
'state': {'a': np.array([1, 2, 3]), 'b': np.array([4, 5, 6])},
'my_state': MyState(a=np.array([10, 20, 30]), b=np.array([40, 50, 60])),
}
checkpointer = ocp.Checkpointer(
ocp.PyTreeCheckpointHandler()
)
path = epath.Path('/tmp/my_checkpoints/')
# Clear older checkpoints from directory.
# Checkpointer.save will fail if path already exists, unless `force=True`
if path.exists():
path.rmtree()
path.mkdir()
checkpointer.save(path / 'my_tree', my_tree)
!echo "Files in path:" $(ls /tmp/my_checkpoints)
!echo "Files in 'my_tree':" $(ls /tmp/my_checkpoints/my_tree)
!echo "Files in 'my_tree/my_state':" $(ls /tmp/my_checkpoints/my_tree/my_state)
checkpointer.restore(path / 'my_tree')
The metadata() method is used for inspecting existing checkpoints and is generally implemented to be less costly than a full restore. Some example use cases are determining whether the restored values can fit in the available memory, getting the checkpointed PyTree structure to extract specific subtrees, or validating whether the shapes and dtypes of the values match with your model data.
In the previous example, MyStateHandler returned the default Metadata() object since the TypeHandler interface requires it. However, we recommend completing this implementation especially if the custom type targets general users.
# 'my_state' returns a default Metadata object.
checkpointer.metadata(path / 'my_tree')
Example implementation of MyStateHandler.metadata:
# Define a metadata class.
class MyStateMetadata(Metadata):
def __init__(
self,
a_shape: np.shape,
b_shape: np.shape,
**kwargs,
):
super().__init__(**kwargs)
self.a_shape = a_shape
self.b_shape = b_shape
class MyStateHandlerWithMetdata(MyStateHandler):
async def metadata(
self, infos: Sequence[ParamInfo]
) -> ocp.metadata.Metadata:
metadata = []
for info in infos:
metadata.append(
await asyncio.get_event_loop().run_in_executor(
self._executor, functools.partial(_read_metadata, info)
)
)
return await asyncio.gather(*metadata)
async def _read_metadata(info: ParamInfo) -> MyStateMetadata:
# This function reads the entire state, but can be more optimally defined
# by reading the header from the npz file. Another option is collectively
# gathering all of the metadata info during serialization, and writing it to
# a file. Since metadata is generally pretty small, it's better to write
# to a single file rather than one for each value.
result = await _from_state(info.path)
return MyStateMetadata(
a_shape=result.a.shape,
b_shape=result.b.shape,
name='my_state',
directory=info.path,
)
ocp.type_handlers.register_type_handler(
MyState, MyStateHandlerWithMetdata(), override=True
)
Now check the metadata, the PyTree should now contain MyStateMetadata.
checkpointer = ocp.PyTreeCheckpointer()
checkpointer.metadata(path / 'my_tree')
In this example, we didn't need to re-save the checkpoint using the newly registered MyStateHandlerWithMetdata TypeHandler, because the class doesn't write new files into the checkpoint.
If your state is not stored within a PyTree, or if you'd like to customize more aspects of checkpointing, implement CheckpointHandler. CheckpointHandlers operate on the entire object so you have a lot of flexibility on how to save and restore the object.
As of orbax-checkpoint-0.5.0, CheckpointHandler API has changed. This page shows a side-by-side comparison of the old and new APIs.
The legacy APIs are deprecated. Please ensure you are using the new style.
Example
Serializing the same dataclass used in the TypeHandler example:
@dataclass
class MyState:
a: np.array
b: np.array
state = MyState(a=np.array([1.0, 1.5]), b=np.array([3, 4, 5]))
import glob
import json
class LegacyMyStateCheckpointHandler(ocp.CheckpointHandler):
def save(
self,
directory: epath.Path,
item: MyState,
# You can define any argument here:
use_npz=True,
**kwargs,
):
if use_npz:
np.savez(directory / 'my_state.npz', a=item.a, b=item.b)
else:
with open(os.path.join(directory, 'my_state.json'), 'w') as f:
f.write(json.dumps(dict(a=state.a.tolist(), b=state.b.tolist())))
def restore(
self,
directory: epath.Path,
item: Optional[Any] = None,
# You can define any argument here as well.
restore_as_dict=False,
**kwargs,
) -> Any:
state_file = glob.glob(os.fspath(directory / '*.*'))[0]
if state_file == 'my_state.npz':
data = np.load(directory / 'my_state.npz')
else:
with open(state_file, 'r') as f:
data = json.load(f)
data['a'] = np.array(data['a'])
data['b'] = np.array(data['b'])
if restore_as_dict:
return dict(a=data['a'], b=data['b'])
return MyState(a=data['a'], b=data['b'])
def metadata(self, directory: epath.Path) -> Optional[Any]:
"""Returns metadata about the saved item."""
# In this example, the State is restored entirely, but this can be
# optimized. For example, but writing a `metadata` file in `self.save()`,
# and reading the file in this method.
result = self.restore(directory)
return MyStateMetadata(
a_shape=result.a.shape,
b_shape=result.b.shape,
name='my_state',
directory=directory / 'my_state',
)
import glob
import json
class MyStateCheckpointHandler(ocp.CheckpointHandler):
def save(
self,
directory: epath.Path,
args: 'MyStateSave',
):
if args.use_npz:
np.savez(directory / 'my_state.npz', a=args.item.a, b=args.item.b)
else:
with open(os.path.join(directory, 'my_state.json'), 'w') as f:
f.write(
json.dumps(dict(a=args.item.a.tolist(), b=args.item.b.tolist()))
)
def restore(
self,
directory: epath.Path,
args: 'MyStateRestore',
) -> Any:
state_file = glob.glob(os.fspath(directory / '*.*'))[0]
if state_file == 'my_state.npz':
data = np.load(directory / 'my_state.npz')
else:
with open(state_file, 'r') as f:
data = json.load(f)
data['a'] = np.array(data['a'])
data['b'] = np.array(data['b'])
if args.restore_as_dict:
return dict(a=data['a'], b=data['b'])
return MyState(a=data['a'], b=data['b'])
def metadata(self, directory: epath.Path) -> Optional[Any]:
"""Returns metadata about the saved item."""
# In this example, the State is restored entirely, but this can be
# optimized. For example, but writing a `metadata` file in `self.save()`,
# and reading the file in this method.
result = self.restore(directory, args=MyStateRestore())
return MyStateMetadata(
a_shape=result.a.shape,
b_shape=result.b.shape,
name='my_state',
directory=directory / 'my_state',
)
@ocp.args.register_with_handler(MyStateCheckpointHandler, for_save=True)
@dataclass
class MyStateSave(ocp.args.CheckpointArgs):
item: MyState
use_npz: bool = True
@ocp.args.register_with_handler(MyStateCheckpointHandler, for_restore=True)
@dataclass
class MyStateRestore(ocp.args.CheckpointArgs):
restore_as_dict: bool = False
These classes can be passed to create a new Checkpointer, which can be used to save or restore a new checkpoint.
legacy_path2 = epath.Path('/tmp/legacy-checkpoint-handler-example/')
legacy_checkpointer = ocp.Checkpointer(LegacyMyStateCheckpointHandler())
if legacy_path2.exists():
legacy_path2.rmtree()
legacy_path2.mkdir()
legacy_checkpointer.save(legacy_path2 / 'state', state, use_npz=False)
!echo "Files in legacy checkpoint path:" $(ls /tmp/legacy-checkpoint-handler-example/)
!echo "Files in legacy 'state' directory:" $(ls /tmp/legacy-checkpoint-handler-example/state)
print('restored state: ', legacy_checkpointer.restore(legacy_path2 / 'state'))
print('restored state as dict: ', legacy_checkpointer.restore(legacy_path2 / 'state', restore_as_dict=True))
print('metadata:', legacy_checkpointer.metadata(legacy_path2 / 'state'))
path2 = epath.Path('/tmp/checkpoint-handler-example/')
checkpointer = ocp.Checkpointer(MyStateCheckpointHandler())
if path2.exists():
path2.rmtree()
path2.mkdir()
checkpointer.save(path2 / 'state', args=MyStateSave(item=state, use_npz=False))
!echo "Files in checkpoint path:" $(ls /tmp/checkpoint-handler-example/)
!echo "Files in 'state' directory:" $(ls /tmp/checkpoint-handler-example/state)
print('restored state: ', checkpointer.restore(path2 / 'state', args=MyStateRestore()))
print('restored state as dict: ', checkpointer.restore(path2 / 'state', args=MyStateRestore(restore_as_dict=True)))
print('metadata:',checkpointer.metadata(path2 / 'state'))
Asynchronous checkpointing allows training to proceed during the I/O, which prevents expensive computational resources from stalling during the CPU writes. When possible, we highly recommend implementing async handlers.
Async saving can be implemented by copying data to the corresponding worker CPU (if necessary), then parallelizing the writing tasks (e.g. by using the await keyword).
TypeHandler deserialization should be defined using async to allow multiple objects to be deserialized at a time.
The AsyncCheckpointHandler interface adds a new async_save abstract method, and should be used with AsyncCheckpointer to write checkpoints asynchronously.
Note that in the new style, AsyncCheckpointHandler's save() and async_save() methods work on args instead of the legacy item etc arguments. Also, the args type needs to be registered against the AsyncCheckpointHandler concrete class.
Example
class MyStateAsyncCheckpointHandler(ocp.AsyncCheckpointHandler, MyStateCheckpointHandler):
def __init__(self):
self._executor = futures.ThreadPoolExecutor(max_workers=1)
def save(self, directory: epath.Path, args: MyStateSave):
time.sleep(.5) # Artificially inflate the time spent in this method.
super().save(directory, args)
async def async_save(self, directory: epath.Path, args: MyStateSave):
return [self._executor.submit(functools.partial(
self.save, directory, args))]
def close(self):
self._executor.shutdown()
# Register MyStateAsyncCheckpointHandler for MyStateSave and MyStateRestore.
# NOTE: This registration will overwrite the previous one with MyStateCheckpointHandler.
# It is just for illustrating this example and should be avoided in real world systems.
ocp.args.register_with_handler(MyStateAsyncCheckpointHandler, for_save=True)(MyStateSave)
ocp.args.register_with_handler(MyStateAsyncCheckpointHandler, for_restore=True)(MyStateRestore)
path3 = epath.Path('/tmp/checkpoint-handler-async/')
if path3.exists():
path3.rmtree()
path3.mkdir()
async_checkpointer = ocp.AsyncCheckpointer(MyStateAsyncCheckpointHandler())
async_checkpointer.save(path3 / 'async-state', args=MyStateSave(item=state))
!echo "directory contents: "; ls /tmp/checkpoint-handler-async/
After the write is complete, the tmp folder is renamed to just async_state.
async_checkpointer.wait_until_finished()
async_checkpointer.close()
!ls /tmp/checkpoint-handler-async/
!ls /tmp/checkpoint-handler-async/async-state