#@title ##### Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This notebook demonstrates putting together a sample_model function that implements Stan's windowed adaptation algorithm.
The sampling done by sample_model is similar in quality to that done by PyMC3 or Stan (each of those uses NUTS, not HMC, and I did not do any deep parameter searching for good integration lengths).
This incidentally shows how to run inference with hundreds of chains.
Run in Google Colab
|
View source on GitHub
|
!pip install -Uq tf-nightly tfp-nightly arviz
# @title Imports
from collections import namedtuple
import warnings
import time
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.mcmc.internal import util as mcmc_util
from tensorflow_probability.python.internal import unnest
tfd = tfp.distributions
tfb = tfp.bijectors
dtype = tf.float64
az.style.use('arviz-darkgrid')
# @title Helper functions
def get_flat_unconstraining_bijector(jd_model, initial_position, batch_rank):
unconstraining_bij = jd_model.experimental_default_event_space_bijector()
# First take the constrained space, and unconstrain it:
bijectors = [unconstraining_bij]
# In case of a JDNamed, pack the state_parts into a dictionary
if isinstance(initial_position, dict):
bijectors.append(tfb.Invert(tfb.Restructure(list(initial_position.keys()))))
# Easiest to do a transform here instead of casing for dictionaries again.
list_position = tfb.Chain(bijectors[-1::-1])(initial_position)
# This bijector takes flat tensors shapes and reshapes them to proper shape
bijectors.append(tfb.JointMap(
[tfb.Reshape(ps.shape(x)[batch_rank:]) for x in list_position]))
# This splits a single tensor into state_parts
bijectors.append(tfb.Split(
num_or_size_splits=np.asarray(
[ps.reduce_prod(ps.shape(x)[batch_rank:]) for x in list_position]
).flatten(),
axis=-1))
# Notice that we want to apply the bijectors in the opposite order
return tfb.Invert(tfb.Chain(bijectors))
MCMCState = namedtuple("MCMCState", ['pinned_model',
'initial_position',
'initial_transformed_position',
'bijector',
'batch_rank',
'batch_shape',
'target_log_prob_fn'])
def setup_mcmc(model, n_chains, **pins):
pinned_model = model.experimental_pin(**pins)
# TODO: add argument to `sample_unpinned` to get chains
initial_position = pinned_model.sample_unpinned(n_chains)
target_log_prob_val = pinned_model.unnormalized_log_prob(initial_position)
batch_rank = ps.rank(target_log_prob_val)
batch_shape = ps.shape(target_log_prob_val)
bijector = get_flat_unconstraining_bijector(pinned_model,
initial_position,
batch_rank)
initial_transformed_position = bijector.forward(initial_position)
# Jitter init
initial_transformed_position = tfd.Uniform(-2., 2.).sample(
initial_transformed_position.shape)
initial_position = bijector.inverse(initial_transformed_position)
target_log_prob_fn = lambda x: pinned_model.unnormalized_log_prob(bijector.inverse(x)) + bijector.inverse_log_det_jacobian(x, event_ndims=1)
# target_log_prob_fn = pinned_model.unnormalized_log_prob
return MCMCState(pinned_model=pinned_model,
initial_position=initial_position,
initial_transformed_position=initial_transformed_position,
bijector=bijector,
batch_rank=batch_rank,
batch_shape=batch_shape,
target_log_prob_fn=target_log_prob_fn)
def make_base_kernel(mcmc_state, *,
step_size,
num_leapfrog_steps,
momentum_distribution):
return tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
mcmc_state.target_log_prob_fn, step_size=step_size,
num_leapfrog_steps=num_leapfrog_steps,
momentum_distribution=momentum_distribution)
def make_fast_adapt_kernel(mcmc_state, *,
initial_step_size,
num_leapfrog_steps,
num_adaptation_steps,
target_accept_prob=0.75,
momentum_distribution=None):
return tfp.mcmc.SimpleStepSizeAdaptation(
make_base_kernel(mcmc_state,
step_size=initial_step_size,
num_leapfrog_steps=num_leapfrog_steps,
momentum_distribution=momentum_distribution),
num_adaptation_steps=num_adaptation_steps,
target_accept_prob=target_accept_prob)
def make_slow_adapt_kernel(mcmc_state, *,
initial_running_variance,
initial_step_size,
num_leapfrog_steps,
num_adaptation_steps,
target_accept_prob=0.75):
return tfp.experimental.mcmc.DiagonalMassMatrixAdaptation(
make_fast_adapt_kernel(mcmc_state,
initial_step_size=initial_step_size,
num_leapfrog_steps=num_leapfrog_steps,
num_adaptation_steps=num_adaptation_steps,
target_accept_prob=target_accept_prob),
initial_running_variance=initial_running_variance)
# @title Kernels
@tf.function(jit_compile=True)
def fast_window(mcmc_state, num_leapfrog_steps, *,
num_draws,
initial_position,
initial_step_size,
momentum_distribution=None):
kernel = make_fast_adapt_kernel(
mcmc_state,
initial_step_size=initial_step_size,
num_leapfrog_steps=num_leapfrog_steps,
num_adaptation_steps=num_draws,
momentum_distribution=momentum_distribution)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
draws, _, fkr = tfp.mcmc.sample_chain(num_draws, initial_position,
kernel=kernel,
return_final_kernel_results=True,
trace_fn=None)
weighted_running_variance = tfp.experimental.stats.RunningVariance.from_stats(
num_samples=num_draws / 2,
mean=tf.reduce_mean(draws[-num_draws // 2:],
axis=ps.range(mcmc_state.batch_rank+1)),
variance=tf.math.reduce_variance(draws[-num_draws // 2:],
axis=ps.range(mcmc_state.batch_rank+1)))
step_size = unnest.get_innermost(fkr, 'step_size')
return draws[-1], step_size, weighted_running_variance
@tf.function(jit_compile=True)
def slow_window(mcmc_state, num_leapfrog_steps, *,
num_draws,
initial_position,
initial_running_variance,
initial_step_size):
kernel = make_slow_adapt_kernel(
mcmc_state,
initial_running_variance=initial_running_variance,
initial_step_size=initial_step_size,
num_leapfrog_steps=num_leapfrog_steps,
num_adaptation_steps=num_draws)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
draws, _, fkr = tfp.mcmc.sample_chain(num_draws, initial_position,
kernel=kernel,
return_final_kernel_results=True,
trace_fn=None)
weighted_running_variance = tfp.experimental.stats.RunningVariance.from_stats(
num_samples=num_draws / 2,
mean=tf.reduce_mean(draws[-num_draws // 2:],
axis=ps.range(mcmc_state.batch_rank+1)),
variance=tf.math.reduce_variance(draws[-num_draws // 2:],
axis=ps.range(mcmc_state.batch_rank+1)))
step_size = unnest.get_innermost(fkr, 'step_size')
momentum_distribution = unnest.get_outermost(fkr, 'momentum_distribution')
return draws[-1], step_size, weighted_running_variance, momentum_distribution
@tf.function(jit_compile=True)
def do_sampling(mcmc_state, num_leapfrog_steps, *,
num_draws,
initial_position,
step_size,
momentum_distribution,
trace_fn=None):
kernel = make_base_kernel(mcmc_state,
step_size=step_size,
num_leapfrog_steps=num_leapfrog_steps,
momentum_distribution=momentum_distribution)
return tfp.mcmc.sample_chain(num_draws, initial_position,
kernel=kernel,
trace_fn=trace_fn)
# @title Implementation of `sample_model`
def sample_model(n_draws, joint_dist, *,
n_chains=16,
num_leapfrog_steps=8,
trace_fn=None,
**pins):
mcmc_state = setup_mcmc(joint_dist, n_chains=n_chains, **pins)
fast_window_size = 75
slow_window_size = 25
print(f"Fast window {fast_window_size}")
position, step_size, running_variance = fast_window(
mcmc_state, num_leapfrog_steps,
num_draws=fast_window_size,
initial_position=mcmc_state.initial_transformed_position,
initial_step_size=tf.ones(mcmc_state.batch_shape)[..., tf.newaxis])
for idx in range(4):
window_size = slow_window_size * (2 ** idx)
print(f"Slow window {window_size}")
position, step_size, running_variance, momentum_distribution = slow_window(
mcmc_state, num_leapfrog_steps,
num_draws=window_size,
initial_position=position,
initial_running_variance=running_variance,
initial_step_size=step_size)
print(f"Fast window {fast_window_size}")
position, step_size, running_variance = fast_window(
mcmc_state, num_leapfrog_steps,
num_draws=fast_window_size,
initial_position=position,
initial_step_size=step_size,
momentum_distribution=momentum_distribution)
print(f"main draws {n_draws}")
draws = do_sampling(mcmc_state, num_leapfrog_steps,
num_draws=n_draws,
initial_position=position,
step_size=step_size,
momentum_distribution=momentum_distribution,
trace_fn=trace_fn)
return mcmc_state.bijector.inverse(draws), mcmc_state
# @title Write down model
num_schools = 8 # number of schools
treatment_effects = tf.constant([28., 8, -3, 7, -1, 1, 18, 12]) # treatment effects
treatment_stddevs = tf.constant([15., 10, 16, 11, 9, 11, 10, 18]) # treatment SE
eight_schools = tfd.JointDistributionSequential([
tfd.Normal(0., 5., name='avg_effect'),
tfd.HalfNormal(5., name='avg_stddev'),
tfd.Independent(tfd.Normal(loc=tf.zeros(num_schools),
scale=tf.ones(num_schools)),
reinterpreted_batch_ndims=1,
name='school_effects_std'),
lambda school_effects_std, avg_stddev, avg_effect: tfd.Independent(
tfd.Normal(loc=(avg_effect[..., tf.newaxis] +
avg_stddev[..., tf.newaxis] * school_effects_std),
scale=treatment_stddevs),
reinterpreted_batch_ndims=1,
name='treatment_effects')
])
%time draws, mcmc_state = sample_model(500, eight_schools, n_chains=64, treatment_effects=treatment_effects)
Fast window 75 Slow window 25 Slow window 50 Slow window 100 Slow window 200 Fast window 75 main draws 500 CPU times: user 43.7 s, sys: 557 ms, total: 44.3 s Wall time: 44.4 s
draws_dict = {k[0]: tf.einsum('ij...->ji...', v) for v, k in zip(draws, eight_schools.resolve_graph())}
idata = az.from_dict(posterior=draws_dict)
idata
<xarray.Dataset>
Dimensions: (chain: 64, draw: 500, school_effects_std_dim_0: 8)
Coordinates:
* chain (chain) int64 0 1 2 3 4 5 6 ... 58 59 60 61 62 63
* draw (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499
* school_effects_std_dim_0 (school_effects_std_dim_0) int64 0 1 2 3 4 5 6 7
Data variables:
avg_effect (chain, draw) float32 7.079 2.308 ... -0.4771
avg_stddev (chain, draw) float32 1.502 3.871 ... 0.5463 4.895
school_effects_std (chain, draw, school_effects_std_dim_0) float32 ...
Attributes:
created_at: 2021-01-15T02:18:28.691877
arviz_version: 0.10.0array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62, 63])array([ 0, 1, 2, ..., 497, 498, 499])
array([0, 1, 2, 3, 4, 5, 6, 7])
array([[ 7.0788116 , 2.3078513 , 6.218813 , ..., 4.104283 ,
4.104283 , 2.5090067 ],
[ 4.6245904 , 4.989573 , 5.448641 , ..., 7.8891077 ,
2.8665342 , 5.024694 ],
[ 8.196905 , 2.4824786 , 6.5024233 , ..., 0.300166 ,
8.502737 , 8.502737 ],
...,
[ 5.176254 , 4.115052 , 4.2281404 , ..., -1.5719069 ,
6.8167133 , 3.3728285 ],
[ 2.390427 , 3.705419 , 2.136839 , ..., 2.8054097 ,
5.3932285 , 4.024034 ],
[ 5.0035014 , 5.9109254 , 4.853741 , ..., 4.0620275 ,
9.756169 , -0.47705954]], dtype=float32)array([[1.5024893 , 3.8710914 , 1.4584293 , ..., 2.2244368 , 2.2244368 ,
5.46564 ],
[1.781415 , 1.2311794 , 2.9173796 , ..., 6.963569 , 0.43236175,
5.9245753 ],
[0.41235757, 5.898891 , 1.1255028 , ..., 0.19575526, 4.7315383 ,
4.7315383 ],
...,
[4.8918915 , 1.2469933 , 5.0602827 , ..., 4.8051996 , 7.4627056 ,
0.8351277 ],
[8.296053 , 0.15603231, 8.389772 , ..., 0.3525195 , 2.7638435 ,
3.2795002 ],
[2.713871 , 1.1858801 , 2.6325293 , ..., 5.948521 , 0.54633164,
4.8954887 ]], dtype=float32)array([[[-7.5678056e-01, 1.5136632e+00, 1.0098728e+00, ...,
9.1826427e-01, 1.0194567e+00, -2.9645830e-01],
[ 9.6156436e-01, -4.8339173e-01, -7.6111543e-01, ...,
-8.2477254e-01, -3.9493623e-01, 3.9612591e-01],
[-5.1743311e-01, 4.7848839e-01, 7.7426797e-01, ...,
1.1259332e+00, 1.1408943e+00, -6.4701039e-01],
...,
[ 7.5896579e-01, -1.7917792e+00, -1.2358522e+00, ...,
-9.6536624e-01, 3.3381122e-01, 6.6004008e-01],
[ 7.5896579e-01, -1.7917792e+00, -1.2358522e+00, ...,
-9.6536624e-01, 3.3381122e-01, 6.6004008e-01],
[ 8.7087609e-02, 2.0222642e+00, 8.3010525e-01, ...,
1.3746994e+00, 7.8416055e-01, 6.0144287e-01]],
[[ 2.8639491e+00, -9.5849824e-01, -1.1291058e-01, ...,
7.0749885e-01, 7.3885274e-01, 3.3133295e-01],
[-3.4066936e-01, -3.0256334e-01, -5.4612100e-01, ...,
-8.2828969e-01, 7.8666493e-02, -5.5600899e-01],
[-1.1794885e-02, 2.0631090e-02, -5.2197599e-01, ...,
3.6153430e-01, 3.8441089e-01, 7.0412463e-01],
...
[ 1.8781780e+00, -4.5630848e-01, 4.5320792e-03, ...,
1.5558955e-01, -1.6126698e+00, 5.5947238e-01],
[-1.1488316e+00, 7.2944903e-01, 2.3380600e-01, ...,
-2.7923912e-01, 2.5858331e+00, -5.8528036e-01],
[ 1.3233609e+00, -6.2065923e-01, 8.8545926e-02, ...,
2.2715415e-01, -1.9690809e+00, 5.9718490e-01]],
[[ 9.9132824e-01, 1.2532596e-01, -5.8151370e-01, ...,
5.9717876e-01, 6.1482084e-01, 1.3174435e+00],
[-7.6111144e-01, -3.4393960e-01, 1.1935511e+00, ...,
-5.8915555e-01, -6.5500718e-01, -1.8766115e+00],
[ 9.5238012e-01, -1.0442076e-01, -4.1634363e-01, ...,
2.3426503e-01, 1.0287771e+00, 8.4116507e-01],
...,
[ 2.3942890e+00, 5.0510818e-01, -1.5954772e+00, ...,
7.9673392e-01, 9.4884455e-01, -6.3506365e-01],
[-1.3257375e+00, 1.9409730e-01, -2.2763357e-01, ...,
-2.5343189e+00, -1.5353709e+00, 2.9405826e-01],
[ 1.7137862e+00, -2.9174909e-01, 8.7251282e-01, ...,
2.0113077e+00, 1.6389089e+00, 1.2290729e-01]]], dtype=float32)# Only plotting the first 4 chains
az.plot_trace(idata.sel(chain=[0, 1, 2, 3]), combined=False, compact=True);
# @title TFP Results
az.summary(idata)
| mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_mean | ess_sd | ess_bulk | ess_tail | r_hat | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| avg_effect | 4.445 | 3.227 | -1.634 | 10.531 | 0.020 | 0.018 | 25183.0 | 16096.0 | 25111.0 | 16020.0 | 1.01 |
| avg_stddev | 3.330 | 2.519 | 0.000 | 7.847 | 0.018 | 0.014 | 20710.0 | 15395.0 | 26998.0 | 16036.0 | 1.00 |
| school_effects_std[0] | 0.314 | 0.985 | -1.540 | 2.144 | 0.007 | 0.009 | 18649.0 | 5716.0 | 18583.0 | 11180.0 | 1.01 |
| school_effects_std[1] | 0.092 | 0.955 | -1.751 | 1.851 | 0.006 | 0.009 | 25276.0 | 6109.0 | 25338.0 | 13053.0 | 1.01 |
| school_effects_std[2] | -0.088 | 0.952 | -1.822 | 1.762 | 0.005 | 0.010 | 30306.0 | 4443.0 | 30298.0 | 14300.0 | 1.01 |
| school_effects_std[3] | 0.056 | 0.941 | -1.677 | 1.878 | 0.005 | 0.010 | 36250.0 | 4021.0 | 36117.0 | 11667.0 | 1.01 |
| school_effects_std[4] | -0.169 | 0.930 | -1.940 | 1.559 | 0.006 | 0.008 | 23681.0 | 6330.0 | 23750.0 | 13245.0 | 1.01 |
| school_effects_std[5] | -0.076 | 0.952 | -1.847 | 1.708 | 0.006 | 0.009 | 29886.0 | 5374.0 | 29891.0 | 13012.0 | 1.01 |
| school_effects_std[6] | 0.356 | 0.964 | -1.462 | 2.164 | 0.006 | 0.009 | 23934.0 | 6346.0 | 23924.0 | 13831.0 | 1.01 |
| school_effects_std[7] | 0.065 | 0.975 | -1.780 | 1.908 | 0.007 | 0.013 | 17462.0 | 2857.0 | 17581.0 | 4372.0 | 1.01 |
# @title PyMC3 NUTS results
data = az.load_arviz_data("non_centered_eight")
az.summary(data, var_names=['mu', 'tau', 'theta_t'])
| mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_mean | ess_sd | ess_bulk | ess_tail | r_hat | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| mu | 4.494 | 3.286 | -2.187 | 10.201 | 0.068 | 0.053 | 2344.0 | 1896.0 | 2354.0 | 1401.0 | 1.00 |
| tau | 3.447 | 2.915 | 0.003 | 8.603 | 0.074 | 0.052 | 1556.0 | 1556.0 | 1268.0 | 900.0 | 1.00 |
| theta_t[0] | 0.338 | 0.997 | -1.587 | 2.112 | 0.021 | 0.023 | 2216.0 | 955.0 | 2215.0 | 1450.0 | 1.00 |
| theta_t[1] | 0.108 | 0.923 | -1.665 | 1.734 | 0.017 | 0.022 | 3109.0 | 900.0 | 3159.0 | 1514.0 | 1.00 |
| theta_t[2] | -0.087 | 0.948 | -1.799 | 1.661 | 0.018 | 0.023 | 2911.0 | 871.0 | 2926.0 | 1530.0 | 1.00 |
| theta_t[3] | 0.092 | 0.976 | -1.762 | 1.908 | 0.019 | 0.025 | 2524.0 | 783.0 | 2515.0 | 1237.0 | 1.00 |
| theta_t[4] | -0.194 | 0.940 | -1.900 | 1.512 | 0.020 | 0.022 | 2289.0 | 931.0 | 2313.0 | 1487.0 | 1.00 |
| theta_t[5] | -0.044 | 0.954 | -1.902 | 1.610 | 0.019 | 0.023 | 2560.0 | 839.0 | 2553.0 | 1464.0 | 1.00 |
| theta_t[6] | 0.312 | 0.956 | -1.609 | 2.070 | 0.018 | 0.021 | 2668.0 | 1004.0 | 2678.0 | 1207.0 | 1.00 |
| theta_t[7] | 0.061 | 0.935 | -1.620 | 1.794 | 0.019 | 0.021 | 2525.0 | 966.0 | 2522.0 | 1589.0 | 1.01 |
az.plot_trace(data, var_names=['mu', 'tau', 'theta_t'], compact=True);
See this github issue which references this blog post for the setup here, and what the situation used to looked like.
np.random.seed(0)
ndims = 5
ndata = 100
X = np.random.randn(ndata, ndims).astype(np.float32)
w_ = np.random.randn(ndims).astype(np.float32) # hidden
noise_ = 0.1 * np.random.randn(ndata).astype(np.float32) # hidden
y_obs = X.dot(w_) + noise_
# @title Write down model
linear_model = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(tf.zeros(ndims), tf.ones(ndims), name='w'),
lambda w: tfd.Normal(tf.linalg.matvec(X, w), 0.1, name='y')
])
%time draws, mcmc_state = sample_model(500, linear_model, n_chains=256, y=y_obs)
WARNING:tensorflow:Note that RandomStandardNormal inside pfor op may not give same output as inside a sequential loop. WARNING:tensorflow:Note that RandomStandardNormal inside pfor op may not give same output as inside a sequential loop. Fast window 75 Slow window 25 WARNING:tensorflow:5 out of the last 5 calls to <function slow_window at 0x7f9456031ea0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. Slow window 50 WARNING:tensorflow:6 out of the last 6 calls to <function slow_window at 0x7f9456031ea0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. Slow window 100 Slow window 200 Fast window 75 main draws 500 CPU times: user 42.9 s, sys: 583 ms, total: 43.4 s Wall time: 43.6 s
draws_dict = {k[0]: tf.einsum('ij...->ji...', v) for v, k in zip(draws, linear_model.resolve_graph())}
idata = az.from_dict(posterior=draws_dict)
idata
<xarray.Dataset>
Dimensions: (chain: 256, draw: 500, w_dim_0: 5)
Coordinates:
* chain (chain) int64 0 1 2 3 4 5 6 7 8 ... 248 249 250 251 252 253 254 255
* draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
* w_dim_0 (w_dim_0) int64 0 1 2 3 4
Data variables:
w (chain, draw, w_dim_0) float32 0.3728 -0.03278 ... -0.2354 -0.3461
Attributes:
created_at: 2021-01-15T02:19:18.402213
arviz_version: 0.10.0array([ 0, 1, 2, ..., 253, 254, 255])
array([ 0, 1, 2, ..., 497, 498, 499])
array([0, 1, 2, 3, 4])
array([[[ 0.3727702 , -0.03277868, 1.0939071 , -0.23079605,
-0.3316637 ],
[ 0.3810674 , -0.03298619, 1.0905781 , -0.24290608,
-0.33249587],
[ 0.3738202 , -0.03141573, 1.0911988 , -0.24582317,
-0.34059626],
...,
[ 0.37776172, -0.05401355, 1.1089557 , -0.22327006,
-0.34600857],
[ 0.37776172, -0.05401355, 1.1089557 , -0.22327006,
-0.34600857],
[ 0.37776172, -0.05401355, 1.1089557 , -0.22327006,
-0.34600857]],
[[ 0.38400292, -0.0442952 , 1.0887545 , -0.24631353,
-0.3588797 ],
[ 0.38400292, -0.0442952 , 1.0887545 , -0.24631353,
-0.3588797 ],
[ 0.3555265 , -0.01121297, 1.0989027 , -0.23311952,
-0.32547614],
...
[ 0.36664015, -0.03588358, 1.0885139 , -0.23332319,
-0.3314201 ],
[ 0.36664015, -0.03588358, 1.0885139 , -0.23332319,
-0.3314201 ],
[ 0.37554982, -0.06173813, 1.1038246 , -0.2369338 ,
-0.32966366]],
[[ 0.38092664, -0.03824486, 1.0961331 , -0.21479334,
-0.3133499 ],
[ 0.36748907, -0.03673361, 1.1004893 , -0.23872176,
-0.34873006],
[ 0.3755664 , -0.03839795, 1.0961144 , -0.2433345 ,
-0.3399594 ],
...,
[ 0.37863496, -0.0389508 , 1.0842844 , -0.23542134,
-0.34613436],
[ 0.37863496, -0.0389508 , 1.0842844 , -0.23542134,
-0.34613436],
[ 0.37863496, -0.0389508 , 1.0842844 , -0.23542134,
-0.34613436]]], dtype=float32)# Only plotting the first 4 chains
axes = az.plot_trace(idata.sel(chain=[0, 1, 2, 3]))
for idx, (true_val, ax) in enumerate(zip(w_, axes)):
ax[0].axvline(true_val, color=f'C{idx}', linestyle='dashed');
az.summary(idata)
| mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_mean | ess_sd | ess_bulk | ess_tail | r_hat | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| w[0] | 0.373 | 0.010 | 0.354 | 0.391 | 0.0 | 0.0 | 3424.0 | 3424.0 | 3429.0 | 13667.0 | 1.05 |
| w[1] | -0.036 | 0.011 | -0.055 | -0.015 | 0.0 | 0.0 | 15409.0 | 15409.0 | 15426.0 | 36882.0 | 1.01 |
| w[2] | 1.095 | 0.010 | 1.076 | 1.114 | 0.0 | 0.0 | 14131.0 | 14131.0 | 14139.0 | 26735.0 | 1.01 |
| w[3] | -0.234 | 0.010 | -0.252 | -0.216 | 0.0 | 0.0 | 5990.0 | 5988.0 | 5991.0 | 19865.0 | 1.03 |
| w[4] | -0.339 | 0.010 | -0.359 | -0.320 | 0.0 | 0.0 | 5707.0 | 5707.0 | 5708.0 | 23421.0 | 1.03 |