# code modified based on https://github.com/Ste29/Uncertainty-analysis/blob/master/scripts/BNN%20MNIST.ipynb
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
from tensorflow_probability.python.layers import DenseVariational, DenseReparameterization, DenseFlipout, Convolution2DFlipout, Convolution2DReparameterization
from tensorflow_probability.python.layers import DistributionLambda
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Input, Dense, Flatten, BatchNormalization, Activation, LeakyReLU
from tensorflow.keras.utils import plot_model
from tensorflow.keras.optimizers import *
tf.compat.v1.enable_eager_execution()
import numpy as np
from scipy.special import softmax
import matplotlib.pyplot as plt
%matplotlib inline
import random
random.seed(5296)
print('TensorFlow version:', tf.__version__)
print('TensorFlow Probability version:', tfp.__version__)
def neg_log_likelihood_with_logits(y_true, y_pred):
y_pred_dist = tfp.distributions.Categorical(logits=y_pred)
return -tf.reduce_mean(y_pred_dist.log_prob(tf.argmax(y_true, axis=-1)))
n_class = 10
batch_size = 128
n_epochs = 15
lr = 1e-3
print('Loading MNIST dataset')
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = np.expand_dims(X_train, -1)
n_train = X_train.shape[0]
X_test = np.expand_dims(X_test, -1)
y_train = tf.keras.utils.to_categorical(y_train, n_class)
y_test = tf.keras.utils.to_categorical(y_test, n_class)
# Normalize data
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255
print("X_train.shape =", X_train.shape)
print("y_train.shape =", y_train.shape)
print("X_test.shape =", X_test.shape)
print("y_test.shape =", y_test.shape)
plt.imshow(X_train[0, :, :, 0], cmap='gist_gray')
def get_kernel_divergence_fn(train_size, w=1.0):
"""
Get the kernel Kullback-Leibler divergence function
# Arguments
train_size (int): size of the training dataset for normalization
w (float): weight to the function
# Returns
kernel_divergence_fn: kernel Kullback-Leibler divergence function
"""
def kernel_divergence_fn(q, p, _): # need the third ignorable argument
kernel_divergence = tfp.distributions.kl_divergence(q, p) / tf.cast(train_size, tf.float32)
return w * kernel_divergence
return kernel_divergence_fn
def add_kl_weight(layer, train_size, w_value=1.0):
w = layer.add_weight(name=layer.name+'/kl_loss_weight', shape=(),
initializer=tf.initializers.constant(w_value), trainable=False)
layer.kernel_divergence_fn = get_kernel_divergence_fn(train_size, w)
return layer
def build_bayesian_bnn_model(input_shape, train_size):
model_in = Input(shape=input_shape)
model_in2 = tf.keras.layers.Flatten()(model_in)
dens_1 = tfp.layers.DenseFlipout(512, activation='relu', kernel_divergence_fn=None)
dens_1 = add_kl_weight(dens_1, train_size)
x = dens_1(model_in2)
x = BatchNormalization()(x)
x = tf.keras.layers.Dropout(0.2)(x)
dens_2 = tfp.layers.DenseFlipout(256, activation='relu', kernel_divergence_fn=None)
dens_2 = add_kl_weight(dens_2, train_size)
x = dens_2(x)
x = BatchNormalization()(x)
x = tf.keras.layers.Dropout(0.2)(x)
dens_3 = DenseFlipout(128, activation='relu',
kernel_divergence_fn=None)
dens_3 = add_kl_weight(dens_3, train_size)
x = dens_3(x)
x = tf.keras.layers.Dropout(0.2)(x)
dens_4 = DenseFlipout(10, activation=None,
kernel_divergence_fn=None)
dens_4 = add_kl_weight(dens_4, train_size)
model_out = dens_4(x) # logits
model = Model(model_in, model_out)
return model
bnn_model = build_bayesian_bnn_model(X_train.shape[1:], n_train)
bnn_model.compile(loss=neg_log_likelihood_with_logits, optimizer=Adam(lr), metrics=[tf.keras.metrics.CategoricalAccuracy()],
experimental_run_tf_function=False)
bnn_model.summary()
tf.keras.utils.plot_model(bnn_model, show_shapes=True)
early_stopping_cd = tf.keras.callbacks.EarlyStopping(patience=3)
hist_bnn = bnn_model.fit(X_train, y_train, batch_size=batch_size,
epochs=n_epochs, verbose=1, validation_split=0.1)
def build_nn_model(input_shape):
model_in = Input(shape=input_shape)
model_in2 = tf.keras.layers.Flatten()(model_in)
dens_1 = tf.keras.layers.Dense(512, activation='relu')
x = dens_1(model_in2)
x = BatchNormalization()(x)
x = tf.keras.layers.Dropout(0.2)(x)
dens_2 = tf.keras.layers.Dense(256, activation='relu')
x = dens_2(x)
x = BatchNormalization()(x)
x = tf.keras.layers.Dropout(0.2)(x)
dens_3 = tf.keras.layers.Dense(128, activation='relu')
x = dens_3(x)
x = tf.keras.layers.Dropout(0.2)(x)
dens_4 = tf.keras.layers.Dense(10, activation=None)
model_out = dens_4(x) # logits
model = Model(model_in, model_out)
return model
nn_model = build_nn_model(X_train.shape[1:])
nn_model.compile(loss=neg_log_likelihood_with_logits, optimizer=Adam(lr), metrics=[tf.keras.metrics.CategoricalAccuracy()],
experimental_run_tf_function=False)
nn_model.summary()
early_stopping_cd = tf.keras.callbacks.EarlyStopping(patience=3)
hist_nn = nn_model.fit(X_train, y_train, batch_size=batch_size,
epochs=n_epochs, verbose=1, validation_split=0.1)
# another way of finding the accuracy
preds_nn = nn_model.evaluate(X_test, y_test)
print ("Loss NN = {}" + str(preds_nn[0]))
print ("Test Accuracy NN = " + str(preds_nn[1]))
preds_bnn = bnn_model.evaluate(X_test, y_test)
print ("Loss BNN = {}" + str(preds_bnn[0]))
print ("Test Accuracy BNN = " + str(preds_bnn[1]))
plt.figure(figsize=(15, 15))
epoch = [i for i in range(1,16)]
plt.subplot(2,1,1)
plt.plot(epoch, hist_nn.history['categorical_accuracy'], label = 'Model NN')
plt.plot(epoch, hist_bnn.history['categorical_accuracy'], label = 'Model BNN')
plt.title('train accuracy')
plt.legend()
plt.subplot(2,1,2)
plt.plot(epoch, hist_nn.history['val_categorical_accuracy'], label = 'Model NN')
plt.plot(epoch, hist_bnn.history['val_categorical_accuracy'], label = 'Model BNN')
plt.title('validation accuracy')
plt.legend()
plt.show()
n_mc_run = 100
med_prob_thres = 0.2
y_pred_logits_list = [bnn_model.predict(X_test) for _ in range(n_mc_run)] # a list of predicted logits
y_pred_prob_all = np.concatenate([softmax(y, axis=-1)[:, :, np.newaxis] for y in y_pred_logits_list], axis=-1)
y_pred = [[int(np.median(y) >= med_prob_thres) for y in y_pred_prob] for y_pred_prob in y_pred_prob_all]
y_pred = np.array(y_pred)
idx_valid = [any(y) for y in y_pred]
print('Number of recognizable samples:', sum(idx_valid))
idx_invalid = [not any(y) for y in y_pred]
print('Unrecognizable samples:', np.where(idx_invalid)[0])
print('Test accuracy on MNIST (recognizable samples):',
sum(np.equal(np.argmax(y_test[idx_valid], axis=-1), np.argmax(y_pred[idx_valid], axis=-1))) / len(y_test[idx_valid]))
print('Test accuracy on MNIST (unrecognizable samples):',
sum(np.equal(np.argmax(y_test[idx_invalid], axis=-1), np.argmax(y_pred[idx_invalid], axis=-1))) / len(y_test[idx_invalid]))
def plot_pred_hist(y_pred, n_class, n_mc_run, n_bins=30, med_prob_thres=0.2, n_subplot_rows=2, figsize=(25, 10)):
bins = np.logspace(-n_bins, 0, n_bins+1)
fig, ax = plt.subplots(n_subplot_rows, n_class // n_subplot_rows + 1, figsize=figsize)
for i in range(n_subplot_rows):
for j in range(n_class // n_subplot_rows + 1):
idx = i * (n_class // n_subplot_rows + 1) + j
if idx < n_class:
ax[i, j].hist(y_pred[idx], bins)
ax[i, j].set_xscale('log')
ax[i, j].set_ylim([0, n_mc_run])
ax[i, j].title.set_text("{} (median prob: {:.2f}) ({})".format(str(idx),
np.median(y_pred[idx]),
str(np.median(y_pred[idx]) >= med_prob_thres)))
else:
ax[i, j].axis('off')
plt.show()
idx = 0
plt.imshow(X_test[idx, :, :, 0], cmap='gist_gray')
print("True label of the test sample {}: {}".format(idx, np.argmax(y_test[idx], axis=-1)))
plot_pred_hist(y_pred_prob_all[idx], n_class, n_mc_run, med_prob_thres=med_prob_thres)
if any(y_pred[idx]):
print("Predicted label of the test sample {}: {}".format(idx, np.argmax(y_pred[idx], axis=-1)))
else:
print("I don't know!")
for idx in np.where(idx_invalid)[0]:
plt.imshow(X_test[idx, :, :, 0], cmap='gist_gray')
print("True label of the test sample {}: {}".format(idx, np.argmax(y_test[idx], axis=-1)))
plot_pred_hist(y_pred_prob_all[idx], n_class, n_mc_run, med_prob_thres=med_prob_thres)
if any(y_pred[idx]):
print("Predicted label of the test sample {}: {}".format(idx, np.argmax(y_pred[idx], axis=-1)))
else:
print("I don't know!")