from __future__ import absolute_import from __future__ import division from __future__ import print_function import matplotlib.pyplot as plt import numpy as np import seaborn as sns import tensorflow as tf import tensorflow_probability as tfp from tensorflow_probability import edward2 as ed tfd = tfp.distributions # tf.enable_eager_execution() #Model Class Source import copy import collections class MetaModel(type): def __call__(cls, *args, **kwargs): obj = type.__call__(cls, *args, **kwargs) obj._load_observed() obj._load_unobserved() return obj class BaseModel(object): __metaclass__ = MetaModel def _load_unobserved(self): unobserved_fun = self._unobserved_vars() self.unobserved = unobserved_fun() def _load_observed(self): self.observed = copy.copy(vars(self)) def _unobserved_vars(self): def unobserved_fn(*args, **kwargs): unobserved_vars = collections.OrderedDict() def interceptor(f, *args, **kwargs): name = kwargs.get("name") rv = f(*args, **kwargs) if name not in self.observed: unobserved_vars[name] = rv.shape return rv with ed.interception(interceptor): self.__call__() return unobserved_vars return unobserved_fn # def observe(self, states): # for name, value in states.iteritems(): # setattr(self, name, value) def target_log_prob_fn(self, *args, **kwargs): """Unnormalized target density as a function of unobserved states.""" def log_joint_fn(*args, **kwargs): states = dict(zip(self.unobserved, args)) states.update(self.observed) log_probs = [] def interceptor(f, *args, **kwargs): name = kwargs.get("name") for name, value in states.iteritems(): if kwargs.get("name") == name: kwargs["value"] = value rv = f(*args, **kwargs) log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value)) log_probs.append(log_prob) return rv with ed.interception(interceptor): self.__call__() log_prob = sum(log_probs) return log_prob return log_joint_fn def get_posterior_fn(self, states={}, *args, **kwargs): """Get the log joint prob given arbitrary values for vars""" def posterior_fn(*args, **kwargs): def interceptor(f, *args, **kwargs): name = kwargs.get("name") for name, value in states.iteritems(): if kwargs.get("name") == name: kwargs["value"] = value rv = f(*args, **kwargs) return rv with ed.interception(interceptor): return self.__call__() return posterior_fn def __call__(self): return self.call() def call(self, *args, **kwargs): raise NotImplementedError # This is a really quick / hacky sample function. # Ideally the user could choose the kernel or inference method # e.g., I could imagine a user defining a variational approximation in the model # and then using VI as a sample option here where the sample method looks for # model.q() # Also, it's relatively straightforward to see how one could return arbitrary # diagnostics given the model. # Todo: Add diagnostics, multiple chains, more automatic inference. def sample(model, num_results=5000, num_burnin_steps=3000, step_size=.4, num_leapfrog_steps=3, numpy=True): initial_state = [] for name, shape in model.unobserved.iteritems(): initial_state.append(.5 * tf.ones(shape, name="init_{}".format(name))) states, kernel_results = tfp.mcmc.sample_chain( num_results=num_results, num_burnin_steps=num_burnin_steps, current_state=initial_state, kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=model.target_log_prob_fn(), step_size=step_size, num_leapfrog_steps=num_leapfrog_steps)) if numpy: with tf.Session() as sess: states, is_accepted_ = sess.run([states, kernel_results.is_accepted]) accepted = np.sum(is_accepted_) print("Acceptance rate: {}".format(accepted / num_results)) return dict(zip(model.unobserved.keys(), states)) num_schools = 8 # number of schools treatment_effects = np.array( [28, 8, -3, 7, -1, 1, 18, 12], dtype=np.float32) # treatment effects treatment_stddevs = np.array( [15, 10, 16, 11, 9, 11, 10, 18], dtype=np.float32) # treatment SE fig, ax = plt.subplots() plt.bar(range(num_schools), treatment_effects, yerr=treatment_stddevs) plt.title("8 Schools treatment effects") plt.xlabel("School") plt.ylabel("Treatment effect") fig.set_size_inches(10, 8) plt.show() class Model(BaseModel): def __init__(self, num_schools, treatment_stddevs, treatment_effects): super(BaseModel, self).__init__() self.num_schools = num_schools self.treatment_stddevs = treatment_stddevs self.treatment_effects = treatment_effects def call(self): avg_effect = ed.Normal(loc=0., scale=10., name="avg_effect") # `mu` above avg_stddev = ed.Normal( loc=5., scale=1., name="avg_stddev") # `log(tau)` above school_effects_standard = ed.Normal( loc=tf.zeros(self.num_schools), scale=tf.ones(self.num_schools), name="school_effects_standard") # `theta_prime` above school_effects = avg_effect + tf.exp( avg_stddev) * school_effects_standard # `theta` above treatment_effects = ed.Normal( loc=school_effects, scale=self.treatment_stddevs, name="treatment_effects") # `y` above return treatment_effects model = Model(num_schools, treatment_stddevs, treatment_effects) with tf.Session() as sess: print(model().eval()) %%time trace = sample(model) model.observed model.unobserved # The sampling function can return the trace with names because the model object # knows the names of the unobserved variables it wants to sample. trace school_effects_samples = ( trace['avg_effect'][:, np.newaxis] + np.exp(trace['avg_stddev'])[:, np.newaxis] * trace['school_effects_standard']) print("E[avg_effect] = {}".format(trace['avg_effect'])) print("E[avg_stddev] = {}".format(trace['avg_stddev'])) print("E[school_effects_standard] =") print(trace['school_effects_standard'].mean(0)) print("E[school_effects] =") print(school_effects_samples[:, ].mean(0)) import warnings warnings.filterwarnings('ignore') fig, axes = plt.subplots(8, 2, sharex='col', sharey='col') fig.set_size_inches(12, 10) for i in range(num_schools): axes[i][0].plot(school_effects_samples[:,i]) axes[i][0].title.set_text("School {} treatment effect chain".format(i)) sns.kdeplot(school_effects_samples[:,i], ax=axes[i][1], shade=True) axes[i][1].title.set_text("School {} treatment effect distribution".format(i)) axes[num_schools - 1][0].set_xlabel("Iteration") axes[num_schools - 1][1].set_xlabel("School effect") fig.tight_layout() plt.show() # Compute the 95% interval for school_effects school_effects_low = np.array([ np.percentile(school_effects_samples[:, i], 2.5) for i in range(num_schools) ]) school_effects_med = np.array([ np.percentile(school_effects_samples[:, i], 50) for i in range(num_schools) ]) school_effects_hi = np.array([ np.percentile(school_effects_samples[:, i], 97.5) for i in range(num_schools) ]) fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True) ax.scatter(np.array(range(num_schools)), school_effects_med, color='red', s=60) ax.scatter( np.array(range(num_schools)) + 0.1, treatment_effects, color='blue', s=60) avg_effect = trace['avg_effect'].mean() plt.plot([-0.2, 7.4], [avg_effect, avg_effect], 'k', linestyle='--') ax.errorbar( np.array(range(8)), school_effects_med, yerr=[ school_effects_med - school_effects_low, school_effects_hi - school_effects_med ], fmt='none') ax.legend(('avg_effect', 'Edward2/HMC', 'Observed effect'), fontsize=14) plt.xlabel('School') plt.ylabel('Treatment effect') plt.title('Edward2 HMC estimated school treatment effects vs. observed data') fig.set_size_inches(10, 8) plt.show() print("Inferred posterior mean: {0:.2f}".format( np.mean(school_effects_samples[:,]))) print("Inferred posterior mean se: {0:.2f}".format( np.std(school_effects_samples[:,]))) trace_mean={ "avg_effect": trace['avg_effect'].mean(0), "avg_stddev": trace['avg_stddev'].mean(0), "school_effects_standard": trace['school_effects_standard'].mean(0) } posterior = model.get_posterior_fn(states=trace_mean) with tf.Session() as sess: posterior_predictive = sess.run(posterior().distribution.sample(5000)) fig, axes = plt.subplots(4, 2, sharex=True, sharey=True) fig.set_size_inches(12, 10) fig.tight_layout() for i, ax in enumerate(axes): sns.kdeplot(posterior_predictive[:, 2*i], ax=ax[0], shade=True) ax[0].title.set_text( "School {} treatment effect posterior predictive".format(2*i)) sns.kdeplot(posterior_predictive[:, 2*i + 1], ax=ax[1], shade=True) ax[1].title.set_text( "School {} treatment effect posterior predictive".format(2*i + 1)) plt.show() # The mean predicted treatment effects for each of the eight schools.# The m prediction = posterior_predictive.mean(axis=0) treatment_effects - prediction residuals = treatment_effects - posterior_predictive fig, axes = plt.subplots(4, 2, sharex=True, sharey=True) fig.set_size_inches(12, 10) fig.tight_layout() for i, ax in enumerate(axes): sns.kdeplot(residuals[:, 2*i], ax=ax[0], shade=True) ax[0].title.set_text( "School {} treatment effect residuals".format(2*i)) sns.kdeplot(residuals[:, 2*i + 1], ax=ax[1], shade=True) ax[1].title.set_text( "School {} treatment effect residuals".format(2*i + 1)) plt.show() # This is identical to # https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Eight_Schools.ipynb # with one change: # original: target_log_prob_fn=target_log_prob_fn # here: target_log_prob_fn=model.target_log_prob_fn() num_results = 5000 num_burnin_steps = 3000 states, kernel_results = tfp.mcmc.sample_chain( num_results=num_results, num_burnin_steps=num_burnin_steps, current_state=[ tf.zeros([], name='init_avg_effect'), tf.zeros([], name='init_avg_stddev'), tf.ones([num_schools], name='init_school_effects_standard'), ], kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=model.target_log_prob_fn(), step_size=0.4, num_leapfrog_steps=3)) avg_effect, avg_stddev, school_effects_standard = states with tf.Session() as sess: [ avg_effect_, avg_stddev_, school_effects_standard_, is_accepted_, ] = sess.run([ avg_effect, avg_stddev, school_effects_standard, kernel_results.is_accepted, ]) school_effects_samples = ( avg_effect_[:, np.newaxis] + np.exp(avg_stddev_)[:, np.newaxis] * school_effects_standard_) num_accepted = np.sum(is_accepted_) print('Acceptance rate: {}'.format(num_accepted / num_results)) fig, axes = plt.subplots(8, 2, sharex='col', sharey='col') fig.set_size_inches(12, 10) for i in range(num_schools): axes[i][0].plot(school_effects_samples[:,i]) axes[i][0].title.set_text("School {} treatment effect chain".format(i)) sns.kdeplot(school_effects_samples[:,i], ax=axes[i][1], shade=True) axes[i][1].title.set_text("School {} treatment effect distribution".format(i)) axes[num_schools - 1][0].set_xlabel("Iteration") axes[num_schools - 1][1].set_xlabel("School effect") fig.tight_layout() plt.show() num_tumors = np.array([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 5, 2, 5, 2, 7, 7, 3, 3, 2, 9, 10, 4, 4, 4, 4, 4, 4, 4, 10, 4, 4, 4, 5, 11, 12, 5, 5, 6, 5, 6, 6, 6, 6, 16, 15, 15, 9 ], dtype=np.float32) num_rats = np.array([ 20, 20, 20, 20, 20, 20, 20, 19, 19, 19, 19, 18, 18, 17, 20, 20, 20, 20, 19, 19, 18, 18, 27, 25, 24, 23, 20, 20, 20, 20, 20, 20, 10, 49, 19, 46, 17, 49, 47, 20, 20, 13, 48, 50, 20, 20, 20, 20, 20, 20, 20, 48, 19, 19, 19, 22, 46, 49, 20, 20, 23, 19, 22, 20, 20, 20, 52, 46, 47, 24 ], dtype=np.float32) num_trials = num_tumors.shape[0] class RatsModel(BaseModel): def __init__(self, num_trials, num_rats, num_tumors): super(BaseModel, self).__init__() self.num_trials = num_trials self.num_rats = num_rats self.num_tumors = num_tumors def call(self): mu = ed.Uniform(low=0., high=1., name="mu") nu = ed.Uniform(low=0., high=1., name="nu") alpha = mu / (nu * nu) beta = (1. - mu) / (nu * nu) thetas = ed.Beta(alpha, beta, sample_shape=self.num_trials, name="thetas") num_tumors = ed.Binomial( total_count=self.num_rats, probs=thetas, value=tf.zeros(self.num_trials), name="num_tumors") return num_tumors rats_model = RatsModel(num_trials, num_rats, num_tumors) rats_model.observed rats_model.unobserved %%time trace = sample( rats_model, step_size=0.007, num_results=20000, num_burnin_steps=30000) trace alpha_ = trace['mu'] / (trace['nu']**2) beta_ = (1. - trace['mu']) / (trace['nu']**2) def plot_trace(trace, name): fig, axes = plt.subplots(1, 2, sharex='col', sharey='col') fig.set_size_inches(14, 5) axes[0].plot(trace) axes[0].title.set_text("{} Trace".format(name)) sns.kdeplot(trace, ax=axes[1], shade=False) axes[1].title.set_text("{} Posterior Density".format(name)) fig.tight_layout() plt.show() plot_trace(beta_, "Beta") plot_trace(alpha_, "Alpha")