The objective of this notebook is to demonstrate splitting a log_prob and gradient computation across a number of GPU devices. For development purposes, this was prototyped in colab with a single GPU partitioned into multiple logical GPUs.
Note: Since it runs on a single GPU, performance is not representative of what can be achieved with multiple GPUs. Usage of tf.data can likely benefit from some tuning when deployed to multiple GPUs.
Needs a GPU: Edit > Notebook Settings: Hardware Accelerator => GPU
%tensorflow_version 2.x
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfb, tfd = tfp.bijectors, tfp.distributions
physical_gpus = tf.config.experimental.list_physical_devices('GPU')
print(physical_gpus)
tf.config.experimental.set_virtual_device_configuration(
physical_gpus[0],
[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2000)] * 4)
gpus = tf.config.list_logical_devices('GPU')
print(gpus)
st = tf.distribute.MirroredStrategy(devices=tf.config.list_logical_devices('GPU'))
print(st.extended.worker_devices)
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
[LogicalDevice(name='/device:GPU:0', device_type='GPU'), LogicalDevice(name='/device:GPU:1', device_type='GPU'), LogicalDevice(name='/device:GPU:2', device_type='GPU'), LogicalDevice(name='/device:GPU:3', device_type='GPU')]
WARNING:tensorflow:NCCL is not supported when using virtual GPUs, fallingback to reduction to one device
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
# Draw samples from an MVN, then sort them. This way we can easily visually
# verify the correct partition ends up on the correct GPUs.
ndim = 3
def model():
Root = tfd.JointDistributionCoroutine.Root
loc = yield Root(tfb.Shift(.5)(tfd.MultivariateNormalDiag(loc=tf.zeros([ndim]))))
scale_tril = yield Root(tfb.FillScaleTriL()(tfd.MultivariateNormalDiag(loc=tf.zeros([ndim * (ndim + 1) // 2]))))
yield tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale_tril)
dist = tfd.JointDistributionCoroutine(model)
tf.random.set_seed(1)
loc, scale_tril, _ = dist.sample(seed=2)
samples = dist.sample(value=([loc] * 1024, scale_tril, None), seed=3)[2]
samples = tf.round(samples * 1000) / 1000
for dim in reversed(range(ndim)):
samples = tf.gather(samples, tf.argsort(samples[:,dim]))
print(samples)
tf.Tensor( [[-4.534 -5.856 3.606] [-4.527 -5.875 1.671] [-4.269 -4.346 5.697] ... [ 2.158 5.71 -2.926] [ 2.302 6.658 -3.491] [ 2.632 5.67 -4.854]], shape=(1024, 3), dtype=float32)
print(loc)
print(scale_tril)
print(tf.reduce_mean(samples, 0))
tf.Tensor([-1.0574996 0.24829748 1.0737331 ], shape=(3,), dtype=float32) tf.Tensor( [[ 1.1475685 0. 0. ] [ 1.9094281 0.5724521 0. ] [-1.1899896 0.49813363 1.5088601 ]], shape=(3, 3), dtype=float32) tf.Tensor([-0.9953702 0.3626416 1.0675195], shape=(3,), dtype=float32)
%%time
def dataset_fn(ctx):
batch_size = ctx.get_per_replica_batch_size(len(samples))
d = tf.data.Dataset.from_tensor_slices(samples).batch(batch_size)
return d.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
ds = st.experimental_distribute_datasets_from_function(dataset_fn)
observations = next(iter(ds))
# print(observations)
@tf.function(autograph=False)
def log_prob_and_grad(loc, scale_tril, observations):
ctx = tf.distribute.get_replica_context()
with tf.GradientTape() as tape:
tape.watch((loc, scale_tril))
lp = tf.reduce_sum(dist.log_prob(loc, scale_tril, observations)) / len(samples)
grad = tape.gradient(lp, (loc, scale_tril))
return ctx.all_reduce('sum', lp), [ctx.all_reduce('sum', g) for g in grad]
@tf.function(autograph=False)
@tf.custom_gradient
def target_log_prob(loc, scale_tril):
lp, grads = st.run(log_prob_and_grad, (loc, scale_tril, observations))
return lp.values[0], lambda grad_lp: [grad_lp * g.values[0] for g in grads]
singleton_vals = tfp.math.value_and_gradient(target_log_prob, (loc, scale_tril))
kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob, step_size=.35, num_leapfrog_steps=2)
kernel = tfp.mcmc.TransformedTransitionKernel(kernel, bijector=[tfb.Identity(), tfb.FillScaleTriL()])
@tf.function(autograph=False)
def sample_chain():
return tfp.mcmc.sample_chain(
num_results=200, num_burnin_steps=100,
current_state=[tf.ones_like(loc), tf.linalg.eye(scale_tril.shape[-1])],
kernel=kernel, trace_fn=lambda _, kr: kr.inner_results.is_accepted)
samps, is_accepted = sample_chain()
print(f'accept rate: {np.mean(is_accepted)}')
print(f'ess: {tfp.mcmc.effective_sample_size(samps)}')
print(tf.reduce_mean(samps[0], axis=0))
# print(tf.reduce_mean(samps[1], axis=0))
import matplotlib.pyplot as plt
for dim in range(ndim):
plt.figure(figsize=(10,1))
plt.hist(samps[0][:,dim], bins=50)
plt.title(f'loc[{dim}]: prior mean = 0.5, observation = {loc[dim]}')
plt.show()
accept rate: 0.8
ess: [<tf.Tensor: shape=(3,), dtype=float32, numpy=array([42.702564, 56.667793, 43.04328 ], dtype=float32)>, <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[57.469185, nan, nan],
[34.60636 , 43.471287, nan],
[13.848625, 30.807411, 72.34708 ]], dtype=float32)>]
tf.Tensor([-0.18598816 0.7396643 0.4074543 ], shape=(3,), dtype=float32)
CPU times: user 14.3 s, sys: 1.57 s, total: 15.8 s Wall time: 12.9 s
%%time
batches_per_eval = 2
def dataset_fn(ctx):
batch_size = ctx.get_per_replica_batch_size(len(samples))
d = tf.data.Dataset.from_tensor_slices(samples).batch(batch_size // batches_per_eval)
return d.shard(ctx.num_input_pipelines, ctx.input_pipeline_id).prefetch(2)
ds = st.experimental_distribute_datasets_from_function(dataset_fn)
@tf.function(autograph=False)
def log_prob_and_grad(loc, scale_tril, observations, prev_sum_lp, prev_sum_grads):
with tf.GradientTape() as tape:
tape.watch((loc, scale_tril))
lp = tf.reduce_sum(dist.log_prob(loc, scale_tril, observations)) / len(samples)
grad = tape.gradient(lp, (loc, scale_tril))
return lp + prev_sum_lp, [g + pg for (g, pg) in zip(grad, prev_sum_grads)]
@tf.function(autograph=False)
@tf.custom_gradient
def target_log_prob(loc, scale_tril):
sum_lp = tf.zeros([])
sum_grads = [tf.zeros_like(x) for x in (loc, scale_tril)]
sum_lp, sum_grads = st.run(
lambda *x: tf.nest.map_structure(tf.identity, x), (sum_lp, sum_grads))
def reduce_fn(state, observations):
sum_lp, sum_grads = state
return st.run(
log_prob_and_grad, (loc, scale_tril, observations, sum_lp, sum_grads))
sum_lp, sum_grads = ds.reduce((sum_lp, sum_grads), reduce_fn)
sum_lp = st.reduce('sum', sum_lp, None)
sum_grads = [st.reduce('sum', sg, None) for sg in sum_grads]
return sum_lp, lambda grad_lp: [grad_lp * sg for sg in sum_grads]
multibatch_vals = tfp.math.value_and_gradient(target_log_prob, (loc, scale_tril))
kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob, step_size=.35, num_leapfrog_steps=2)
kernel = tfp.mcmc.TransformedTransitionKernel(kernel, bijector=[tfb.Identity(), tfb.FillScaleTriL()])
@tf.function(autograph=False)
def sample_chain():
return tfp.mcmc.sample_chain(
num_results=200, num_burnin_steps=100,
current_state=[tf.ones_like(loc), tf.linalg.eye(scale_tril.shape[-1])],
kernel=kernel, trace_fn=lambda _, kr: kr.inner_results.is_accepted)
samps, is_accepted = sample_chain()
print(f'accept rate: {np.mean(is_accepted)}')
print(f'ess: {tfp.mcmc.effective_sample_size(samps)}')
print(tf.reduce_mean(samps[0], axis=0))
# print(tf.reduce_mean(samps[1], axis=0))
import matplotlib.pyplot as plt
for dim in range(ndim):
plt.figure(figsize=(10,1))
plt.hist(samps[0][:,dim], bins=50)
plt.title(f'loc[{dim}]: prior mean = 0.5, observation = {loc[dim]}')
plt.show()
accept rate: 0.8
ess: [<tf.Tensor: shape=(3,), dtype=float32, numpy=array([42.702564, 56.667793, 43.043278], dtype=float32)>, <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[57.469193, nan, nan],
[34.606365, 43.471287, nan],
[13.848625, 30.807411, 72.34709 ]], dtype=float32)>]
tf.Tensor([-0.18598816 0.7396643 0.4074543 ], shape=(3,), dtype=float32)
CPU times: user 1min 8s, sys: 8.71 s, total: 1min 17s Wall time: 51.5 s
Sanity check logprob and gradients.
for i, (sv, mv) in enumerate(zip(tf.nest.flatten(singleton_vals),
tf.nest.flatten(multibatch_vals))):
np.testing.assert_allclose(sv, mv, err_msg=i, rtol=1e-5)