#@title ##### Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" } # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Importing the TFP with Jax backend !pip3 install -q 'tfp-nightly[jax]' tf-nightly-cpu # We (currently) still require TF, but TF's smaller CPU build will work. import tensorflow_probability as tfp tfp = tfp.experimental.substrates.jax tf = tfp.tf2jax # Standard TFP Imports tfd = tfp.distributions tfb = tfp.bijectors tfpk = tfp.math.psd_kernels # Jax imports import jax import jax.numpy as np from jax import random # Other imports import matplotlib.pyplot as plt import seaborn as sns sns.set(style='white') tf.ones(5) tf.matmul(tf.ones([1, 2]), tf.ones([2, 4])) tf.ones(5).shape tf.random.stateless_uniform([1, 2], seed=random.PRNGKey(0)) tf.compat.v1.placeholder_with_default(tf.ones(5), (5,)) bij = tfb.Shift(1.)(tfb.Scale(3.)) print(bij.forward(np.ones(5))) print(bij.inverse(np.ones(5))) b = tfb.FillScaleTriL(diag_bijector=tfb.Exp(), diag_shift=None) print(b.forward(x=[0., 0., 0.])) print(b.inverse(y=[[1., 0], [.5, 2]])) b = tfb.Chain([tfb.Exp(), tfb.Softplus()]) # or: # b = tfb.Exp()(tfb.Softplus()) print(b.forward(-np.ones(5))) dist = tfd.Normal(loc=0., scale=1.) print(dist.sample(seed=random.PRNGKey(0))) dist = tfd.Normal(np.zeros(5), np.ones(5)) s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0)) print(dist.log_prob(s).shape) dist = tfd.Independent(tfd.Normal(np.zeros(5), np.ones(5)), 1) s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0)) print(dist.log_prob(s).shape) dist = tfd.TransformedDistribution( tfd.MultivariateNormalDiag(tf.zeros(5), tf.ones(5)), tfb.Exp()) # or: # dist = tfb.Exp()(tfd.MultivariateNormalDiag(tf.zeros(5), tf.ones(5))) s = dist.sample(sample_shape=2, seed=random.PRNGKey(0)) print(s) print(dist.log_prob(s).shape) k1, k2, k3 = random.split(random.PRNGKey(0), 3) observation_noise_variance = 0.01 f = lambda x: np.sin(10*x[..., 0]) * np.exp(-x[..., 0]**2) observation_index_points = tf.random.stateless_uniform( [50], minval=-1.,maxval= 1., seed=k1)[..., np.newaxis] observations = f(observation_index_points) + tfd.Normal(loc=0., scale=np.sqrt(observation_noise_variance)).sample(seed=k2) index_points = np.linspace(-1., 1., 100)[..., np.newaxis] kernel = tfpk.ExponentiatedQuadratic(length_scale=0.1) gprm = tfd.GaussianProcessRegressionModel( kernel=kernel, index_points=index_points, observation_index_points=observation_index_points, observations=observations, observation_noise_variance=observation_noise_variance) samples = gprm.sample(10, seed=k3) for i in range(10): plt.plot(index_points, samples[i]) plt.show()