%tensorflow_version 2.x import numpy as np import tensorflow as tf import tensorflow_probability as tfp tfb, tfd = tfp.bijectors, tfp.distributions physical_gpus = tf.config.experimental.list_physical_devices('GPU') print(physical_gpus) tf.config.experimental.set_virtual_device_configuration( physical_gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2000)] * 4) gpus = tf.config.list_logical_devices('GPU') print(gpus) st = tf.distribute.MirroredStrategy(devices=tf.config.list_logical_devices('GPU')) print(st.extended.worker_devices) # Draw samples from an MVN, then sort them. This way we can easily visually # verify the correct partition ends up on the correct GPUs. ndim = 3 def model(): Root = tfd.JointDistributionCoroutine.Root loc = yield Root(tfb.Shift(.5)(tfd.MultivariateNormalDiag(loc=tf.zeros([ndim])))) scale_tril = yield Root(tfb.FillScaleTriL()(tfd.MultivariateNormalDiag(loc=tf.zeros([ndim * (ndim + 1) // 2])))) yield tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale_tril) dist = tfd.JointDistributionCoroutine(model) tf.random.set_seed(1) loc, scale_tril, _ = dist.sample(seed=2) samples = dist.sample(value=([loc] * 1024, scale_tril, None), seed=3)[2] samples = tf.round(samples * 1000) / 1000 for dim in reversed(range(ndim)): samples = tf.gather(samples, tf.argsort(samples[:,dim])) print(samples) print(loc) print(scale_tril) print(tf.reduce_mean(samples, 0)) %%time def dataset_fn(ctx): batch_size = ctx.get_per_replica_batch_size(len(samples)) d = tf.data.Dataset.from_tensor_slices(samples).batch(batch_size) return d.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) ds = st.experimental_distribute_datasets_from_function(dataset_fn) observations = next(iter(ds)) # print(observations) @tf.function(autograph=False) def log_prob_and_grad(loc, scale_tril, observations): ctx = tf.distribute.get_replica_context() with tf.GradientTape() as tape: tape.watch((loc, scale_tril)) lp = tf.reduce_sum(dist.log_prob(loc, scale_tril, observations)) / len(samples) grad = tape.gradient(lp, (loc, scale_tril)) return ctx.all_reduce('sum', lp), [ctx.all_reduce('sum', g) for g in grad] @tf.function(autograph=False) @tf.custom_gradient def target_log_prob(loc, scale_tril): lp, grads = st.run(log_prob_and_grad, (loc, scale_tril, observations)) return lp.values[0], lambda grad_lp: [grad_lp * g.values[0] for g in grads] singleton_vals = tfp.math.value_and_gradient(target_log_prob, (loc, scale_tril)) kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob, step_size=.35, num_leapfrog_steps=2) kernel = tfp.mcmc.TransformedTransitionKernel(kernel, bijector=[tfb.Identity(), tfb.FillScaleTriL()]) @tf.function(autograph=False) def sample_chain(): return tfp.mcmc.sample_chain( num_results=200, num_burnin_steps=100, current_state=[tf.ones_like(loc), tf.linalg.eye(scale_tril.shape[-1])], kernel=kernel, trace_fn=lambda _, kr: kr.inner_results.is_accepted) samps, is_accepted = sample_chain() print(f'accept rate: {np.mean(is_accepted)}') print(f'ess: {tfp.mcmc.effective_sample_size(samps)}') print(tf.reduce_mean(samps[0], axis=0)) # print(tf.reduce_mean(samps[1], axis=0)) import matplotlib.pyplot as plt for dim in range(ndim): plt.figure(figsize=(10,1)) plt.hist(samps[0][:,dim], bins=50) plt.title(f'loc[{dim}]: prior mean = 0.5, observation = {loc[dim]}') plt.show() %%time batches_per_eval = 2 def dataset_fn(ctx): batch_size = ctx.get_per_replica_batch_size(len(samples)) d = tf.data.Dataset.from_tensor_slices(samples).batch(batch_size // batches_per_eval) return d.shard(ctx.num_input_pipelines, ctx.input_pipeline_id).prefetch(2) ds = st.experimental_distribute_datasets_from_function(dataset_fn) @tf.function(autograph=False) def log_prob_and_grad(loc, scale_tril, observations, prev_sum_lp, prev_sum_grads): with tf.GradientTape() as tape: tape.watch((loc, scale_tril)) lp = tf.reduce_sum(dist.log_prob(loc, scale_tril, observations)) / len(samples) grad = tape.gradient(lp, (loc, scale_tril)) return lp + prev_sum_lp, [g + pg for (g, pg) in zip(grad, prev_sum_grads)] @tf.function(autograph=False) @tf.custom_gradient def target_log_prob(loc, scale_tril): sum_lp = tf.zeros([]) sum_grads = [tf.zeros_like(x) for x in (loc, scale_tril)] sum_lp, sum_grads = st.run( lambda *x: tf.nest.map_structure(tf.identity, x), (sum_lp, sum_grads)) def reduce_fn(state, observations): sum_lp, sum_grads = state return st.run( log_prob_and_grad, (loc, scale_tril, observations, sum_lp, sum_grads)) sum_lp, sum_grads = ds.reduce((sum_lp, sum_grads), reduce_fn) sum_lp = st.reduce('sum', sum_lp, None) sum_grads = [st.reduce('sum', sg, None) for sg in sum_grads] return sum_lp, lambda grad_lp: [grad_lp * sg for sg in sum_grads] multibatch_vals = tfp.math.value_and_gradient(target_log_prob, (loc, scale_tril)) kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob, step_size=.35, num_leapfrog_steps=2) kernel = tfp.mcmc.TransformedTransitionKernel(kernel, bijector=[tfb.Identity(), tfb.FillScaleTriL()]) @tf.function(autograph=False) def sample_chain(): return tfp.mcmc.sample_chain( num_results=200, num_burnin_steps=100, current_state=[tf.ones_like(loc), tf.linalg.eye(scale_tril.shape[-1])], kernel=kernel, trace_fn=lambda _, kr: kr.inner_results.is_accepted) samps, is_accepted = sample_chain() print(f'accept rate: {np.mean(is_accepted)}') print(f'ess: {tfp.mcmc.effective_sample_size(samps)}') print(tf.reduce_mean(samps[0], axis=0)) # print(tf.reduce_mean(samps[1], axis=0)) import matplotlib.pyplot as plt for dim in range(ndim): plt.figure(figsize=(10,1)) plt.hist(samps[0][:,dim], bins=50) plt.title(f'loc[{dim}]: prior mean = 0.5, observation = {loc[dim]}') plt.show() for i, (sv, mv) in enumerate(zip(tf.nest.flatten(singleton_vals), tf.nest.flatten(multibatch_vals))): np.testing.assert_allclose(sv, mv, err_msg=i, rtol=1e-5)