import numpy as np import orbax.checkpoint as ocp from etils import epath train_state = { 'layer0': { 'kernel': np.ones((8, 8), dtype=np.float32), 'bias': np.ones((8,), dtype=np.float32), } } ### PREFER NOT TO USE THIS. ### ### PREFER TO USE ASYNC CHECKPOINTING INSTEAD (SEE BELOW). ### path = epath.Path('/tmp/sync_checkpoint') ckptr = ocp.Checkpointer(ocp.StandardCheckpointHandler()) ckptr.save(path, args=ocp.args.StandardSave(train_state)) path = epath.Path('/tmp/async_checkpoint') ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler()) ckptr.save(path, args=ocp.args.StandardSave(train_state)) ### Do some other work... ckptr.wait_until_finished() path = epath.Path('/tmp/async_checkpoint_manager') ckpt_mngr = ocp.CheckpointManager(path) def train_step(step, state): # update state values accordingly return step + 1, state step = 0 num_steps = 5 while step < num_steps: ckpt_mngr.save(step, args=ocp.args.StandardSave(train_state)) step, train_state = train_step(step, train_state) ckpt_mngr.wait_until_finished() ckpt_mngr.all_steps() ocp.CheckpointManagerOptions(enable_async_checkpointing=False)