import orbax.checkpoint as ocp mngr = ocp.CheckpointManager( '/tmp/mydir/', ocp.PyTreeCheckpointer(), ocp.CheckpointManagerOptions(save_interval_steps=4) ) def train_step(s): return s state = {'a': 1, 'b': 2} start_step = 0 num_steps = 12 if mngr.latest_step() is not None: start_step = mngr.latest_step() state = mngr.restore(start_step) for step in range(start_step, num_steps): state = train_step(state) mngr.save(step, state) for step in range(start_step, num_steps): state = train_step(state) mngr.save(step, state) if mngr.reached_preemption(step): mngr.wait_until_finished() exit()