#@title ##### Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" } # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. !pip install -Uq tf-nightly tfp-nightly arviz # @title Imports from collections import namedtuple import warnings import time import arviz as az import numpy as np import matplotlib.pyplot as plt import tensorflow as tf import tensorflow_probability as tfp from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.mcmc.internal import util as mcmc_util from tensorflow_probability.python.internal import unnest tfd = tfp.distributions tfb = tfp.bijectors dtype = tf.float64 az.style.use('arviz-darkgrid') # @title Helper functions def get_flat_unconstraining_bijector(jd_model, initial_position, batch_rank): unconstraining_bij = jd_model.experimental_default_event_space_bijector() # First take the constrained space, and unconstrain it: bijectors = [unconstraining_bij] # In case of a JDNamed, pack the state_parts into a dictionary if isinstance(initial_position, dict): bijectors.append(tfb.Invert(tfb.Restructure(list(initial_position.keys())))) # Easiest to do a transform here instead of casing for dictionaries again. list_position = tfb.Chain(bijectors[-1::-1])(initial_position) # This bijector takes flat tensors shapes and reshapes them to proper shape bijectors.append(tfb.JointMap( [tfb.Reshape(ps.shape(x)[batch_rank:]) for x in list_position])) # This splits a single tensor into state_parts bijectors.append(tfb.Split( num_or_size_splits=np.asarray( [ps.reduce_prod(ps.shape(x)[batch_rank:]) for x in list_position] ).flatten(), axis=-1)) # Notice that we want to apply the bijectors in the opposite order return tfb.Invert(tfb.Chain(bijectors)) MCMCState = namedtuple("MCMCState", ['pinned_model', 'initial_position', 'initial_transformed_position', 'bijector', 'batch_rank', 'batch_shape', 'target_log_prob_fn']) def setup_mcmc(model, n_chains, **pins): pinned_model = model.experimental_pin(**pins) # TODO: add argument to `sample_unpinned` to get chains initial_position = pinned_model.sample_unpinned(n_chains) target_log_prob_val = pinned_model.unnormalized_log_prob(initial_position) batch_rank = ps.rank(target_log_prob_val) batch_shape = ps.shape(target_log_prob_val) bijector = get_flat_unconstraining_bijector(pinned_model, initial_position, batch_rank) initial_transformed_position = bijector.forward(initial_position) # Jitter init initial_transformed_position = tfd.Uniform(-2., 2.).sample( initial_transformed_position.shape) initial_position = bijector.inverse(initial_transformed_position) target_log_prob_fn = lambda x: pinned_model.unnormalized_log_prob(bijector.inverse(x)) + bijector.inverse_log_det_jacobian(x, event_ndims=1) # target_log_prob_fn = pinned_model.unnormalized_log_prob return MCMCState(pinned_model=pinned_model, initial_position=initial_position, initial_transformed_position=initial_transformed_position, bijector=bijector, batch_rank=batch_rank, batch_shape=batch_shape, target_log_prob_fn=target_log_prob_fn) def make_base_kernel(mcmc_state, *, step_size, num_leapfrog_steps, momentum_distribution): return tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo( mcmc_state.target_log_prob_fn, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, momentum_distribution=momentum_distribution) def make_fast_adapt_kernel(mcmc_state, *, initial_step_size, num_leapfrog_steps, num_adaptation_steps, target_accept_prob=0.75, momentum_distribution=None): return tfp.mcmc.SimpleStepSizeAdaptation( make_base_kernel(mcmc_state, step_size=initial_step_size, num_leapfrog_steps=num_leapfrog_steps, momentum_distribution=momentum_distribution), num_adaptation_steps=num_adaptation_steps, target_accept_prob=target_accept_prob) def make_slow_adapt_kernel(mcmc_state, *, initial_running_variance, initial_step_size, num_leapfrog_steps, num_adaptation_steps, target_accept_prob=0.75): return tfp.experimental.mcmc.DiagonalMassMatrixAdaptation( make_fast_adapt_kernel(mcmc_state, initial_step_size=initial_step_size, num_leapfrog_steps=num_leapfrog_steps, num_adaptation_steps=num_adaptation_steps, target_accept_prob=target_accept_prob), initial_running_variance=initial_running_variance) # @title Kernels @tf.function(jit_compile=True) def fast_window(mcmc_state, num_leapfrog_steps, *, num_draws, initial_position, initial_step_size, momentum_distribution=None): kernel = make_fast_adapt_kernel( mcmc_state, initial_step_size=initial_step_size, num_leapfrog_steps=num_leapfrog_steps, num_adaptation_steps=num_draws, momentum_distribution=momentum_distribution) with warnings.catch_warnings(): warnings.simplefilter("ignore") draws, _, fkr = tfp.mcmc.sample_chain(num_draws, initial_position, kernel=kernel, return_final_kernel_results=True, trace_fn=None) weighted_running_variance = tfp.experimental.stats.RunningVariance.from_stats( num_samples=num_draws / 2, mean=tf.reduce_mean(draws[-num_draws // 2:], axis=ps.range(mcmc_state.batch_rank+1)), variance=tf.math.reduce_variance(draws[-num_draws // 2:], axis=ps.range(mcmc_state.batch_rank+1))) step_size = unnest.get_innermost(fkr, 'step_size') return draws[-1], step_size, weighted_running_variance @tf.function(jit_compile=True) def slow_window(mcmc_state, num_leapfrog_steps, *, num_draws, initial_position, initial_running_variance, initial_step_size): kernel = make_slow_adapt_kernel( mcmc_state, initial_running_variance=initial_running_variance, initial_step_size=initial_step_size, num_leapfrog_steps=num_leapfrog_steps, num_adaptation_steps=num_draws) with warnings.catch_warnings(): warnings.simplefilter("ignore") draws, _, fkr = tfp.mcmc.sample_chain(num_draws, initial_position, kernel=kernel, return_final_kernel_results=True, trace_fn=None) weighted_running_variance = tfp.experimental.stats.RunningVariance.from_stats( num_samples=num_draws / 2, mean=tf.reduce_mean(draws[-num_draws // 2:], axis=ps.range(mcmc_state.batch_rank+1)), variance=tf.math.reduce_variance(draws[-num_draws // 2:], axis=ps.range(mcmc_state.batch_rank+1))) step_size = unnest.get_innermost(fkr, 'step_size') momentum_distribution = unnest.get_outermost(fkr, 'momentum_distribution') return draws[-1], step_size, weighted_running_variance, momentum_distribution @tf.function(jit_compile=True) def do_sampling(mcmc_state, num_leapfrog_steps, *, num_draws, initial_position, step_size, momentum_distribution, trace_fn=None): kernel = make_base_kernel(mcmc_state, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, momentum_distribution=momentum_distribution) return tfp.mcmc.sample_chain(num_draws, initial_position, kernel=kernel, trace_fn=trace_fn) # @title Implementation of `sample_model` def sample_model(n_draws, joint_dist, *, n_chains=16, num_leapfrog_steps=8, trace_fn=None, **pins): mcmc_state = setup_mcmc(joint_dist, n_chains=n_chains, **pins) fast_window_size = 75 slow_window_size = 25 print(f"Fast window {fast_window_size}") position, step_size, running_variance = fast_window( mcmc_state, num_leapfrog_steps, num_draws=fast_window_size, initial_position=mcmc_state.initial_transformed_position, initial_step_size=tf.ones(mcmc_state.batch_shape)[..., tf.newaxis]) for idx in range(4): window_size = slow_window_size * (2 ** idx) print(f"Slow window {window_size}") position, step_size, running_variance, momentum_distribution = slow_window( mcmc_state, num_leapfrog_steps, num_draws=window_size, initial_position=position, initial_running_variance=running_variance, initial_step_size=step_size) print(f"Fast window {fast_window_size}") position, step_size, running_variance = fast_window( mcmc_state, num_leapfrog_steps, num_draws=fast_window_size, initial_position=position, initial_step_size=step_size, momentum_distribution=momentum_distribution) print(f"main draws {n_draws}") draws = do_sampling(mcmc_state, num_leapfrog_steps, num_draws=n_draws, initial_position=position, step_size=step_size, momentum_distribution=momentum_distribution, trace_fn=trace_fn) return mcmc_state.bijector.inverse(draws), mcmc_state # @title Write down model num_schools = 8 # number of schools treatment_effects = tf.constant([28., 8, -3, 7, -1, 1, 18, 12]) # treatment effects treatment_stddevs = tf.constant([15., 10, 16, 11, 9, 11, 10, 18]) # treatment SE eight_schools = tfd.JointDistributionSequential([ tfd.Normal(0., 5., name='avg_effect'), tfd.HalfNormal(5., name='avg_stddev'), tfd.Independent(tfd.Normal(loc=tf.zeros(num_schools), scale=tf.ones(num_schools)), reinterpreted_batch_ndims=1, name='school_effects_std'), lambda school_effects_std, avg_stddev, avg_effect: tfd.Independent( tfd.Normal(loc=(avg_effect[..., tf.newaxis] + avg_stddev[..., tf.newaxis] * school_effects_std), scale=treatment_stddevs), reinterpreted_batch_ndims=1, name='treatment_effects') ]) %time draws, mcmc_state = sample_model(500, eight_schools, n_chains=64, treatment_effects=treatment_effects) draws_dict = {k[0]: tf.einsum('ij...->ji...', v) for v, k in zip(draws, eight_schools.resolve_graph())} idata = az.from_dict(posterior=draws_dict) idata # Only plotting the first 4 chains az.plot_trace(idata.sel(chain=[0, 1, 2, 3]), combined=False, compact=True); # @title TFP Results az.summary(idata) # @title PyMC3 NUTS results data = az.load_arviz_data("non_centered_eight") az.summary(data, var_names=['mu', 'tau', 'theta_t']) az.plot_trace(data, var_names=['mu', 'tau', 'theta_t'], compact=True); np.random.seed(0) ndims = 5 ndata = 100 X = np.random.randn(ndata, ndims).astype(np.float32) w_ = np.random.randn(ndims).astype(np.float32) # hidden noise_ = 0.1 * np.random.randn(ndata).astype(np.float32) # hidden y_obs = X.dot(w_) + noise_ # @title Write down model linear_model = tfd.JointDistributionSequentialAutoBatched([ tfd.Normal(tf.zeros(ndims), tf.ones(ndims), name='w'), lambda w: tfd.Normal(tf.linalg.matvec(X, w), 0.1, name='y') ]) %time draws, mcmc_state = sample_model(500, linear_model, n_chains=256, y=y_obs) draws_dict = {k[0]: tf.einsum('ij...->ji...', v) for v, k in zip(draws, linear_model.resolve_graph())} idata = az.from_dict(posterior=draws_dict) idata # Only plotting the first 4 chains axes = az.plot_trace(idata.sel(chain=[0, 1, 2, 3])) for idx, (true_val, ax) in enumerate(zip(w_, axes)): ax[0].axvline(true_val, color=f'C{idx}', linestyle='dashed'); az.summary(idata)