import asyncio from concurrent import futures from dataclasses import dataclass import functools import os import time from typing import Any, List, Optional, Sequence from etils import epath import numpy as np import orbax.checkpoint as ocp ParamInfo = ocp.pytree_checkpoint_handler.ParamInfo Metadata = ocp.metadata.Metadata @dataclass class MyState: a: np.array b: np.array # Make sure to only run this cell once, otherwise a new `MyState` dataclass will # be created which could mess up Python issubclass/isinstance checks. class MyStateHandler(ocp.pytree_checkpoint_handler.TypeHandler): """Serializes MyState to the numpy npz format.""" def __init__(self): self._executor = futures.ThreadPoolExecutor(max_workers=1) def typestr(self) -> str: return 'MyState' async def serialize( self, values: Sequence[MyState], infos: Sequence[ParamInfo], args: Optional[Sequence[ocp.SaveArgs]], ) -> List[futures.Future]: del args # Unused in this example. futures = [] for value, info in zip(values, infos): # make sure the per-key directory is present as OCDBT doesn't create one info.path.mkdir(exist_ok=True) futures.append( self._executor.submit( functools.partial(_write_state, value, info.path) ) ) return futures async def deserialize( self, infos: Sequence[ParamInfo], args: Optional[Sequence[ocp.RestoreArgs]] = None, ) -> MyState: del args # Unused in this example. futures = [] for info in infos: futures.append( await asyncio.get_event_loop().run_in_executor( self._executor, functools.partial(_from_state, info.path) ) ) return await asyncio.gather(*futures) async def metadata(self, infos: Sequence[ParamInfo]) -> Sequence[Metadata]: # This method is explained in a separate section. return [Metadata(name=info.name, directory=info.path) for info in infos] def _write_state(state: MyState, path: epath.Path) -> str: path = path / 'my_state.npz' np.savez(path, a=state.a, b=state.b) return path async def _from_state(path: epath.Path) -> MyState: data = np.load(path / 'my_state.npz') return MyState(a=data['a'], b=data['b']) ocp.type_handlers.register_type_handler( MyState, MyStateHandler(), override=True ) assert ocp.type_handlers.has_type_handler(MyState) my_tree = { 'state': {'a': np.array([1, 2, 3]), 'b': np.array([4, 5, 6])}, 'my_state': MyState(a=np.array([10, 20, 30]), b=np.array([40, 50, 60])), } checkpointer = ocp.Checkpointer( ocp.PyTreeCheckpointHandler() ) path = epath.Path('/tmp/my_checkpoints/') # Clear older checkpoints from directory. # Checkpointer.save will fail if path already exists, unless `force=True` if path.exists(): path.rmtree() path.mkdir() checkpointer.save(path / 'my_tree', my_tree) !echo "Files in path:" $(ls /tmp/my_checkpoints) !echo "Files in 'my_tree':" $(ls /tmp/my_checkpoints/my_tree) !echo "Files in 'my_tree/my_state':" $(ls /tmp/my_checkpoints/my_tree/my_state) checkpointer.restore(path / 'my_tree') # 'my_state' returns a default Metadata object. checkpointer.metadata(path / 'my_tree') # Define a metadata class. @dataclass class MyStateMetadata(Metadata): a_shape: np.shape b_shape: np.shape name: str = 'my_state' class MyStateHandlerWithMetdata(MyStateHandler): async def metadata( self, infos: Sequence[ParamInfo] ) -> ocp.value_metadata.Metadata: metadata = [] for info in infos: metadata.append( await asyncio.get_event_loop().run_in_executor( self._executor, functools.partial(_read_metadata, info) ) ) return await asyncio.gather(*metadata) async def _read_metadata(info: ParamInfo) -> MyStateMetadata: # This function reads the entire state, but can be more optimally defined # by reading the header from the npz file. Another option is collectively # gathering all of the metadata info during serialization, and writing it to # a file. Since metadata is generally pretty small, it's better to write # to a single file rather than one for each value. result = await _from_state(info.path) return MyStateMetadata( a_shape=result.a.shape, b_shape=result.b.shape, directory=info.path, ) ocp.type_handlers.register_type_handler( MyState, MyStateHandlerWithMetdata(), override=True ) checkpointer = ocp.PyTreeCheckpointer() checkpointer.metadata(path / 'my_tree') state = MyState(a=np.array([1.0, 1.5]), b=np.array([3, 4, 5])) import glob import json class LegacyMyStateCheckpointHandler(ocp.CheckpointHandler): def save( self, directory: epath.Path, item: MyState, # You can define any argument here: use_npz=True, **kwargs, ): if use_npz: np.savez(directory / 'my_state.npz', a=item.a, b=item.b) else: with open(os.path.join(directory, 'my_state.json'), 'w') as f: f.write(json.dumps(dict(a=state.a.tolist(), b=state.b.tolist()))) def restore( self, directory: epath.Path, item: Optional[Any] = None, # You can define any argument here as well. restore_as_dict=False, **kwargs, ) -> Any: state_file = glob.glob(os.fspath(directory / '*.*'))[0] if state_file == 'my_state.npz': data = np.load(directory / 'my_state.npz') else: with open(state_file, 'r') as f: data = json.load(f) data['a'] = np.array(data['a']) data['b'] = np.array(data['b']) if restore_as_dict: return dict(a=data['a'], b=data['b']) return MyState(a=data['a'], b=data['b']) def metadata(self, directory: epath.Path) -> Optional[Any]: """Returns metadata about the saved item.""" # In this example, the State is restored entirely, but this can be # optimized. For example, but writing a `metadata` file in `self.save()`, # and reading the file in this method. result = self.restore(directory) return MyStateMetadata( a_shape=result.a.shape, b_shape=result.b.shape, directory=directory / 'my_state', ) import glob import json class MyStateCheckpointHandler(ocp.CheckpointHandler): def save( self, directory: epath.Path, args: 'MyStateSave', ): if args.use_npz: np.savez(directory / 'my_state.npz', a=args.item.a, b=args.item.b) else: with open(os.path.join(directory, 'my_state.json'), 'w') as f: f.write( json.dumps(dict(a=args.item.a.tolist(), b=args.item.b.tolist())) ) def restore( self, directory: epath.Path, args: 'MyStateRestore', ) -> Any: state_file = glob.glob(os.fspath(directory / '*.*'))[0] if state_file == 'my_state.npz': data = np.load(directory / 'my_state.npz') else: with open(state_file, 'r') as f: data = json.load(f) data['a'] = np.array(data['a']) data['b'] = np.array(data['b']) if args.restore_as_dict: return dict(a=data['a'], b=data['b']) return MyState(a=data['a'], b=data['b']) def metadata(self, directory: epath.Path) -> Optional[Any]: """Returns metadata about the saved item.""" # In this example, the State is restored entirely, but this can be # optimized. For example, but writing a `metadata` file in `self.save()`, # and reading the file in this method. result = self.restore(directory, args=MyStateRestore()) return MyStateMetadata( a_shape=result.a.shape, b_shape=result.b.shape, directory=directory / 'my_state', ) @ocp.args.register_with_handler(MyStateCheckpointHandler, for_save=True) @dataclass class MyStateSave(ocp.args.CheckpointArgs): item: MyState use_npz: bool = True @ocp.args.register_with_handler(MyStateCheckpointHandler, for_restore=True) @dataclass class MyStateRestore(ocp.args.CheckpointArgs): restore_as_dict: bool = False legacy_path2 = epath.Path('/tmp/legacy-checkpoint-handler-example/') legacy_checkpointer = ocp.Checkpointer(LegacyMyStateCheckpointHandler()) if legacy_path2.exists(): legacy_path2.rmtree() legacy_path2.mkdir() legacy_checkpointer.save(legacy_path2 / 'state', state, use_npz=False) !echo "Files in legacy checkpoint path:" $(ls /tmp/legacy-checkpoint-handler-example/) !echo "Files in legacy 'state' directory:" $(ls /tmp/legacy-checkpoint-handler-example/state) print('restored state: ', legacy_checkpointer.restore(legacy_path2 / 'state')) print('restored state as dict: ', legacy_checkpointer.restore(legacy_path2 / 'state', restore_as_dict=True)) print('metadata:', legacy_checkpointer.metadata(legacy_path2 / 'state')) path2 = epath.Path('/tmp/checkpoint-handler-example/') checkpointer = ocp.Checkpointer(MyStateCheckpointHandler()) if path2.exists(): path2.rmtree() path2.mkdir() checkpointer.save(path2 / 'state', args=MyStateSave(item=state, use_npz=False)) !echo "Files in checkpoint path:" $(ls /tmp/checkpoint-handler-example/) !echo "Files in 'state' directory:" $(ls /tmp/checkpoint-handler-example/state) print('restored state: ', checkpointer.restore(path2 / 'state', args=MyStateRestore())) print('restored state as dict: ', checkpointer.restore(path2 / 'state', args=MyStateRestore(restore_as_dict=True))) print('metadata:',checkpointer.metadata(path2 / 'state')) class MyStateAsyncCheckpointHandler(ocp.AsyncCheckpointHandler, MyStateCheckpointHandler): def __init__(self): self._executor = futures.ThreadPoolExecutor(max_workers=1) def save(self, directory: epath.Path, args: MyStateSave): time.sleep(.5) # Artificially inflate the time spent in this method. super().save(directory, args) async def async_save(self, directory: epath.Path, args: MyStateSave): return [self._executor.submit(functools.partial( self.save, directory, args))] def close(self): self._executor.shutdown() # Register MyStateAsyncCheckpointHandler for MyStateSave and MyStateRestore. # NOTE: This registration will overwrite the previous one with MyStateCheckpointHandler. # It is just for illustrating this example and should be avoided in real world systems. ocp.args.register_with_handler(MyStateAsyncCheckpointHandler, for_save=True)(MyStateSave) ocp.args.register_with_handler(MyStateAsyncCheckpointHandler, for_restore=True)(MyStateRestore) path3 = epath.Path('/tmp/checkpoint-handler-async/') if path3.exists(): path3.rmtree() path3.mkdir() async_checkpointer = ocp.AsyncCheckpointer(MyStateAsyncCheckpointHandler()) async_checkpointer.save(path3 / 'async-state', args=MyStateSave(item=state)) !echo "directory contents: "; ls /tmp/checkpoint-handler-async/ async_checkpointer.wait_until_finished() async_checkpointer.close() !ls /tmp/checkpoint-handler-async/ !ls /tmp/checkpoint-handler-async/async-state