Licensed under the Apache License, Version 2.0 (the "License");
#@title ##### Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Run in Google Colab
|
View source on GitHub
|
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
from discussion import nn as tfp_nn
# Globally Enable XLA.
# tf.config.optimizer.set_jit(True)
try:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
tfb = tfp.bijectors
tfd = tfp.distributions
[train_dataset, eval_dataset], datasets_info = tfds.load(
name='mnist',
split=['train', 'test'],
with_info=True,
as_supervised=True,
shuffle_files=True)
def _preprocess(image, label):
image = tf.cast(image, dtype=tf.float32) / 255.
label = tf.cast(label, dtype=tf.int32)
return image, label
train_size = datasets_info.splits['train'].num_examples
batch_size = 32
train_dataset = tfp_nn.util.tune_dataset(
train_dataset,
batch_size=batch_size,
shuffle_size=int(train_size / 7),
preprocess_fn=_preprocess)
eval_dataset = tfp_nn.util.tune_dataset(
eval_dataset,
repeat_count=None,
preprocess_fn=_preprocess)
x, y = next(iter(eval_dataset.batch(10)))
tfp_nn.util.display_imgs(x, y);
max_pool = tf.keras.layers.MaxPooling2D( # Has no tf.Variables.
pool_size=(2, 2),
strides=(2, 2),
padding='SAME',
data_format='channels_last')
bnn = tfp_nn.Sequential([
tfp_nn.ConvolutionVariationalFlipout(
input_size=1,
output_size=8,
filter_shape=5,
padding='SAME',
penalty_weight=1. / train_size,
name='conv1'),
tf.nn.leaky_relu,
max_pool, # [28, 28, 8] -> [14, 14, 8]
tfp_nn.ConvolutionVariationalFlipout(
input_size=8,
output_size=16,
filter_shape=5,
padding='SAME',
penalty_weight = 1. / train_size,
name='conv2'),
tf.nn.leaky_relu,
max_pool, # [14, 14, 16] -> [7, 7, 16]
tfp_nn.util.flatten_rightmost,
tfp_nn.AffineVariationalReparameterizationLocal(
input_size=7 * 7 * 16,
output_size=10,
penalty_weight = 1. / train_size,
name='affine1'),
lambda x: tfd.Categorical(logits=x, dtype=tf.int32),
], name='BNN')
print(bnn.summary())
=== BNN ==================================================
SIZE SHAPE TRAIN NAME
200 [5, 5, 1, 8] True posterior_kernel_loc:0
200 [5, 5, 1, 8] True posterior_kernel_scale:0
8 [8] True posterior_bias_loc:0
8 [8] True posterior_bias_scale:0
3200 [5, 5, 8, 16] True posterior_kernel_loc:0
3200 [5, 5, 8, 16] True posterior_kernel_scale:0
16 [16] True posterior_bias_loc:0
16 [16] True posterior_bias_scale:0
7840 [784, 10] True posterior_kernel_loc:0
7840 [784, 10] True posterior_kernel_scale:0
10 [10] True posterior_bias_loc:0
10 [10] True posterior_bias_scale:0
trainable size: 22548 / 0.086 MiB / {float32: 22548}
train_iter = iter(train_dataset)
eval_iter = iter(eval_dataset.batch(2000).repeat())
def loss():
x, y = next(train_iter)
nll = -tf.reduce_mean(bnn(x).log_prob(y), axis=-1)
kl = bnn.extra_loss
return nll + kl, (nll, kl)
opt = tf.optimizers.Adam(learning_rate=1e-2)
fit = tfp_nn.util.make_fit_op(
loss,
opt,
bnn.trainable_variables,
grad_summary_fn=lambda gs: tf.nest.map_structure(tf.norm, gs))
@tf.function(autograph=False)
def eval():
with tf.xla.experimental.jit_scope(compile_ops=True):
x, y = next(eval_iter)
yhat = bnn(x)
nll = -tf.reduce_mean(yhat.log_prob(y))
kl = bnn.extra_loss
loss = nll + kl
acc = tf.reduce_mean(tf.cast(tf.equal(y, yhat.mode()), tf.float32), axis=-1)
return loss, acc, nll, kl
num_train_steps = 20e3 # @param { isTemplate: true}
num_train_steps = int(num_train_steps) # Enforce correct type when overridden.
dur_sec = dur_num = 0
for i in range(num_train_steps):
start = time.time()
trn_loss, (trn_nll, trn_kl), g = fit()
stop = time.time()
dur_sec += stop - start
dur_num += 1
if i % 100 == 0 or i == num_train_steps - 1:
tst_loss, tst_acc, tst_nll, tst_kl = eval()
f, x = zip(*[
('it:{:5}', opt.iterations),
('ms/it:{:6.4f}', dur_sec / max(1., dur_num) * 1000.),
('tst_acc:{:6.4f}', tst_acc),
('trn_loss:{:6.4f}', trn_loss),
('tst_loss:{:6.4f}', tst_loss),
('tst_nll:{:6.4f}', tst_nll),
('tst_kl:{:6.4f}', tst_kl),
('sum_norm_grad:{:6.4f}', sum(g)),
])
print(' '.join(f).format(*[getattr(x_, 'numpy', lambda: x_)()
for x_ in x]))
sys.stdout.flush()
dur_sec = dur_num = 0
# if i % 1000 == 0 or i == maxiter - 1:
# bnn.save('/tmp/bnn.npz')
it: 1 ms/it:15928.3569 tst_acc:0.0960 trn_loss:1867.5190 tst_loss:1572.2639 tst_nll:1572.1337 tst_kl:0.1303 sum_norm_grad:1348.1011 it: 101 ms/it:3.9539 tst_acc:0.1085 trn_loss:141.2893 tst_loss:110.9031 tst_nll:110.7626 tst_kl:0.1405 sum_norm_grad:164.9449 it: 201 ms/it:3.9313 tst_acc:0.1215 trn_loss:77.2260 tst_loss:60.6951 tst_nll:60.5451 tst_kl:0.1500 sum_norm_grad:142.0284 it: 301 ms/it:3.8703 tst_acc:0.1155 trn_loss:51.5968 tst_loss:83.7521 tst_nll:83.5882 tst_kl:0.1640 sum_norm_grad:67.0755 it: 401 ms/it:3.8302 tst_acc:0.1625 trn_loss:47.8995 tst_loss:32.7923 tst_nll:32.6151 tst_kl:0.1771 sum_norm_grad:145.9635 it: 501 ms/it:3.7184 tst_acc:0.2150 trn_loss:41.3301 tst_loss:39.1208 tst_nll:38.9311 tst_kl:0.1897 sum_norm_grad:57.5402 it: 601 ms/it:4.1285 tst_acc:0.2480 trn_loss:23.3938 tst_loss:19.6674 tst_nll:19.4679 tst_kl:0.1995 sum_norm_grad:40.7257 it: 701 ms/it:3.8126 tst_acc:0.2775 trn_loss:19.2859 tst_loss:18.2370 tst_nll:18.0254 tst_kl:0.2116 sum_norm_grad:54.5467 it: 801 ms/it:3.7836 tst_acc:0.3325 trn_loss:8.0900 tst_loss:13.7760 tst_nll:13.5559 tst_kl:0.2201 sum_norm_grad:33.2975 it: 901 ms/it:3.8657 tst_acc:0.4215 trn_loss:7.8182 tst_loss:10.4693 tst_nll:10.2390 tst_kl:0.2303 sum_norm_grad:19.9062 it: 1001 ms/it:3.8200 tst_acc:0.3955 trn_loss:15.5580 tst_loss:9.7859 tst_nll:9.5479 tst_kl:0.2381 sum_norm_grad:41.5614 it: 1101 ms/it:3.9202 tst_acc:0.5750 trn_loss:7.6126 tst_loss:6.6138 tst_nll:6.3675 tst_kl:0.2463 sum_norm_grad:29.4877 it: 1201 ms/it:3.8557 tst_acc:0.4725 trn_loss:8.7019 tst_loss:7.1943 tst_nll:6.9418 tst_kl:0.2525 sum_norm_grad:45.5732 it: 1301 ms/it:3.8773 tst_acc:0.5930 trn_loss:8.4221 tst_loss:5.7505 tst_nll:5.4909 tst_kl:0.2595 sum_norm_grad:23.3249 it: 1401 ms/it:3.8556 tst_acc:0.6755 trn_loss:6.1047 tst_loss:4.1688 tst_nll:3.9022 tst_kl:0.2665 sum_norm_grad:20.6842 it: 1501 ms/it:3.8319 tst_acc:0.6690 trn_loss:2.4749 tst_loss:3.8920 tst_nll:3.6198 tst_kl:0.2722 sum_norm_grad:11.9391 it: 1601 ms/it:3.8669 tst_acc:0.6750 trn_loss:3.2286 tst_loss:3.8155 tst_nll:3.5400 tst_kl:0.2755 sum_norm_grad:14.1307 it: 1701 ms/it:3.9143 tst_acc:0.7065 trn_loss:1.5178 tst_loss:3.1960 tst_nll:2.9166 tst_kl:0.2794 sum_norm_grad:7.7101 it: 1801 ms/it:4.0725 tst_acc:0.7675 trn_loss:2.8735 tst_loss:3.4984 tst_nll:3.2174 tst_kl:0.2809 sum_norm_grad:15.7313 it: 1901 ms/it:4.1345 tst_acc:0.7825 trn_loss:2.0038 tst_loss:2.5639 tst_nll:2.2759 tst_kl:0.2879 sum_norm_grad:9.2704 it: 2001 ms/it:4.0296 tst_acc:0.7815 trn_loss:1.4074 tst_loss:2.5748 tst_nll:2.2832 tst_kl:0.2916 sum_norm_grad:9.1460 it: 2101 ms/it:3.8173 tst_acc:0.8030 trn_loss:2.5845 tst_loss:2.1998 tst_nll:1.9072 tst_kl:0.2926 sum_norm_grad:8.6494 it: 2201 ms/it:3.9169 tst_acc:0.7925 trn_loss:1.4358 tst_loss:2.0969 tst_nll:1.8011 tst_kl:0.2959 sum_norm_grad:8.2437 it: 2301 ms/it:4.0515 tst_acc:0.7965 trn_loss:1.1560 tst_loss:1.8291 tst_nll:1.5286 tst_kl:0.3005 sum_norm_grad:7.8903 it: 2401 ms/it:3.8685 tst_acc:0.8605 trn_loss:3.1983 tst_loss:1.5584 tst_nll:1.2568 tst_kl:0.3016 sum_norm_grad:18.2812 it: 2501 ms/it:4.0031 tst_acc:0.8610 trn_loss:2.8698 tst_loss:1.4478 tst_nll:1.1429 tst_kl:0.3049 sum_norm_grad:12.8738 it: 2601 ms/it:3.8913 tst_acc:0.8395 trn_loss:1.6629 tst_loss:1.5258 tst_nll:1.2204 tst_kl:0.3054 sum_norm_grad:7.3639 it: 2701 ms/it:3.8980 tst_acc:0.8655 trn_loss:1.0928 tst_loss:1.4976 tst_nll:1.1868 tst_kl:0.3108 sum_norm_grad:6.4518 it: 2801 ms/it:3.9551 tst_acc:0.8525 trn_loss:1.4046 tst_loss:1.5539 tst_nll:1.2425 tst_kl:0.3114 sum_norm_grad:8.3178 it: 2901 ms/it:3.8766 tst_acc:0.8735 trn_loss:1.1183 tst_loss:1.3143 tst_nll:0.9988 tst_kl:0.3155 sum_norm_grad:7.0501 it: 3001 ms/it:3.8854 tst_acc:0.8670 trn_loss:1.7738 tst_loss:1.3476 tst_nll:1.0324 tst_kl:0.3152 sum_norm_grad:12.6129 it: 3101 ms/it:4.0011 tst_acc:0.8825 trn_loss:0.7332 tst_loss:1.2051 tst_nll:0.8894 tst_kl:0.3157 sum_norm_grad:4.8421 it: 3201 ms/it:3.9724 tst_acc:0.8675 trn_loss:1.2539 tst_loss:1.3021 tst_nll:0.9816 tst_kl:0.3205 sum_norm_grad:7.0985 it: 3301 ms/it:4.1110 tst_acc:0.8570 trn_loss:0.8448 tst_loss:1.1390 tst_nll:0.8192 tst_kl:0.3198 sum_norm_grad:6.2651 it: 3401 ms/it:3.9468 tst_acc:0.8725 trn_loss:2.2540 tst_loss:1.0749 tst_nll:0.7523 tst_kl:0.3226 sum_norm_grad:10.1086 it: 3501 ms/it:4.0869 tst_acc:0.8835 trn_loss:0.6078 tst_loss:1.0931 tst_nll:0.7682 tst_kl:0.3248 sum_norm_grad:4.0761 it: 3601 ms/it:3.9498 tst_acc:0.8800 trn_loss:0.9377 tst_loss:1.0535 tst_nll:0.7263 tst_kl:0.3272 sum_norm_grad:7.7651 it: 3701 ms/it:3.9123 tst_acc:0.8770 trn_loss:1.5083 tst_loss:1.0264 tst_nll:0.6982 tst_kl:0.3281 sum_norm_grad:6.2144 it: 3801 ms/it:3.9928 tst_acc:0.8855 trn_loss:0.4369 tst_loss:0.9467 tst_nll:0.6162 tst_kl:0.3305 sum_norm_grad:2.8031 it: 3901 ms/it:3.9239 tst_acc:0.8925 trn_loss:0.7663 tst_loss:1.0723 tst_nll:0.7410 tst_kl:0.3312 sum_norm_grad:5.9763 it: 4001 ms/it:4.0469 tst_acc:0.8925 trn_loss:0.5692 tst_loss:0.9270 tst_nll:0.5935 tst_kl:0.3334 sum_norm_grad:2.8922 it: 4101 ms/it:3.7619 tst_acc:0.9060 trn_loss:0.8710 tst_loss:0.8875 tst_nll:0.5505 tst_kl:0.3369 sum_norm_grad:6.1845 it: 4201 ms/it:4.1047 tst_acc:0.8965 trn_loss:0.4071 tst_loss:0.9123 tst_nll:0.5755 tst_kl:0.3368 sum_norm_grad:2.1311 it: 4301 ms/it:3.9728 tst_acc:0.8850 trn_loss:0.9103 tst_loss:0.9304 tst_nll:0.5917 tst_kl:0.3387 sum_norm_grad:4.9114 it: 4401 ms/it:3.9363 tst_acc:0.9130 trn_loss:1.7282 tst_loss:0.7744 tst_nll:0.4357 tst_kl:0.3387 sum_norm_grad:6.8361 it: 4501 ms/it:3.8017 tst_acc:0.9155 trn_loss:1.2034 tst_loss:0.8359 tst_nll:0.4951 tst_kl:0.3408 sum_norm_grad:7.6702 it: 4601 ms/it:4.0642 tst_acc:0.9120 trn_loss:0.9882 tst_loss:0.8646 tst_nll:0.5207 tst_kl:0.3439 sum_norm_grad:9.5124 it: 4701 ms/it:3.8952 tst_acc:0.9100 trn_loss:1.1729 tst_loss:0.8813 tst_nll:0.5387 tst_kl:0.3426 sum_norm_grad:4.8135 it: 4801 ms/it:3.7730 tst_acc:0.9160 trn_loss:0.4045 tst_loss:0.7455 tst_nll:0.4022 tst_kl:0.3433 sum_norm_grad:2.2956 it: 4901 ms/it:3.6960 tst_acc:0.9075 trn_loss:0.7228 tst_loss:0.8334 tst_nll:0.4903 tst_kl:0.3431 sum_norm_grad:5.9699 it: 5001 ms/it:3.7376 tst_acc:0.9210 trn_loss:0.5268 tst_loss:0.7140 tst_nll:0.3685 tst_kl:0.3455 sum_norm_grad:4.6464 it: 5101 ms/it:3.8780 tst_acc:0.9370 trn_loss:0.8155 tst_loss:0.7131 tst_nll:0.3670 tst_kl:0.3461 sum_norm_grad:5.6356 it: 5201 ms/it:3.6609 tst_acc:0.9345 trn_loss:0.7713 tst_loss:0.7443 tst_nll:0.3966 tst_kl:0.3477 sum_norm_grad:4.9599 it: 5301 ms/it:3.9859 tst_acc:0.9215 trn_loss:1.0918 tst_loss:0.7639 tst_nll:0.4144 tst_kl:0.3496 sum_norm_grad:6.9639 it: 5401 ms/it:3.8696 tst_acc:0.9305 trn_loss:0.7758 tst_loss:0.6449 tst_nll:0.2954 tst_kl:0.3495 sum_norm_grad:5.4554 it: 5501 ms/it:3.8664 tst_acc:0.9115 trn_loss:1.4699 tst_loss:0.7773 tst_nll:0.4255 tst_kl:0.3518 sum_norm_grad:5.8081 it: 5601 ms/it:3.9588 tst_acc:0.9150 trn_loss:1.1072 tst_loss:0.7865 tst_nll:0.4364 tst_kl:0.3501 sum_norm_grad:6.9779 it: 5701 ms/it:3.9183 tst_acc:0.9170 trn_loss:0.7125 tst_loss:0.7487 tst_nll:0.3958 tst_kl:0.3530 sum_norm_grad:3.6534 it: 5801 ms/it:3.9031 tst_acc:0.9360 trn_loss:0.3692 tst_loss:0.7300 tst_nll:0.3763 tst_kl:0.3538 sum_norm_grad:0.9757 it: 5901 ms/it:4.0894 tst_acc:0.9205 trn_loss:0.5741 tst_loss:0.7402 tst_nll:0.3861 tst_kl:0.3542 sum_norm_grad:5.3837 it: 6001 ms/it:4.0470 tst_acc:0.9160 trn_loss:0.4374 tst_loss:0.7195 tst_nll:0.3669 tst_kl:0.3527 sum_norm_grad:3.5918 it: 6101 ms/it:3.8831 tst_acc:0.9325 trn_loss:0.4947 tst_loss:0.7233 tst_nll:0.3692 tst_kl:0.3541 sum_norm_grad:2.9750 it: 6201 ms/it:4.0328 tst_acc:0.9400 trn_loss:1.2672 tst_loss:0.6815 tst_nll:0.3261 tst_kl:0.3553 sum_norm_grad:6.6143 it: 6301 ms/it:3.8980 tst_acc:0.9380 trn_loss:0.5253 tst_loss:0.6950 tst_nll:0.3414 tst_kl:0.3536 sum_norm_grad:5.1917 it: 6401 ms/it:3.8323 tst_acc:0.9415 trn_loss:1.1518 tst_loss:0.5930 tst_nll:0.2386 tst_kl:0.3544 sum_norm_grad:3.3670 it: 6501 ms/it:4.0591 tst_acc:0.9255 trn_loss:1.2271 tst_loss:0.7223 tst_nll:0.3683 tst_kl:0.3541 sum_norm_grad:4.4156 it: 6601 ms/it:4.0571 tst_acc:0.9345 trn_loss:0.4177 tst_loss:0.6636 tst_nll:0.3090 tst_kl:0.3546 sum_norm_grad:1.8167 it: 6701 ms/it:3.8964 tst_acc:0.9510 trn_loss:1.6205 tst_loss:0.6293 tst_nll:0.2762 tst_kl:0.3531 sum_norm_grad:4.6092 it: 6801 ms/it:4.0024 tst_acc:0.9370 trn_loss:0.7933 tst_loss:0.6722 tst_nll:0.3198 tst_kl:0.3524 sum_norm_grad:5.5259 it: 6901 ms/it:3.9365 tst_acc:0.9400 trn_loss:0.4914 tst_loss:0.6975 tst_nll:0.3438 tst_kl:0.3538 sum_norm_grad:3.1655 it: 7001 ms/it:3.9592 tst_acc:0.9215 trn_loss:0.5987 tst_loss:0.7676 tst_nll:0.4150 tst_kl:0.3525 sum_norm_grad:3.5757 it: 7101 ms/it:3.8877 tst_acc:0.9335 trn_loss:1.0931 tst_loss:0.6915 tst_nll:0.3364 tst_kl:0.3551 sum_norm_grad:4.8359 it: 7201 ms/it:3.9275 tst_acc:0.9505 trn_loss:1.0986 tst_loss:0.5902 tst_nll:0.2369 tst_kl:0.3533 sum_norm_grad:6.8004 it: 7301 ms/it:4.0631 tst_acc:0.9575 trn_loss:0.5730 tst_loss:0.5933 tst_nll:0.2381 tst_kl:0.3551 sum_norm_grad:3.5288 it: 7401 ms/it:3.9955 tst_acc:0.9440 trn_loss:0.3934 tst_loss:0.6002 tst_nll:0.2470 tst_kl:0.3532 sum_norm_grad:1.1933 it: 7501 ms/it:3.8963 tst_acc:0.9455 trn_loss:0.8857 tst_loss:0.6165 tst_nll:0.2613 tst_kl:0.3553 sum_norm_grad:6.5109 it: 7601 ms/it:4.0886 tst_acc:0.9425 trn_loss:0.5190 tst_loss:0.6719 tst_nll:0.3193 tst_kl:0.3525 sum_norm_grad:4.7453 it: 7701 ms/it:3.7783 tst_acc:0.9530 trn_loss:0.9799 tst_loss:0.5732 tst_nll:0.2224 tst_kl:0.3509 sum_norm_grad:8.2814 it: 7801 ms/it:3.8785 tst_acc:0.9540 trn_loss:0.6683 tst_loss:0.6285 tst_nll:0.2801 tst_kl:0.3484 sum_norm_grad:4.8551 it: 7901 ms/it:4.0093 tst_acc:0.9555 trn_loss:0.5054 tst_loss:0.5828 tst_nll:0.2363 tst_kl:0.3465 sum_norm_grad:2.4672 it: 8001 ms/it:3.9008 tst_acc:0.9440 trn_loss:0.5357 tst_loss:0.6174 tst_nll:0.2674 tst_kl:0.3500 sum_norm_grad:4.9554 it: 8101 ms/it:4.0858 tst_acc:0.9560 trn_loss:0.7299 tst_loss:0.5604 tst_nll:0.2145 tst_kl:0.3459 sum_norm_grad:1.9412 it: 8201 ms/it:3.9650 tst_acc:0.9520 trn_loss:0.4379 tst_loss:0.6517 tst_nll:0.3039 tst_kl:0.3478 sum_norm_grad:3.0030 it: 8301 ms/it:4.0518 tst_acc:0.9520 trn_loss:1.3436 tst_loss:0.6454 tst_nll:0.2974 tst_kl:0.3480 sum_norm_grad:5.6112 it: 8401 ms/it:3.9783 tst_acc:0.9625 trn_loss:0.3490 tst_loss:0.5599 tst_nll:0.2135 tst_kl:0.3464 sum_norm_grad:0.2034 it: 8501 ms/it:3.9764 tst_acc:0.9540 trn_loss:0.6815 tst_loss:0.5576 tst_nll:0.2129 tst_kl:0.3447 sum_norm_grad:3.7303 it: 8601 ms/it:3.9883 tst_acc:0.9560 trn_loss:0.5182 tst_loss:0.5690 tst_nll:0.2242 tst_kl:0.3448 sum_norm_grad:3.3792 it: 8701 ms/it:3.8798 tst_acc:0.9315 trn_loss:0.8600 tst_loss:0.7801 tst_nll:0.4356 tst_kl:0.3445 sum_norm_grad:3.5876 it: 8801 ms/it:4.0914 tst_acc:0.9575 trn_loss:0.7861 tst_loss:0.5740 tst_nll:0.2328 tst_kl:0.3413 sum_norm_grad:3.4349 it: 8901 ms/it:3.9658 tst_acc:0.9500 trn_loss:0.4085 tst_loss:0.5710 tst_nll:0.2302 tst_kl:0.3407 sum_norm_grad:2.1678 it: 9001 ms/it:4.0408 tst_acc:0.9530 trn_loss:0.5396 tst_loss:0.5652 tst_nll:0.2243 tst_kl:0.3408 sum_norm_grad:3.5100 it: 9101 ms/it:3.9666 tst_acc:0.9515 trn_loss:0.3412 tst_loss:0.5904 tst_nll:0.2523 tst_kl:0.3381 sum_norm_grad:0.1454 it: 9201 ms/it:3.9417 tst_acc:0.9515 trn_loss:0.5546 tst_loss:0.5627 tst_nll:0.2242 tst_kl:0.3385 sum_norm_grad:4.2297 it: 9301 ms/it:3.8457 tst_acc:0.9535 trn_loss:0.6075 tst_loss:0.5458 tst_nll:0.2080 tst_kl:0.3378 sum_norm_grad:4.9227 it: 9401 ms/it:3.9689 tst_acc:0.9375 trn_loss:0.4993 tst_loss:0.6201 tst_nll:0.2855 tst_kl:0.3346 sum_norm_grad:2.9346 it: 9501 ms/it:4.0548 tst_acc:0.9510 trn_loss:1.0234 tst_loss:0.5534 tst_nll:0.2223 tst_kl:0.3312 sum_norm_grad:5.5114 it: 9601 ms/it:3.9514 tst_acc:0.9580 trn_loss:0.6258 tst_loss:0.5995 tst_nll:0.2681 tst_kl:0.3314 sum_norm_grad:6.6073 it: 9701 ms/it:3.8920 tst_acc:0.9400 trn_loss:0.4452 tst_loss:0.6250 tst_nll:0.2959 tst_kl:0.3291 sum_norm_grad:3.0340 it: 9801 ms/it:4.0399 tst_acc:0.9580 trn_loss:0.5026 tst_loss:0.5449 tst_nll:0.2179 tst_kl:0.3270 sum_norm_grad:3.2578 it: 9901 ms/it:4.0137 tst_acc:0.9585 trn_loss:0.3263 tst_loss:0.5291 tst_nll:0.2014 tst_kl:0.3277 sum_norm_grad:0.0510 it:10001 ms/it:4.0287 tst_acc:0.9605 trn_loss:0.5603 tst_loss:0.5282 tst_nll:0.2026 tst_kl:0.3256 sum_norm_grad:2.9182 it:10101 ms/it:4.0722 tst_acc:0.9515 trn_loss:0.7266 tst_loss:0.6692 tst_nll:0.3464 tst_kl:0.3229 sum_norm_grad:5.5061 it:10201 ms/it:4.1431 tst_acc:0.9615 trn_loss:0.5190 tst_loss:0.5144 tst_nll:0.1918 tst_kl:0.3226 sum_norm_grad:3.6430 it:10301 ms/it:4.1492 tst_acc:0.9635 trn_loss:0.3812 tst_loss:0.4837 tst_nll:0.1635 tst_kl:0.3202 sum_norm_grad:3.0847 it:10401 ms/it:3.9624 tst_acc:0.9610 trn_loss:0.4839 tst_loss:0.5172 tst_nll:0.1983 tst_kl:0.3189 sum_norm_grad:3.1248 it:10501 ms/it:4.0594 tst_acc:0.9665 trn_loss:0.6110 tst_loss:0.4930 tst_nll:0.1716 tst_kl:0.3214 sum_norm_grad:5.1827 it:10601 ms/it:4.1912 tst_acc:0.9530 trn_loss:0.4877 tst_loss:0.5494 tst_nll:0.2295 tst_kl:0.3199 sum_norm_grad:6.0416 it:10701 ms/it:3.9832 tst_acc:0.9645 trn_loss:0.3775 tst_loss:0.5252 tst_nll:0.2038 tst_kl:0.3215 sum_norm_grad:2.7238 it:10801 ms/it:4.0256 tst_acc:0.9725 trn_loss:0.3877 tst_loss:0.4734 tst_nll:0.1526 tst_kl:0.3208 sum_norm_grad:2.9169 it:10901 ms/it:4.0747 tst_acc:0.9580 trn_loss:0.3180 tst_loss:0.6106 tst_nll:0.2944 tst_kl:0.3162 sum_norm_grad:0.1334 it:11001 ms/it:3.9968 tst_acc:0.9455 trn_loss:0.3194 tst_loss:0.5343 tst_nll:0.2152 tst_kl:0.3191 sum_norm_grad:0.0178 it:11101 ms/it:3.9972 tst_acc:0.9515 trn_loss:0.3889 tst_loss:0.5569 tst_nll:0.2421 tst_kl:0.3147 sum_norm_grad:2.8468 it:11201 ms/it:4.0366 tst_acc:0.9610 trn_loss:0.6569 tst_loss:0.4738 tst_nll:0.1589 tst_kl:0.3149 sum_norm_grad:3.3640 it:11301 ms/it:4.1191 tst_acc:0.9580 trn_loss:0.3132 tst_loss:0.4906 tst_nll:0.1786 tst_kl:0.3120 sum_norm_grad:0.0134 it:11401 ms/it:4.0271 tst_acc:0.9610 trn_loss:0.3254 tst_loss:0.5107 tst_nll:0.2023 tst_kl:0.3084 sum_norm_grad:1.3171 it:11501 ms/it:3.9839 tst_acc:0.9590 trn_loss:0.3049 tst_loss:0.4978 tst_nll:0.1914 tst_kl:0.3064 sum_norm_grad:0.0171 it:11601 ms/it:4.1491 tst_acc:0.9555 trn_loss:0.7551 tst_loss:0.5123 tst_nll:0.2050 tst_kl:0.3073 sum_norm_grad:3.4431 it:11701 ms/it:4.0772 tst_acc:0.9590 trn_loss:0.3967 tst_loss:0.4841 tst_nll:0.1806 tst_kl:0.3036 sum_norm_grad:3.3256 it:11801 ms/it:4.0272 tst_acc:0.9640 trn_loss:0.3794 tst_loss:0.4489 tst_nll:0.1470 tst_kl:0.3020 sum_norm_grad:2.0461 it:11901 ms/it:4.0927 tst_acc:0.9645 trn_loss:0.5983 tst_loss:0.5148 tst_nll:0.2124 tst_kl:0.3024 sum_norm_grad:2.4591 it:12001 ms/it:4.1528 tst_acc:0.9585 trn_loss:0.4211 tst_loss:0.4888 tst_nll:0.1880 tst_kl:0.3008 sum_norm_grad:3.2418 it:12101 ms/it:3.9710 tst_acc:0.9545 trn_loss:0.3040 tst_loss:0.5454 tst_nll:0.2443 tst_kl:0.3011 sum_norm_grad:0.0830 it:12201 ms/it:4.0374 tst_acc:0.9645 trn_loss:0.5958 tst_loss:0.4543 tst_nll:0.1528 tst_kl:0.3015 sum_norm_grad:3.9811 it:12301 ms/it:4.1819 tst_acc:0.9655 trn_loss:0.3156 tst_loss:0.5346 tst_nll:0.2348 tst_kl:0.2998 sum_norm_grad:1.6105 it:12401 ms/it:3.9935 tst_acc:0.9635 trn_loss:0.3303 tst_loss:0.4585 tst_nll:0.1592 tst_kl:0.2993 sum_norm_grad:2.2599 it:12501 ms/it:4.0487 tst_acc:0.9645 trn_loss:0.3004 tst_loss:0.5061 tst_nll:0.2083 tst_kl:0.2978 sum_norm_grad:0.0668 it:12601 ms/it:3.9198 tst_acc:0.9625 trn_loss:1.6642 tst_loss:0.4656 tst_nll:0.1650 tst_kl:0.3005 sum_norm_grad:6.9393 it:12701 ms/it:3.7626 tst_acc:0.9560 trn_loss:0.5358 tst_loss:0.4966 tst_nll:0.1965 tst_kl:0.3000 sum_norm_grad:7.2128 it:12801 ms/it:3.7579 tst_acc:0.9680 trn_loss:0.3825 tst_loss:0.4433 tst_nll:0.1396 tst_kl:0.3038 sum_norm_grad:2.9145 it:12901 ms/it:3.7854 tst_acc:0.9720 trn_loss:0.3692 tst_loss:0.4615 tst_nll:0.1624 tst_kl:0.2991 sum_norm_grad:2.2373 it:13001 ms/it:3.9328 tst_acc:0.9580 trn_loss:0.3034 tst_loss:0.4807 tst_nll:0.1820 tst_kl:0.2987 sum_norm_grad:0.2544 it:13101 ms/it:4.0091 tst_acc:0.9625 trn_loss:0.4843 tst_loss:0.4470 tst_nll:0.1499 tst_kl:0.2971 sum_norm_grad:2.6624 it:13201 ms/it:4.1551 tst_acc:0.9605 trn_loss:0.3185 tst_loss:0.4487 tst_nll:0.1522 tst_kl:0.2966 sum_norm_grad:0.9811 it:13301 ms/it:3.8945 tst_acc:0.9640 trn_loss:0.5640 tst_loss:0.4398 tst_nll:0.1449 tst_kl:0.2950 sum_norm_grad:2.3270 it:13401 ms/it:4.0474 tst_acc:0.9715 trn_loss:0.7745 tst_loss:0.4601 tst_nll:0.1676 tst_kl:0.2925 sum_norm_grad:4.6689 it:13501 ms/it:4.1910 tst_acc:0.9535 trn_loss:0.3593 tst_loss:0.5454 tst_nll:0.2533 tst_kl:0.2921 sum_norm_grad:1.9409 it:13601 ms/it:3.9952 tst_acc:0.9635 trn_loss:0.4736 tst_loss:0.4601 tst_nll:0.1676 tst_kl:0.2924 sum_norm_grad:3.5644 it:13701 ms/it:4.0719 tst_acc:0.9710 trn_loss:0.2986 tst_loss:0.4606 tst_nll:0.1661 tst_kl:0.2946 sum_norm_grad:0.6279 it:13801 ms/it:4.0196 tst_acc:0.9660 trn_loss:0.9478 tst_loss:0.4852 tst_nll:0.1958 tst_kl:0.2893 sum_norm_grad:5.0802 it:13901 ms/it:4.0258 tst_acc:0.9695 trn_loss:0.4310 tst_loss:0.4311 tst_nll:0.1410 tst_kl:0.2901 sum_norm_grad:5.7558 it:14001 ms/it:3.9523 tst_acc:0.9630 trn_loss:0.3845 tst_loss:0.4306 tst_nll:0.1413 tst_kl:0.2893 sum_norm_grad:3.2005 it:14101 ms/it:4.0503 tst_acc:0.9605 trn_loss:0.3007 tst_loss:0.4851 tst_nll:0.1945 tst_kl:0.2906 sum_norm_grad:0.8452 it:14201 ms/it:4.1792 tst_acc:0.9670 trn_loss:0.3004 tst_loss:0.4574 tst_nll:0.1676 tst_kl:0.2898 sum_norm_grad:0.5864 it:14301 ms/it:4.0907 tst_acc:0.9610 trn_loss:0.3002 tst_loss:0.4524 tst_nll:0.1630 tst_kl:0.2894 sum_norm_grad:0.6309 it:14401 ms/it:4.0256 tst_acc:0.9725 trn_loss:0.3457 tst_loss:0.4301 tst_nll:0.1397 tst_kl:0.2904 sum_norm_grad:2.7863 it:14501 ms/it:4.1129 tst_acc:0.9680 trn_loss:0.6517 tst_loss:0.4271 tst_nll:0.1370 tst_kl:0.2901 sum_norm_grad:5.5336 it:14601 ms/it:4.1169 tst_acc:0.9605 trn_loss:0.6168 tst_loss:0.4444 tst_nll:0.1557 tst_kl:0.2887 sum_norm_grad:6.3000 it:14701 ms/it:4.0285 tst_acc:0.9600 trn_loss:0.5961 tst_loss:0.4503 tst_nll:0.1622 tst_kl:0.2881 sum_norm_grad:6.2231 it:14801 ms/it:3.9670 tst_acc:0.9640 trn_loss:0.3150 tst_loss:0.4470 tst_nll:0.1587 tst_kl:0.2883 sum_norm_grad:1.6966 it:14901 ms/it:4.0100 tst_acc:0.9755 trn_loss:0.3967 tst_loss:0.4127 tst_nll:0.1235 tst_kl:0.2892 sum_norm_grad:3.2471 it:15001 ms/it:3.9892 tst_acc:0.9695 trn_loss:0.3209 tst_loss:0.4563 tst_nll:0.1675 tst_kl:0.2888 sum_norm_grad:3.6252 it:15101 ms/it:4.1697 tst_acc:0.9565 trn_loss:0.4164 tst_loss:0.5404 tst_nll:0.2558 tst_kl:0.2846 sum_norm_grad:6.1996 it:15201 ms/it:3.8684 tst_acc:0.9700 trn_loss:0.6764 tst_loss:0.4735 tst_nll:0.1879 tst_kl:0.2856 sum_norm_grad:4.1625 it:15301 ms/it:3.8755 tst_acc:0.9675 trn_loss:0.7920 tst_loss:0.4526 tst_nll:0.1680 tst_kl:0.2847 sum_norm_grad:5.0280 it:15401 ms/it:3.9478 tst_acc:0.9690 trn_loss:0.5737 tst_loss:0.4505 tst_nll:0.1667 tst_kl:0.2838 sum_norm_grad:4.9883 it:15501 ms/it:4.0127 tst_acc:0.9505 trn_loss:0.3146 tst_loss:0.4777 tst_nll:0.1902 tst_kl:0.2875 sum_norm_grad:1.1997 it:15601 ms/it:4.1600 tst_acc:0.9640 trn_loss:0.2892 tst_loss:0.4651 tst_nll:0.1764 tst_kl:0.2887 sum_norm_grad:0.0423 it:15701 ms/it:4.0222 tst_acc:0.9645 trn_loss:0.4415 tst_loss:0.4608 tst_nll:0.1754 tst_kl:0.2854 sum_norm_grad:4.5139 it:15801 ms/it:3.9192 tst_acc:0.9685 trn_loss:0.3805 tst_loss:0.4093 tst_nll:0.1281 tst_kl:0.2812 sum_norm_grad:5.8602 it:15901 ms/it:3.6926 tst_acc:0.9760 trn_loss:0.2826 tst_loss:0.4172 tst_nll:0.1363 tst_kl:0.2808 sum_norm_grad:0.0702 it:16001 ms/it:4.0114 tst_acc:0.9655 trn_loss:0.5406 tst_loss:0.4407 tst_nll:0.1565 tst_kl:0.2842 sum_norm_grad:7.7410 it:16101 ms/it:3.9677 tst_acc:0.9705 trn_loss:0.2867 tst_loss:0.4966 tst_nll:0.2131 tst_kl:0.2834 sum_norm_grad:0.4872 it:16201 ms/it:3.9714 tst_acc:0.9705 trn_loss:0.5924 tst_loss:0.4186 tst_nll:0.1330 tst_kl:0.2856 sum_norm_grad:4.3815 it:16301 ms/it:3.9763 tst_acc:0.9730 trn_loss:0.4129 tst_loss:0.4093 tst_nll:0.1237 tst_kl:0.2856 sum_norm_grad:3.9573 it:16401 ms/it:3.9081 tst_acc:0.9530 trn_loss:0.2875 tst_loss:0.4477 tst_nll:0.1610 tst_kl:0.2867 sum_norm_grad:0.1167 it:16501 ms/it:4.0766 tst_acc:0.9725 trn_loss:0.3180 tst_loss:0.4415 tst_nll:0.1535 tst_kl:0.2881 sum_norm_grad:2.1457 it:16601 ms/it:4.0789 tst_acc:0.9595 trn_loss:0.4518 tst_loss:0.4384 tst_nll:0.1529 tst_kl:0.2856 sum_norm_grad:3.7270 it:16701 ms/it:4.0770 tst_acc:0.9540 trn_loss:0.4883 tst_loss:0.5747 tst_nll:0.2888 tst_kl:0.2859 sum_norm_grad:6.3127 it:16801 ms/it:3.9686 tst_acc:0.9470 trn_loss:0.7660 tst_loss:0.5449 tst_nll:0.2607 tst_kl:0.2841 sum_norm_grad:6.8640 it:16901 ms/it:4.0843 tst_acc:0.9720 trn_loss:0.6791 tst_loss:0.4385 tst_nll:0.1543 tst_kl:0.2842 sum_norm_grad:5.0611 it:17001 ms/it:4.0326 tst_acc:0.9485 trn_loss:0.3662 tst_loss:0.6400 tst_nll:0.3543 tst_kl:0.2857 sum_norm_grad:5.5587 it:17101 ms/it:3.9873 tst_acc:0.9550 trn_loss:1.0070 tst_loss:0.6144 tst_nll:0.3246 tst_kl:0.2898 sum_norm_grad:11.7100 it:17201 ms/it:3.9656 tst_acc:0.9695 trn_loss:0.2885 tst_loss:0.4724 tst_nll:0.1841 tst_kl:0.2883 sum_norm_grad:0.0180 it:17301 ms/it:4.0687 tst_acc:0.9645 trn_loss:0.2902 tst_loss:0.4296 tst_nll:0.1415 tst_kl:0.2881 sum_norm_grad:0.0060 it:17401 ms/it:4.0105 tst_acc:0.9755 trn_loss:0.4903 tst_loss:0.4250 tst_nll:0.1388 tst_kl:0.2861 sum_norm_grad:5.6156 it:17501 ms/it:4.0767 tst_acc:0.9665 trn_loss:0.5377 tst_loss:0.4181 tst_nll:0.1331 tst_kl:0.2850 sum_norm_grad:4.3493 it:17601 ms/it:3.8058 tst_acc:0.9630 trn_loss:0.3337 tst_loss:0.4663 tst_nll:0.1846 tst_kl:0.2817 sum_norm_grad:2.6130 it:17701 ms/it:3.9994 tst_acc:0.9630 trn_loss:0.2957 tst_loss:0.4675 tst_nll:0.1863 tst_kl:0.2811 sum_norm_grad:0.9803 it:17801 ms/it:4.0825 tst_acc:0.9715 trn_loss:0.3633 tst_loss:0.4168 tst_nll:0.1360 tst_kl:0.2808 sum_norm_grad:4.0092 it:17901 ms/it:4.0679 tst_acc:0.9810 trn_loss:0.2838 tst_loss:0.3677 tst_nll:0.0852 tst_kl:0.2825 sum_norm_grad:0.1164 it:18001 ms/it:4.1126 tst_acc:0.9640 trn_loss:0.4023 tst_loss:0.4335 tst_nll:0.1508 tst_kl:0.2827 sum_norm_grad:2.9085 it:18101 ms/it:4.0763 tst_acc:0.9635 trn_loss:0.3012 tst_loss:0.4868 tst_nll:0.2053 tst_kl:0.2815 sum_norm_grad:1.4692 it:18201 ms/it:4.0795 tst_acc:0.9625 trn_loss:0.2843 tst_loss:0.4739 tst_nll:0.1924 tst_kl:0.2815 sum_norm_grad:0.2426 it:18301 ms/it:3.9620 tst_acc:0.9580 trn_loss:0.4541 tst_loss:0.4786 tst_nll:0.1970 tst_kl:0.2816 sum_norm_grad:3.2329 it:18401 ms/it:4.0137 tst_acc:0.9665 trn_loss:0.4026 tst_loss:0.4390 tst_nll:0.1543 tst_kl:0.2847 sum_norm_grad:3.5079 it:18501 ms/it:4.0719 tst_acc:0.9545 trn_loss:0.5524 tst_loss:0.5126 tst_nll:0.2262 tst_kl:0.2864 sum_norm_grad:3.8132 it:18601 ms/it:3.9031 tst_acc:0.9670 trn_loss:0.2877 tst_loss:0.4704 tst_nll:0.1829 tst_kl:0.2875 sum_norm_grad:0.0500 it:18701 ms/it:3.9321 tst_acc:0.9685 trn_loss:0.2878 tst_loss:0.4687 tst_nll:0.1828 tst_kl:0.2860 sum_norm_grad:0.0418 it:18801 ms/it:4.0339 tst_acc:0.9740 trn_loss:0.4590 tst_loss:0.4106 tst_nll:0.1223 tst_kl:0.2883 sum_norm_grad:3.0058 it:18901 ms/it:4.1064 tst_acc:0.9515 trn_loss:0.4362 tst_loss:0.4768 tst_nll:0.1898 tst_kl:0.2870 sum_norm_grad:4.9499 it:19001 ms/it:4.1348 tst_acc:0.9645 trn_loss:0.2843 tst_loss:0.4390 tst_nll:0.1531 tst_kl:0.2859 sum_norm_grad:0.0411 it:19101 ms/it:3.8992 tst_acc:0.9665 trn_loss:0.2848 tst_loss:0.4721 tst_nll:0.1884 tst_kl:0.2837 sum_norm_grad:0.0044 it:19201 ms/it:4.0857 tst_acc:0.9695 trn_loss:0.5177 tst_loss:0.4461 tst_nll:0.1582 tst_kl:0.2878 sum_norm_grad:4.7725 it:19301 ms/it:3.9766 tst_acc:0.9710 trn_loss:0.2896 tst_loss:0.3844 tst_nll:0.1000 tst_kl:0.2844 sum_norm_grad:0.2153 it:19401 ms/it:3.9142 tst_acc:0.9725 trn_loss:0.3389 tst_loss:0.4276 tst_nll:0.1438 tst_kl:0.2838 sum_norm_grad:3.5745 it:19501 ms/it:3.8347 tst_acc:0.9680 trn_loss:0.5121 tst_loss:0.4608 tst_nll:0.1741 tst_kl:0.2867 sum_norm_grad:5.2784 it:19601 ms/it:4.0078 tst_acc:0.9660 trn_loss:0.6609 tst_loss:0.4748 tst_nll:0.1863 tst_kl:0.2886 sum_norm_grad:4.7520 it:19701 ms/it:4.0738 tst_acc:0.9670 trn_loss:0.4750 tst_loss:0.5153 tst_nll:0.2284 tst_kl:0.2869 sum_norm_grad:6.4839 it:19801 ms/it:4.0431 tst_acc:0.9700 trn_loss:0.3464 tst_loss:0.4074 tst_nll:0.1198 tst_kl:0.2876 sum_norm_grad:2.0778 it:19901 ms/it:4.0963 tst_acc:0.9705 trn_loss:0.3211 tst_loss:0.4235 tst_nll:0.1377 tst_kl:0.2858 sum_norm_grad:1.9506 it:20000 ms/it:3.8401 tst_acc:0.9710 trn_loss:0.3530 tst_loss:0.4541 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:4.1883
# Run inference multiple times...
num_inferences = 10 # @param { isTemplate: true}
@tf.function(autograph=False)
def predicted_log_prob(x):
with tf.xla.experimental.jit_scope(compile_ops=True):
return tf.math.log_softmax(bnn(x).logits, axis=-1)
eval_iter = iter(eval_dataset.batch(2000).repeat(int(num_inferences)))
before_avg_predicted_log_probs = tf.reshape(
tf.stack([predicted_log_prob(x) for x, _ in eval_iter], axis=0),
shape=[int(num_inferences), datasets_info.splits['test'].num_examples, -1])
bnn_predicted_log_probs = tfp.math.reduce_logmeanexp(
before_avg_predicted_log_probs, axis=0)
decision = tf.argmax(bnn_predicted_log_probs, axis=-1, output_type=tf.int32)
confidence = tf.reduce_max(bnn_predicted_log_probs, axis=-1)
threshold = 0.95
decided_idx = tf.where(confidence > np.log(threshold))
ordered = tf.argsort(confidence)
n = datasets_info.splits['test'].num_examples
x_final, y_final = next(iter(eval_dataset.batch(n)))
print('Number of examples undecided: {}'.format(n - tf.size(decided_idx)))
accurary = tf.reduce_mean(
tf.cast(tf.equal(tf.gather(y_final, decided_idx),
tf.gather(decision, decided_idx)),
tf.float32))
print('Accurary after excluding undecided ones: {}'.format(accurary))
tfp_nn.util.display_imgs(
tf.gather(x_final, ordered[0:50]),
tf.gather(y_final, ordered[0:50]));
Number of examples undecided: 705 Accurary after excluding undecided ones: 0.995266258717
from sklearn import metrics
bnn_auc = np.array([
metrics.roc_auc_score(tf.equal(y_final, i), bnn_predicted_log_probs[:, i])
for i in range(10)])
print('Per class AUC:\n{}'.format(bnn_auc[:, np.newaxis]))
Per class AUC: [[0.99994366] [0.99992725] [0.99974284] [0.99959576] [0.99985348] [0.99967074] [0.99962562] [0.99963218] [0.99862774] [0.99943871]]
max_pool = tf.keras.layers.MaxPooling2D( # Has no tf.Variables.
pool_size=(2, 2),
strides=(2, 2),
padding='SAME',
data_format='channels_last')
dnn = tfp_nn.Sequential([
tfp_nn.Convolution(
input_size=1,
output_size=8,
filter_shape=5,
padding='SAME',
name='conv1'),
tf.nn.leaky_relu,
max_pool, # [28, 28, 8] -> [14, 14, 8]
tfp_nn.Convolution(
input_size=8,
output_size=16,
filter_shape=5,
padding='SAME',
name='conv2'),
tf.nn.leaky_relu,
max_pool, # [14, 14, 16] -> [7, 7, 16]
tfp_nn.util.flatten_rightmost,
tfp_nn.Affine(
input_size=7 * 7 * 16,
output_size=10,
name='affine1'),
lambda x: tfd.Categorical(logits=x, dtype=tf.int32),
], name='DNN')
print(dnn.summary())
=== DNN ==================================================
SIZE SHAPE TRAIN NAME
8 [8] True bias:0
200 [5, 5, 1, 8] True kernel:0
16 [16] True bias:0
3200 [5, 5, 8, 16] True kernel:0
10 [10] True bias:0
7840 [784, 10] True kernel:0
trainable size: 11274 / 0.043 MiB / {float32: 11274}
train_iter = iter(train_dataset)
eval_iter = iter(eval_dataset.batch(2000).repeat())
def loss():
x, y = next(train_iter)
return -tf.reduce_mean(dnn(x).log_prob(y), axis=-1), None
opt = tf.optimizers.Adam(learning_rate=1e-2)
fit = tfp_nn.util.make_fit_op(
loss,
opt,
dnn.trainable_variables,
grad_summary_fn=lambda gs: tf.nest.map_structure(tf.norm, gs))
@tf.function(autograph=False)
def eval():
with tf.xla.experimental.jit_scope(compile_ops=True):
x, y = next(eval_iter)
yhat = dnn(x)
nll = -tf.reduce_mean(yhat.log_prob(y), axis=-1)
acc = tf.reduce_mean(tf.cast(tf.equal(y, yhat.mode()), tf.float32), axis=-1)
return nll, acc
num_train_steps = 20e3 # @param { isTemplate: true}
num_train_steps = int(num_train_steps) # Enforce correct type when overridden.
dur_sec = dur_num = 0
for i in range(num_train_steps):
start = time.time()
trn_loss, _, g = fit()
stop = time.time()
dur_sec += stop - start
dur_num += 1
if i % 100 == 0 or i == num_train_steps - 1:
tst_loss, tst_acc= eval()
f, x = zip(*[
('it:{:5}', opt.iterations),
('ms/it:{:6.4f}', dur_sec / max(1., dur_num) * 1000.),
('tst_acc:{:6.4f}', tst_acc),
('trn_loss:{:6.4f}', trn_loss),
('tst_loss:{:6.4f}', tst_loss),
('tst_nll:{:6.4f}', tst_nll),
('tst_kl:{:6.4f}', tst_kl),
('sum_norm_grad:{:6.4f}', sum(g)),
])
print(' '.join(f).format(*[getattr(x_, 'numpy', lambda: x_)()
for x_ in x]))
sys.stdout.flush()
dur_sec = dur_num = 0
# if i % 1000 == 0 or i == maxiter - 1:
# dnn.save('/tmp/vae.npz')
it: 1 ms/it:947.1798 tst_acc:0.2020 trn_loss:2.4360 tst_loss:2.2867 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.2128 it: 101 ms/it:1.8313 tst_acc:0.9485 trn_loss:0.0494 tst_loss:0.1961 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.5276 it: 201 ms/it:1.8815 tst_acc:0.9515 trn_loss:0.0971 tst_loss:0.1553 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.6970 it: 301 ms/it:1.8873 tst_acc:0.9595 trn_loss:0.0675 tst_loss:0.1329 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.6763 it: 401 ms/it:1.9162 tst_acc:0.9590 trn_loss:0.0372 tst_loss:0.1261 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.1034 it: 501 ms/it:1.8792 tst_acc:0.9565 trn_loss:0.2008 tst_loss:0.1358 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.5304 it: 601 ms/it:1.8019 tst_acc:0.9670 trn_loss:0.1382 tst_loss:0.0962 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.8372 it: 701 ms/it:1.8842 tst_acc:0.9595 trn_loss:0.1764 tst_loss:0.1383 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.3745 it: 801 ms/it:1.7390 tst_acc:0.9735 trn_loss:0.1286 tst_loss:0.0840 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.1147 it: 901 ms/it:1.9072 tst_acc:0.9845 trn_loss:0.0826 tst_loss:0.0450 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.0499 it: 1001 ms/it:1.9544 tst_acc:0.9670 trn_loss:0.3418 tst_loss:0.1032 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.6575 it: 1101 ms/it:1.9279 tst_acc:0.9745 trn_loss:0.2445 tst_loss:0.0854 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.6256 it: 1201 ms/it:1.8666 tst_acc:0.9740 trn_loss:0.1481 tst_loss:0.0832 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.9186 it: 1301 ms/it:1.9622 tst_acc:0.9770 trn_loss:0.1692 tst_loss:0.0729 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.6448 it: 1401 ms/it:1.8534 tst_acc:0.9785 trn_loss:0.0597 tst_loss:0.1007 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.5306 it: 1501 ms/it:1.9520 tst_acc:0.9700 trn_loss:0.3685 tst_loss:0.0880 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.7521 it: 1601 ms/it:1.9009 tst_acc:0.9775 trn_loss:0.4346 tst_loss:0.0606 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:4.2722 it: 1701 ms/it:1.9912 tst_acc:0.9780 trn_loss:0.1151 tst_loss:0.0710 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.3924 it: 1801 ms/it:1.8722 tst_acc:0.9755 trn_loss:0.0110 tst_loss:0.0782 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.3810 it: 1901 ms/it:1.9005 tst_acc:0.9795 trn_loss:0.0059 tst_loss:0.0759 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.2433 it: 2001 ms/it:1.8488 tst_acc:0.9750 trn_loss:0.0583 tst_loss:0.0659 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.1850 it: 2101 ms/it:1.9041 tst_acc:0.9670 trn_loss:0.0043 tst_loss:0.1234 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.2029 it: 2201 ms/it:1.8319 tst_acc:0.9645 trn_loss:0.0655 tst_loss:0.1298 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.9481 it: 2301 ms/it:1.9331 tst_acc:0.9660 trn_loss:0.0004 tst_loss:0.1582 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0237 it: 2401 ms/it:1.9130 tst_acc:0.9685 trn_loss:0.6777 tst_loss:0.1454 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:5.7664 it: 2501 ms/it:1.9730 tst_acc:0.9660 trn_loss:0.1712 tst_loss:0.1015 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.4604 it: 2601 ms/it:1.8234 tst_acc:0.9745 trn_loss:0.0069 tst_loss:0.1264 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.3191 it: 2701 ms/it:1.8260 tst_acc:0.9625 trn_loss:0.0064 tst_loss:0.1271 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.2002 it: 2801 ms/it:1.8385 tst_acc:0.9750 trn_loss:0.2281 tst_loss:0.0854 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.8069 it: 2901 ms/it:1.8029 tst_acc:0.9815 trn_loss:0.0291 tst_loss:0.0668 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.6005 it: 3001 ms/it:1.8537 tst_acc:0.9610 trn_loss:0.2334 tst_loss:0.1418 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.1746 it: 3101 ms/it:1.8889 tst_acc:0.9740 trn_loss:0.0018 tst_loss:0.1433 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0881 it: 3201 ms/it:1.8358 tst_acc:0.9735 trn_loss:0.1684 tst_loss:0.1178 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.3806 it: 3301 ms/it:1.9122 tst_acc:0.9745 trn_loss:0.0497 tst_loss:0.1153 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.1474 it: 3401 ms/it:1.9089 tst_acc:0.9720 trn_loss:0.0006 tst_loss:0.0930 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0316 it: 3501 ms/it:1.9271 tst_acc:0.9770 trn_loss:0.3862 tst_loss:0.0926 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.2212 it: 3601 ms/it:1.8609 tst_acc:0.9780 trn_loss:0.0328 tst_loss:0.0887 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.1865 it: 3701 ms/it:1.9096 tst_acc:0.9720 trn_loss:0.2323 tst_loss:0.1357 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.7994 it: 3801 ms/it:1.8436 tst_acc:0.9785 trn_loss:0.0211 tst_loss:0.1066 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.2262 it: 3901 ms/it:1.9345 tst_acc:0.9815 trn_loss:0.2188 tst_loss:0.0651 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.5738 it: 4001 ms/it:1.9165 tst_acc:0.9730 trn_loss:0.1181 tst_loss:0.0939 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.1097 it: 4101 ms/it:1.8918 tst_acc:0.9765 trn_loss:0.0021 tst_loss:0.1239 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.2101 it: 4201 ms/it:1.8934 tst_acc:0.9780 trn_loss:0.1797 tst_loss:0.1366 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:5.6820 it: 4301 ms/it:1.9533 tst_acc:0.9760 trn_loss:0.0326 tst_loss:0.1189 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.5648 it: 4401 ms/it:1.9542 tst_acc:0.9755 trn_loss:0.0188 tst_loss:0.0899 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.4059 it: 4501 ms/it:1.9615 tst_acc:0.9770 trn_loss:0.0020 tst_loss:0.0911 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.1239 it: 4601 ms/it:1.8829 tst_acc:0.9740 trn_loss:0.0258 tst_loss:0.1248 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.3504 it: 4701 ms/it:1.9402 tst_acc:0.9695 trn_loss:0.0004 tst_loss:0.1254 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0192 it: 4801 ms/it:1.9671 tst_acc:0.9795 trn_loss:0.0648 tst_loss:0.1293 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.5944 it: 4901 ms/it:1.9212 tst_acc:0.9795 trn_loss:0.0717 tst_loss:0.0998 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.1639 it: 5001 ms/it:1.8767 tst_acc:0.9630 trn_loss:0.0079 tst_loss:0.1598 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.8263 it: 5101 ms/it:1.8408 tst_acc:0.9730 trn_loss:0.1447 tst_loss:0.1446 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.1445 it: 5201 ms/it:1.8568 tst_acc:0.9830 trn_loss:0.0450 tst_loss:0.0646 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.0922 it: 5301 ms/it:1.8645 tst_acc:0.9810 trn_loss:0.0002 tst_loss:0.0877 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0133 it: 5401 ms/it:2.0107 tst_acc:0.9715 trn_loss:0.0026 tst_loss:0.0888 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.1400 it: 5501 ms/it:1.9240 tst_acc:0.9740 trn_loss:0.0513 tst_loss:0.1046 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.0739 it: 5601 ms/it:1.9938 tst_acc:0.9755 trn_loss:0.0002 tst_loss:0.0968 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0092 it: 5701 ms/it:1.8752 tst_acc:0.9775 trn_loss:0.0002 tst_loss:0.0956 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0181 it: 5801 ms/it:1.9496 tst_acc:0.9655 trn_loss:0.0488 tst_loss:0.1383 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.2051 it: 5901 ms/it:1.9216 tst_acc:0.9385 trn_loss:0.0019 tst_loss:0.3988 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.2937 it: 6001 ms/it:1.9092 tst_acc:0.9735 trn_loss:0.3313 tst_loss:0.1342 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:5.1321 it: 6101 ms/it:1.8636 tst_acc:0.9705 trn_loss:0.3192 tst_loss:0.2283 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:5.3080 it: 6201 ms/it:1.9832 tst_acc:0.9725 trn_loss:0.5824 tst_loss:0.2525 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:7.1433 it: 6301 ms/it:1.9989 tst_acc:0.9790 trn_loss:0.0720 tst_loss:0.0891 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.1926 it: 6401 ms/it:2.0012 tst_acc:0.9745 trn_loss:0.0110 tst_loss:0.1168 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.1464 it: 6501 ms/it:1.9625 tst_acc:0.9715 trn_loss:0.2354 tst_loss:0.1131 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.7642 it: 6601 ms/it:1.8740 tst_acc:0.9795 trn_loss:0.0018 tst_loss:0.1038 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.1598 it: 6701 ms/it:1.9265 tst_acc:0.9785 trn_loss:0.0911 tst_loss:0.1080 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.6253 it: 6801 ms/it:1.9781 tst_acc:0.9655 trn_loss:0.0008 tst_loss:0.1389 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0917 it: 6901 ms/it:1.9172 tst_acc:0.9855 trn_loss:0.1280 tst_loss:0.0714 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.6832 it: 7001 ms/it:1.8954 tst_acc:0.9725 trn_loss:0.0057 tst_loss:0.1199 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.5952 it: 7101 ms/it:1.9048 tst_acc:0.9830 trn_loss:0.0526 tst_loss:0.0875 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.1944 it: 7201 ms/it:1.9000 tst_acc:0.9795 trn_loss:0.0036 tst_loss:0.1177 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.2082 it: 7301 ms/it:1.8918 tst_acc:0.9855 trn_loss:0.1575 tst_loss:0.0595 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:6.4531 it: 7401 ms/it:1.8148 tst_acc:0.9770 trn_loss:0.0023 tst_loss:0.0912 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.2389 it: 7501 ms/it:1.8609 tst_acc:0.9615 trn_loss:0.1333 tst_loss:0.1984 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:4.0536 it: 7601 ms/it:1.8783 tst_acc:0.9625 trn_loss:0.1348 tst_loss:0.1635 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.3381 it: 7701 ms/it:1.9814 tst_acc:0.9670 trn_loss:0.0914 tst_loss:0.2154 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.3460 it: 7801 ms/it:1.8573 tst_acc:0.9790 trn_loss:0.5642 tst_loss:0.1225 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:5.2740 it: 7901 ms/it:1.9789 tst_acc:0.9825 trn_loss:0.0746 tst_loss:0.0923 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.0229 it: 8001 ms/it:1.9607 tst_acc:0.9765 trn_loss:0.0647 tst_loss:0.1554 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:4.3073 it: 8101 ms/it:1.7003 tst_acc:0.9775 trn_loss:0.0000 tst_loss:0.1622 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0052 it: 8201 ms/it:1.9164 tst_acc:0.9740 trn_loss:0.1375 tst_loss:0.1565 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:5.1471 it: 8301 ms/it:2.0367 tst_acc:0.9825 trn_loss:0.0443 tst_loss:0.0957 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.6082 it: 8401 ms/it:1.9045 tst_acc:0.9780 trn_loss:0.4172 tst_loss:0.1610 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:4.2836 it: 8501 ms/it:1.8642 tst_acc:0.9680 trn_loss:0.1343 tst_loss:0.1856 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:4.9268 it: 8601 ms/it:1.9150 tst_acc:0.9760 trn_loss:0.1492 tst_loss:0.1116 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:5.3113 it: 8701 ms/it:1.8917 tst_acc:0.9840 trn_loss:0.0010 tst_loss:0.1036 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.1165 it: 8801 ms/it:1.9829 tst_acc:0.9760 trn_loss:0.0072 tst_loss:0.0998 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.7192 it: 8901 ms/it:1.8685 tst_acc:0.9835 trn_loss:0.0003 tst_loss:0.0713 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0218 it: 9001 ms/it:1.8986 tst_acc:0.9755 trn_loss:0.5329 tst_loss:0.1177 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:4.3075 it: 9101 ms/it:1.8849 tst_acc:0.9745 trn_loss:0.0001 tst_loss:0.1427 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0074 it: 9201 ms/it:1.9143 tst_acc:0.9740 trn_loss:0.0064 tst_loss:0.1286 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.6379 it: 9301 ms/it:1.9335 tst_acc:0.9770 trn_loss:0.0049 tst_loss:0.0912 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.3331 it: 9401 ms/it:1.8829 tst_acc:0.9820 trn_loss:0.1389 tst_loss:0.0869 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.7373 it: 9501 ms/it:1.8093 tst_acc:0.9800 trn_loss:0.0000 tst_loss:0.0994 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0026 it: 9601 ms/it:1.8548 tst_acc:0.9775 trn_loss:0.1294 tst_loss:0.1379 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:4.6775 it: 9701 ms/it:1.8653 tst_acc:0.9690 trn_loss:0.2082 tst_loss:0.1807 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.6929 it: 9801 ms/it:1.8750 tst_acc:0.9790 trn_loss:0.0000 tst_loss:0.1341 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0016 it: 9901 ms/it:1.8391 tst_acc:0.9840 trn_loss:0.0005 tst_loss:0.1072 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0553 it:10001 ms/it:1.8475 tst_acc:0.9740 trn_loss:0.0170 tst_loss:0.1504 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.6424 it:10101 ms/it:1.8375 tst_acc:0.9790 trn_loss:0.0000 tst_loss:0.1647 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0001 it:10201 ms/it:1.8481 tst_acc:0.9715 trn_loss:0.0693 tst_loss:0.1810 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:5.2210 it:10301 ms/it:1.9360 tst_acc:0.9710 trn_loss:0.2064 tst_loss:0.1916 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.9218 it:10401 ms/it:1.7590 tst_acc:0.9795 trn_loss:0.0003 tst_loss:0.1485 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0427 it:10501 ms/it:1.9131 tst_acc:0.9600 trn_loss:0.0003 tst_loss:0.2860 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0428 it:10601 ms/it:1.8459 tst_acc:0.9820 trn_loss:0.0566 tst_loss:0.1933 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.4583 it:10701 ms/it:1.9201 tst_acc:0.9780 trn_loss:0.0000 tst_loss:0.1731 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0001 it:10801 ms/it:1.8856 tst_acc:0.9760 trn_loss:0.2365 tst_loss:0.2043 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.9638 it:10901 ms/it:1.9326 tst_acc:0.9810 trn_loss:0.0040 tst_loss:0.1054 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.5333 it:11001 ms/it:1.8518 tst_acc:0.9750 trn_loss:0.8869 tst_loss:0.1666 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:7.3510 it:11101 ms/it:1.8174 tst_acc:0.9800 trn_loss:0.0000 tst_loss:0.1264 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0013 it:11201 ms/it:1.9216 tst_acc:0.9795 trn_loss:0.2843 tst_loss:0.1280 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:6.0041 it:11301 ms/it:1.8474 tst_acc:0.9800 trn_loss:0.0000 tst_loss:0.1312 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0000 it:11401 ms/it:1.9509 tst_acc:0.9860 trn_loss:0.0717 tst_loss:0.0760 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.2902 it:11501 ms/it:1.9183 tst_acc:0.9835 trn_loss:0.0720 tst_loss:0.0706 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.8035 it:11601 ms/it:1.8588 tst_acc:0.9810 trn_loss:0.0000 tst_loss:0.1667 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0000 it:11701 ms/it:1.9243 tst_acc:0.9780 trn_loss:0.0000 tst_loss:0.1288 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0002 it:11801 ms/it:2.0181 tst_acc:0.9800 trn_loss:0.0015 tst_loss:0.1283 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.1053 it:11901 ms/it:1.9545 tst_acc:0.9825 trn_loss:0.0000 tst_loss:0.1035 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0013 it:12001 ms/it:1.8887 tst_acc:0.9790 trn_loss:0.0000 tst_loss:0.1556 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0020 it:12101 ms/it:1.7490 tst_acc:0.9795 trn_loss:0.0061 tst_loss:0.1599 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.5866 it:12201 ms/it:1.9567 tst_acc:0.9840 trn_loss:0.0086 tst_loss:0.1277 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.0934 it:12301 ms/it:1.8599 tst_acc:0.9850 trn_loss:0.0163 tst_loss:0.1061 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.7902 it:12401 ms/it:1.9244 tst_acc:0.9800 trn_loss:0.0166 tst_loss:0.1223 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.7924 it:12501 ms/it:1.8721 tst_acc:0.9780 trn_loss:0.6067 tst_loss:0.1491 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:11.0514 it:12601 ms/it:1.8351 tst_acc:0.9735 trn_loss:0.0553 tst_loss:0.2212 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.6150 it:12701 ms/it:1.8683 tst_acc:0.9770 trn_loss:0.0578 tst_loss:0.1618 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:4.5222 it:12801 ms/it:1.8916 tst_acc:0.9780 trn_loss:0.1388 tst_loss:0.1443 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:4.4016 it:12901 ms/it:1.8777 tst_acc:0.9795 trn_loss:0.3753 tst_loss:0.0841 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:5.2822 it:13001 ms/it:1.9720 tst_acc:0.9800 trn_loss:0.0006 tst_loss:0.1374 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.1209 it:13101 ms/it:1.9213 tst_acc:0.9760 trn_loss:1.0306 tst_loss:0.1779 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:7.6652 it:13201 ms/it:1.9153 tst_acc:0.9760 trn_loss:0.0217 tst_loss:0.2003 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.4286 it:13301 ms/it:1.9395 tst_acc:0.9820 trn_loss:0.3963 tst_loss:0.1308 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:4.9430 it:13401 ms/it:1.9037 tst_acc:0.9815 trn_loss:0.0000 tst_loss:0.1967 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0000 it:13501 ms/it:1.9530 tst_acc:0.9780 trn_loss:0.1648 tst_loss:0.1603 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:6.8833 it:13601 ms/it:1.9008 tst_acc:0.9745 trn_loss:0.0000 tst_loss:0.2226 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0014 it:13701 ms/it:1.9314 tst_acc:0.9800 trn_loss:0.3639 tst_loss:0.1396 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:6.0991 it:13801 ms/it:1.9677 tst_acc:0.9825 trn_loss:0.0000 tst_loss:0.1507 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0003 it:13901 ms/it:1.7624 tst_acc:0.9835 trn_loss:0.0634 tst_loss:0.1499 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.6273 it:14001 ms/it:1.6495 tst_acc:0.9755 trn_loss:0.0000 tst_loss:0.1899 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0000 it:14101 ms/it:1.6304 tst_acc:0.9745 trn_loss:0.0000 tst_loss:0.2126 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0000 it:14201 ms/it:1.6777 tst_acc:0.9800 trn_loss:0.0026 tst_loss:0.1826 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.2822 it:14301 ms/it:1.8394 tst_acc:0.9785 trn_loss:0.0000 tst_loss:0.1636 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0018 it:14401 ms/it:1.8427 tst_acc:0.9770 trn_loss:0.5020 tst_loss:0.1936 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:7.6895 it:14501 ms/it:1.7796 tst_acc:0.9830 trn_loss:0.0012 tst_loss:0.1847 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.1890 it:14601 ms/it:1.7793 tst_acc:0.9780 trn_loss:1.4099 tst_loss:0.2085 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:6.8055 it:14701 ms/it:1.5060 tst_acc:0.9790 trn_loss:0.0023 tst_loss:0.1367 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.2968 it:14801 ms/it:2.0035 tst_acc:0.9760 trn_loss:1.0047 tst_loss:0.2220 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:10.6048 it:14901 ms/it:1.6726 tst_acc:0.9760 trn_loss:0.9350 tst_loss:0.2156 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:5.9722 it:15001 ms/it:1.9433 tst_acc:0.9810 trn_loss:0.0089 tst_loss:0.1182 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.2063 it:15101 ms/it:1.8444 tst_acc:0.9750 trn_loss:0.0584 tst_loss:0.2196 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:4.6017 it:15201 ms/it:1.9109 tst_acc:0.9795 trn_loss:0.0000 tst_loss:0.1789 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0005 it:15301 ms/it:1.8718 tst_acc:0.9810 trn_loss:0.0000 tst_loss:0.1174 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0003 it:15401 ms/it:1.8953 tst_acc:0.9785 trn_loss:0.0000 tst_loss:0.1475 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0001 it:15501 ms/it:1.8311 tst_acc:0.9765 trn_loss:0.0001 tst_loss:0.1564 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0073 it:15601 ms/it:1.7217 tst_acc:0.9750 trn_loss:0.3296 tst_loss:0.1664 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:6.9605 it:15701 ms/it:1.9053 tst_acc:0.9855 trn_loss:0.0009 tst_loss:0.1667 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.1999 it:15801 ms/it:1.7717 tst_acc:0.9825 trn_loss:0.0000 tst_loss:0.1057 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0000 it:15901 ms/it:1.8217 tst_acc:0.9830 trn_loss:0.3892 tst_loss:0.1688 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:6.3087 it:16001 ms/it:1.9229 tst_acc:0.9745 trn_loss:0.1859 tst_loss:0.2158 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:4.9582 it:16101 ms/it:1.8082 tst_acc:0.9795 trn_loss:-0.0000 tst_loss:0.2029 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0000 it:16201 ms/it:1.9254 tst_acc:0.9775 trn_loss:0.2363 tst_loss:0.2955 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:8.8553 it:16301 ms/it:1.8827 tst_acc:0.9750 trn_loss:0.8202 tst_loss:0.2255 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:10.0281 it:16401 ms/it:1.9574 tst_acc:0.9645 trn_loss:0.0000 tst_loss:0.3707 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0000 it:16501 ms/it:1.9046 tst_acc:0.9820 trn_loss:0.0359 tst_loss:0.1896 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.8955 it:16601 ms/it:1.8141 tst_acc:0.9775 trn_loss:0.0399 tst_loss:0.1710 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:2.4733 it:16701 ms/it:1.9342 tst_acc:0.9835 trn_loss:0.0000 tst_loss:0.1501 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0000 it:16801 ms/it:1.9684 tst_acc:0.9760 trn_loss:0.0001 tst_loss:0.1900 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0106 it:16901 ms/it:1.8685 tst_acc:0.9820 trn_loss:0.7164 tst_loss:0.1348 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:8.9684 it:17001 ms/it:1.8551 tst_acc:0.9840 trn_loss:0.0000 tst_loss:0.1458 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0000 it:17101 ms/it:1.9224 tst_acc:0.9760 trn_loss:0.0006 tst_loss:0.2259 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0942 it:17201 ms/it:1.9630 tst_acc:0.9735 trn_loss:0.0003 tst_loss:0.3040 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0378 it:17301 ms/it:1.9050 tst_acc:0.9840 trn_loss:0.0033 tst_loss:0.1675 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.4517 it:17401 ms/it:1.8947 tst_acc:0.9800 trn_loss:0.0000 tst_loss:0.1678 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0000 it:17501 ms/it:1.9690 tst_acc:0.9815 trn_loss:0.0000 tst_loss:0.1366 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0000 it:17601 ms/it:1.8694 tst_acc:0.9770 trn_loss:0.9932 tst_loss:0.2249 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:5.7315 it:17701 ms/it:1.8724 tst_acc:0.9775 trn_loss:0.0000 tst_loss:0.1323 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0002 it:17801 ms/it:1.8554 tst_acc:0.9740 trn_loss:0.0519 tst_loss:0.1686 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.2556 it:17901 ms/it:1.8612 tst_acc:0.9740 trn_loss:0.0001 tst_loss:0.1772 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0100 it:18001 ms/it:1.9047 tst_acc:0.9805 trn_loss:0.6135 tst_loss:0.1413 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:7.8780 it:18101 ms/it:1.8002 tst_acc:0.9820 trn_loss:0.0053 tst_loss:0.2359 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.9542 it:18201 ms/it:1.8813 tst_acc:0.9775 trn_loss:0.0003 tst_loss:0.2230 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0574 it:18301 ms/it:1.8955 tst_acc:0.9865 trn_loss:-0.0000 tst_loss:0.1916 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0000 it:18401 ms/it:1.9161 tst_acc:0.9735 trn_loss:0.0001 tst_loss:0.1726 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0062 it:18501 ms/it:1.9954 tst_acc:0.9870 trn_loss:0.5568 tst_loss:0.1441 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:13.3536 it:18601 ms/it:1.8806 tst_acc:0.9735 trn_loss:0.1310 tst_loss:0.3379 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:5.4664 it:18701 ms/it:1.8249 tst_acc:0.9680 trn_loss:0.1082 tst_loss:0.3077 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:5.1476 it:18801 ms/it:1.8865 tst_acc:0.9810 trn_loss:-0.0000 tst_loss:0.1839 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0000 it:18901 ms/it:1.8765 tst_acc:0.9825 trn_loss:0.0060 tst_loss:0.2067 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.9773 it:19001 ms/it:2.0174 tst_acc:0.9795 trn_loss:0.0072 tst_loss:0.2197 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.6490 it:19101 ms/it:1.8839 tst_acc:0.9760 trn_loss:0.4187 tst_loss:0.3126 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:6.0084 it:19201 ms/it:1.9739 tst_acc:0.9865 trn_loss:0.2705 tst_loss:0.1166 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:7.4380 it:19301 ms/it:1.8977 tst_acc:0.9805 trn_loss:0.3683 tst_loss:0.1763 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:6.3306 it:19401 ms/it:1.8479 tst_acc:0.9790 trn_loss:0.0000 tst_loss:0.1369 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0001 it:19501 ms/it:1.8434 tst_acc:0.9820 trn_loss:0.0000 tst_loss:0.1748 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0064 it:19601 ms/it:1.8677 tst_acc:0.9775 trn_loss:0.0005 tst_loss:0.3086 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0919 it:19701 ms/it:1.9377 tst_acc:0.9685 trn_loss:0.0092 tst_loss:0.3331 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:1.5948 it:19801 ms/it:1.8509 tst_acc:0.9825 trn_loss:-0.0000 tst_loss:0.1835 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0000 it:19901 ms/it:1.9325 tst_acc:0.9780 trn_loss:-0.0000 tst_loss:0.1955 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:0.0000 it:20000 ms/it:1.8418 tst_acc:0.9750 trn_loss:0.0257 tst_loss:0.2688 tst_nll:0.1681 tst_kl:0.2860 sum_norm_grad:3.1025
@tf.function(autograph=False)
def dnn_predicted_log_prob(x):
with tf.xla.experimental.jit_scope(compile_ops=True):
return tf.math.log_softmax(dnn(x).logits, axis=-1)
eval_iter = iter(eval_dataset.batch(2000))
dnn_predicted_log_probs = tf.reshape(
tf.stack([dnn_predicted_log_prob(x) for x, _ in eval_iter], axis=0),
shape=[datasets_info.splits['test'].num_examples, -1])
decision = tf.argmax(dnn_predicted_log_probs, axis=-1, output_type=tf.int32)
confidence = tf.reduce_max(dnn_predicted_log_probs, axis=-1)
threshold = 0.95
decided_idx = tf.where(confidence > np.log(threshold))
ordered = tf.argsort(confidence)
n = datasets_info.splits['test'].num_examples
x_final, y_final = next(iter(eval_dataset.batch(n)))
print('Number of examples undecided: {}'.format(n - tf.size(decided_idx)))
accurary = tf.reduce_mean(
tf.cast(tf.equal(tf.gather(y_final, decided_idx),
tf.gather(decision, decided_idx)),
tf.float32))
print('Accurary after excluding undecided ones: {}'.format(accurary))
tfp_nn.util.display_imgs(
tf.gather(x_final, ordered[0:50]),
tf.gather(y_final, ordered[0:50]));
Number of examples undecided: 127 Accurary after excluding undecided ones: 0.98480707407
from sklearn import metrics
dnn_auc = np.array([
metrics.roc_auc_score(tf.equal(y_final, i), dnn_predicted_log_probs[:, i])
for i in range(10)])
print('Per class AUC:\n{}'.format(dnn_auc[:, np.newaxis]))
Per class AUC: [[0.99905177] [0.99976495] [0.9995598 ] [0.99976883] [0.99948807] [0.99970995] [0.99963312] [0.99960009] [0.99943057] [0.99946715]]