Orbax Export is a library for exporting JAX models to TensorFlow SavedModel format.
Orbax Export provides three classes.
JaxModule
wraps a JAX function and its parameters to an exportable and callable
closure.ServingConfig
defines a serving configuration for a JaxModule, including
[a signature key and an input signature][1], and optionally pre- and
post-processing functions and extra [TrackableResources][2].ExportManager
builds the actual [serving signatures][1] based on a JaxModule and a list
of ServingConfigs, and saves them to the SavedModel format. It is for CPU.
Users can inherit ExportManager class and create their own "ExportManager"
for different hardwares.# Import Orbax Export classes.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import numpy as np
import jax
import jax.numpy as jnp
import tensorflow as tf
# Prepare the parameters and model function to export.
example1_params = { 'a': np.array(5.0), 'b': np.array(1.1), 'c': np.array(0.55)} # A pytree of the JAX model parameters.
# model f(x) = a * sin(x) + b * x + c, here (a, b, c) are model parameters
def example1_model_fn(params, inputs): # The JAX model function to export.
a, b, c = params['a'], params['b'], params['c']
return a * jnp.sin(inputs) + b * inputs + c
def example1_preprocess(inputs): # Optional: preprocessor in TF.
norm_inputs = tf.nest.map_structure(lambda x: x/tf.math.reduce_max(x), inputs)
return norm_inputs
def example1_postprocess(model_fn_outputs): # Optional: post-processor in TF.
return {'outputs': model_fn_outputs}
inputs = tf.random.normal([16], dtype=tf.float32)
model_outputs = example1_postprocess(example1_model_fn(example1_params, np.array(example1_preprocess(inputs))))
print("model output: ", model_outputs)
Exporting a JAX model to a CPU SavedModel
import tensorflow as tf
# Construct a JaxModule where JAX->TF conversion happens.
jax_module = JaxModule(example1_params, example1_model_fn)
# Export the JaxModule along with one or more serving configs.
export_mgr = ExportManager(
jax_module, [
ServingConfig(
'serving_default',
input_signature= [tf.TensorSpec(shape=[16], dtype=tf.float32)],
tf_preprocessor=example1_preprocess,
tf_postprocessor=example1_postprocess
),
])
output_dir='/tmp/example1_output_dir'
export_mgr.save(output_dir)
Load the TF saved_model model back and run it
loaded_model = tf.saved_model.load(output_dir)
loaded_model_outputs = loaded_model(inputs)
print("loaded model output: ", loaded_model_outputs)
np.testing.assert_allclose(model_outputs['outputs'], loaded_model_outputs['outputs'], atol=1e-5, rtol=1e-5)
This error message means the JAX funtion model_fn only can take single arg as the input.
Orbax is designed to take a JAX Module in the format of a Callable with
parameters of type PyTree and model inputs of type PyTree. If your JAX function
takes multiple inputs, you must pack them into a single JAX PyTree. Otherwise,
you will encounter this error message.
To solve this problem, you can update the ServingConfig.tf_preprocessor
function to pack the inputs into a single JAX PyTree. For example, our model
takes two inputs x and y. You can define the ServingConfig.tf_preprocessor
pack them into a list [x, y].
example2_params = {} # A pytree of the JAX model parameters.
def example2_model_fn(params, inputs):
x, y = inputs
return x + y
def example2_preprocessor(x, y):
# put the normal tf_preprocessor codes here.
return [x, y] # pack it into a single list for jax model_func.
jax_module = JaxModule(example2_params, example2_model_fn)
export_mgr = ExportManager(
jax_module,
[
ServingConfig(
'serving_default',
input_signature=[tf.TensorSpec([16]), tf.TensorSpec([16])],
tf_preprocessor=example2_preprocessor,
)
],
)
output_dir='/tmp/example2_output_dir'
export_mgr.save(output_dir)
loaded_model = tf.saved_model.load(output_dir)
loaded_model_outputs = loaded_model(tf.random.normal([16]), tf.random.normal([16]))
print("loaded model output: ", loaded_model_outputs)
Orbax.export.validate is library that can be used to validate the JAX model and its exported TF SavedModel format.
Users must finish the JAX model exporting first. Users can export the model by orbax.export or manually.
Orbax.export.validate provides those classes:
ValidationJob
take the model and data as input, then output the result.ValidationReport
compare the JAX model and TF SavedModel results, then generate the formatted
report.ValidationManager
take JaxModule as inputs and wrap the validation e2e flow.Here we same example as ExportManager.
from orbax.export.validate import ValidationManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
jax_module = JaxModule(example1_params, example1_model_fn)
batch_inputs = [inputs] * 16
serving_configs = [
ServingConfig(
'serving_default',
input_signature= [tf.TensorSpec(shape=[16], dtype=tf.float32)],
tf_preprocessor=example1_preprocess,
tf_postprocessor=example1_postprocess
),
]
# Provide computation method for the baseline.
validation_mgr = ValidationManager(jax_module, serving_configs,
batch_inputs)
tf_saved_model_path = "/tmp/example1_output_dir"
loaded_model = tf.saved_model.load(tf_saved_model_path)
# Provide the computation method for the candidate.
validation_reports = validation_mgr.validate(loaded_model)
# `validation_reports` is a python dict and the key is TF SavedModel serving_key.
for key in validation_reports:
assert(validation_reports[key].status.name == 'Pass')
# Users can also save the converted json to file.
print(validation_reports[key].to_json(indent=2))
Here we list those limitation of Orbax.export validate module.
ValidationReport module
can do apple-to-apple comparison between JAX model and TF model result, we
suggest users modify the model output as a dictionary.