# Setup import orbax.checkpoint as ocp import numpy as np # Example: Migrate original tree into the new_tree, which has the same # nested structure but different keys. original_tree = { 'a': 1, 'b': 2 } transformations = { 'a2': ocp.Transform(original_key='a'), 'b2': ocp.Transform(original_key='b') } new_tree = { 'a2': ..., 'b2': ... } ocp.apply_transformations(original_tree, transformations, new_tree) # Example 2: Renaming with regex original_tree = { 'a1': 1, 'b5': 2 } transformations = { r'([a-z])_([0-9])': ocp.Transform(original_key=r'\1\2'), } new_tree = { 'a_1': ..., 'b_5': ... } ocp.apply_transformations(original_tree, transformations, new_tree) # Example 3: Renaming nested trees original_tree = { 'a': 1, 'dense_1': {'kernel': 2, 'bias': 3}, 'dense_2': {'kernel': 4, 'bias': 5}, } # Nested keys can be represented by a single string by separating each level # with '/'. transformations = { r'([a-z]+)_NEW': ocp.Transform(original_key=r'\1'), r'([a-z]+)_([0-9])_NEW/([a-z]+)_1': ocp.Transform(original_key=r'\1_\2/\3'), } # This is equivalent to: transformations = { r'([a-z]+)_NEW': ocp.Transform(original_key=r'\1'), r'([a-z]+)_([0-9])_NEW': { '([a-z]+)_1': ocp.Transform(original_key=r'\1_\2/\3'),} } new_tree = { 'a_NEW': ..., 'dense_1_NEW': {'kernel_1': ..., 'bias_1': ...}, 'dense_2_NEW': {'kernel_1': ..., 'bias_1': ...}, } ocp.apply_transformations(original_tree, transformations, new_tree) # Example: Transform the values in a tree. original_tree = { 'a': 1, 'b': 2 } transformations = { 'a': ocp.Transform(value_fn=lambda v: v * 2), 'b2': ocp.Transform(value_fn=lambda v: v * 3, original_key='b') } new_tree = { 'a': ..., 'b2': ... # Output different key } ocp.apply_transformations(original_tree, transformations, new_tree) # Example 2: Transform values in a tree with regex (multiply all 'a' keys by 2 # all 'b' keys by 3). original_tree = { 'a1': 1, 'a2': 2, 'b': 3 } transformations = { r'a([0-9]?)\*2': ocp.Transform(value_fn=lambda v: v * 2, original_key=r'a\1'), r'b([0-9]?)\*3': ocp.Transform(value_fn=lambda v: v * 3, original_key=r'b\1') } new_tree = { 'a1*2': ..., 'a2*2': ..., 'b*3': ... } ocp.apply_transformations(original_tree, transformations, new_tree) # Example: Flatten nested structure original_tree = { 'a': 1, 'dense_1': {'kernel': 2, 'bias': 3}, 'dense_2': {'kernel': 4, 'bias': 5}, } transformations = { r'([a-z]+)': ocp.Transform(original_key=r'\1'), r'([a-z]+)_([0-9])_([a-z]+)': ocp.Transform(original_key=r'\1_\2/\3'), } new_tree = { 'a': ..., 'dense_1_kernel': ..., 'dense_1_bias': ..., 'dense_2_kernel': ..., 'dense_2_bias': ..., } ocp.apply_transformations(original_tree, transformations, new_tree) # Example: various multi_value_fn usage original_tree = { 'a': np.array([1, 2, 3, 4]), 'b': {'c': np.array([5, 6, 7, 8])}, } transformations = { 'a': ocp.Transform(multi_value_fn=lambda _, kv: kv['a'][-1]), 'b': { 'c': ocp.Transform(multi_value_fn=lambda _, kv: kv['a'] + kv['b']['c'])}, } new_tree = { 'a': ..., 'b': {'c': ...} } ocp.apply_transformations(original_tree, transformations, new_tree) # Example: Average the weights original_tree = { 'a': {'a_1': 1, 'a_2': 2}, 'b': {'b_1': 3, 'b_2': 4, 'b_3': 5}, } transformations = { r'([a-z]+)': ocp.Transform( multi_value_fn=lambda k, kv: sum(kv[k].values()) / len(kv[k])), } new_tree = { 'a': ..., 'b': ..., } ocp.apply_transformations(original_tree, transformations, new_tree) import flax.struct @flax.struct.dataclass class Small: key1: int @flax.struct.dataclass class Big: key1: int key2: int to_save = Big(key1=10, key2=100) to_restore = Small(key1=0) path = '/tmp/my-checkpoints/' ckptr = ocp.PyTreeCheckpointer() ckptr.save(path, to_save) restored1 = ckptr.restore( path, args=ocp.args.PyTreeRestore( to_restore, restore_args=ocp.checkpoint_utils.construct_restore_args(to_restore), transforms={} ) ) restored2 = ckptr.restore( path, args=ocp.args.PyTreeRestore( to_restore, restore_args=ocp.checkpoint_utils.construct_restore_args(to_restore), transforms={ r'(.*)key1(.*)': ocp.Transform(original_key=r'\1key2\2') } ) ) restored1 restored2 # Example: original_tree = { 'dense_1': {'kernel': 2, 'bias': 3}, } transformations = { r'(?P[a-z]+)_(?P[0-9])_(?P[a-z]+)': ocp.Transform( original_key=r'\g_\g/\g'), } new_tree = { 'dense_1_kernel': ..., 'dense_1_bias': ..., } ocp.apply_transformations(original_tree, transformations, new_tree)