In [ ]:
# code modified based on https://github.com/Ste29/Uncertainty-analysis/blob/master/scripts/BNN%20MNIST.ipynb
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
from tensorflow_probability.python.layers import DenseVariational, DenseReparameterization, DenseFlipout, Convolution2DFlipout, Convolution2DReparameterization
from tensorflow_probability.python.layers import DistributionLambda
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Input, Dense, Flatten, BatchNormalization, Activation, LeakyReLU
from tensorflow.keras.utils import plot_model
from tensorflow.keras.optimizers import *
tf.compat.v1.enable_eager_execution()

import numpy as np
from scipy.special import softmax
import matplotlib.pyplot as plt

%matplotlib inline

import random
random.seed(5296)

print('TensorFlow version:', tf.__version__)
print('TensorFlow Probability version:', tfp.__version__)
TensorFlow version: 2.3.0
TensorFlow Probability version: 0.11.0

Define the loss function of negative log-likelihood (input as logits)

In [ ]:
def neg_log_likelihood_with_logits(y_true, y_pred):
    y_pred_dist = tfp.distributions.Categorical(logits=y_pred)
    return -tf.reduce_mean(y_pred_dist.log_prob(tf.argmax(y_true, axis=-1)))

Load MNIST dataset

In [ ]:
n_class = 10

batch_size = 128
n_epochs = 15
lr = 1e-3

print('Loading MNIST dataset')
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = np.expand_dims(X_train, -1)
n_train = X_train.shape[0]
X_test = np.expand_dims(X_test, -1)
y_train = tf.keras.utils.to_categorical(y_train, n_class)
y_test = tf.keras.utils.to_categorical(y_test, n_class)

# Normalize data
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255

print("X_train.shape =", X_train.shape)
print("y_train.shape =", y_train.shape)
print("X_test.shape =", X_test.shape)
print("y_test.shape =", y_test.shape)

plt.imshow(X_train[0, :, :, 0], cmap='gist_gray')
Loading MNIST dataset
X_train.shape = (60000, 28, 28, 1)
y_train.shape = (60000, 10)
X_test.shape = (10000, 28, 28, 1)
y_test.shape = (10000, 10)
Out[ ]:
<matplotlib.image.AxesImage at 0x7f9d0d465cc0>

BNN model: Define the kernel divergence function that comes with a weight

In [ ]:
def get_kernel_divergence_fn(train_size, w=1.0):
    """
    Get the kernel Kullback-Leibler divergence function

    # Arguments
        train_size (int): size of the training dataset for normalization
        w (float): weight to the function

    # Returns
        kernel_divergence_fn: kernel Kullback-Leibler divergence function
    """
    def kernel_divergence_fn(q, p, _):  # need the third ignorable argument
        kernel_divergence = tfp.distributions.kl_divergence(q, p) / tf.cast(train_size, tf.float32)
        return w * kernel_divergence
    return kernel_divergence_fn
In [ ]:
def add_kl_weight(layer, train_size, w_value=1.0):
    w = layer.add_weight(name=layer.name+'/kl_loss_weight', shape=(),
                         initializer=tf.initializers.constant(w_value), trainable=False)
    layer.kernel_divergence_fn = get_kernel_divergence_fn(train_size, w)
    return layer

Build and train the Bayesian NN model

In [ ]:
def build_bayesian_bnn_model(input_shape, train_size):
    model_in = Input(shape=input_shape)
    model_in2 = tf.keras.layers.Flatten()(model_in)
    dens_1 = tfp.layers.DenseFlipout(512, activation='relu', kernel_divergence_fn=None)
    dens_1 = add_kl_weight(dens_1, train_size)
    x = dens_1(model_in2)
    x = BatchNormalization()(x)
    x = tf.keras.layers.Dropout(0.2)(x)

    dens_2 = tfp.layers.DenseFlipout(256, activation='relu', kernel_divergence_fn=None)
    dens_2 = add_kl_weight(dens_2, train_size)
    x = dens_2(x)
    x = BatchNormalization()(x)
    x = tf.keras.layers.Dropout(0.2)(x)

    dens_3 = DenseFlipout(128, activation='relu',
                           kernel_divergence_fn=None)
    dens_3 = add_kl_weight(dens_3, train_size)
    x = dens_3(x)
    x = tf.keras.layers.Dropout(0.2)(x)
    
    dens_4 = DenseFlipout(10, activation=None,
                           kernel_divergence_fn=None)
    dens_4 = add_kl_weight(dens_4, train_size)
    model_out = dens_4(x)  # logits
    model = Model(model_in, model_out)
    return model
    
bnn_model = build_bayesian_bnn_model(X_train.shape[1:], n_train)
bnn_model.compile(loss=neg_log_likelihood_with_logits, optimizer=Adam(lr), metrics=[tf.keras.metrics.CategoricalAccuracy()],
                   experimental_run_tf_function=False)
bnn_model.summary()
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/layers/util.py:106: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense_flipout (DenseFlipout) (None, 512)               803329    
_________________________________________________________________
batch_normalization (BatchNo (None, 512)               2048      
_________________________________________________________________
dropout (Dropout)            (None, 512)               0         
_________________________________________________________________
dense_flipout_1 (DenseFlipou (None, 256)               262401    
_________________________________________________________________
batch_normalization_1 (Batch (None, 256)               1024      
_________________________________________________________________
dropout_1 (Dropout)          (None, 256)               0         
_________________________________________________________________
dense_flipout_2 (DenseFlipou (None, 128)               65665     
_________________________________________________________________
dropout_2 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_flipout_3 (DenseFlipou (None, 10)                2571      
=================================================================
Total params: 1,137,038
Trainable params: 1,135,498
Non-trainable params: 1,540
_________________________________________________________________
In [ ]:
tf.keras.utils.plot_model(bnn_model, show_shapes=True)
Out[ ]:
In [ ]:
early_stopping_cd = tf.keras.callbacks.EarlyStopping(patience=3)

hist_bnn = bnn_model.fit(X_train, y_train, batch_size=batch_size,
            epochs=n_epochs, verbose=1, validation_split=0.1)
Epoch 1/15
422/422 [==============================] - 3s 8ms/step - loss: 23.8397 - categorical_accuracy: 0.8104 - val_loss: 22.6626 - val_categorical_accuracy: 0.9450
Epoch 2/15
422/422 [==============================] - 3s 7ms/step - loss: 21.9936 - categorical_accuracy: 0.9170 - val_loss: 21.0514 - val_categorical_accuracy: 0.9617
Epoch 3/15
422/422 [==============================] - 3s 7ms/step - loss: 20.3340 - categorical_accuracy: 0.9356 - val_loss: 19.4208 - val_categorical_accuracy: 0.9713
Epoch 4/15
422/422 [==============================] - 3s 7ms/step - loss: 18.6874 - categorical_accuracy: 0.9469 - val_loss: 17.8144 - val_categorical_accuracy: 0.9685
Epoch 5/15
422/422 [==============================] - 3s 7ms/step - loss: 17.0996 - categorical_accuracy: 0.9513 - val_loss: 16.2669 - val_categorical_accuracy: 0.9733
Epoch 6/15
422/422 [==============================] - 3s 7ms/step - loss: 15.6081 - categorical_accuracy: 0.9558 - val_loss: 14.8674 - val_categorical_accuracy: 0.9735
Epoch 7/15
422/422 [==============================] - 3s 7ms/step - loss: 14.2791 - categorical_accuracy: 0.9578 - val_loss: 13.6138 - val_categorical_accuracy: 0.9765
Epoch 8/15
422/422 [==============================] - 3s 7ms/step - loss: 13.1182 - categorical_accuracy: 0.9596 - val_loss: 12.5585 - val_categorical_accuracy: 0.9712
Epoch 9/15
422/422 [==============================] - 3s 7ms/step - loss: 12.1399 - categorical_accuracy: 0.9593 - val_loss: 11.6554 - val_categorical_accuracy: 0.9750
Epoch 10/15
422/422 [==============================] - 3s 7ms/step - loss: 11.3139 - categorical_accuracy: 0.9604 - val_loss: 10.8959 - val_categorical_accuracy: 0.9747
Epoch 11/15
422/422 [==============================] - 3s 7ms/step - loss: 10.6036 - categorical_accuracy: 0.9610 - val_loss: 10.2416 - val_categorical_accuracy: 0.9735
Epoch 12/15
422/422 [==============================] - 3s 7ms/step - loss: 9.9743 - categorical_accuracy: 0.9611 - val_loss: 9.6411 - val_categorical_accuracy: 0.9727
Epoch 13/15
422/422 [==============================] - 3s 7ms/step - loss: 9.4086 - categorical_accuracy: 0.9590 - val_loss: 9.0924 - val_categorical_accuracy: 0.9748
Epoch 14/15
422/422 [==============================] - 3s 7ms/step - loss: 8.8797 - categorical_accuracy: 0.9604 - val_loss: 8.5829 - val_categorical_accuracy: 0.9752
Epoch 15/15
422/422 [==============================] - 3s 7ms/step - loss: 8.3950 - categorical_accuracy: 0.9586 - val_loss: 8.1238 - val_categorical_accuracy: 0.9708

Regular NN model

In [ ]:
def build_nn_model(input_shape):
    model_in = Input(shape=input_shape)
    model_in2 = tf.keras.layers.Flatten()(model_in)

    dens_1 = tf.keras.layers.Dense(512, activation='relu')
    x = dens_1(model_in2)
    x = BatchNormalization()(x)
    x = tf.keras.layers.Dropout(0.2)(x)

    dens_2 = tf.keras.layers.Dense(256, activation='relu')
    x = dens_2(x)
    x = BatchNormalization()(x)
    x = tf.keras.layers.Dropout(0.2)(x)

    dens_3 = tf.keras.layers.Dense(128, activation='relu')
    x = dens_3(x)
    x = tf.keras.layers.Dropout(0.2)(x)

    dens_4 = tf.keras.layers.Dense(10, activation=None)
    model_out = dens_4(x)  # logits
    model = Model(model_in, model_out)
    return model
    
nn_model = build_nn_model(X_train.shape[1:])
nn_model.compile(loss=neg_log_likelihood_with_logits, optimizer=Adam(lr), metrics=[tf.keras.metrics.CategoricalAccuracy()],
                   experimental_run_tf_function=False)
nn_model.summary()
Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 512)               401920    
_________________________________________________________________
batch_normalization_2 (Batch (None, 512)               2048      
_________________________________________________________________
dropout_3 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 256)               131328    
_________________________________________________________________
batch_normalization_3 (Batch (None, 256)               1024      
_________________________________________________________________
dropout_4 (Dropout)          (None, 256)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 128)               32896     
_________________________________________________________________
dropout_5 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 10)                1290      
=================================================================
Total params: 570,506
Trainable params: 568,970
Non-trainable params: 1,536
_________________________________________________________________
In [ ]:
early_stopping_cd = tf.keras.callbacks.EarlyStopping(patience=3)

hist_nn = nn_model.fit(X_train, y_train, batch_size=batch_size,
            epochs=n_epochs, verbose=1, validation_split=0.1)
Epoch 1/15
422/422 [==============================] - 2s 4ms/step - loss: 0.2767 - categorical_accuracy: 0.9147 - val_loss: 0.0955 - val_categorical_accuracy: 0.9713
Epoch 2/15
422/422 [==============================] - 2s 4ms/step - loss: 0.1322 - categorical_accuracy: 0.9598 - val_loss: 0.0849 - val_categorical_accuracy: 0.9738
Epoch 3/15
422/422 [==============================] - 2s 4ms/step - loss: 0.0963 - categorical_accuracy: 0.9700 - val_loss: 0.0842 - val_categorical_accuracy: 0.9757
Epoch 4/15
422/422 [==============================] - 2s 4ms/step - loss: 0.0793 - categorical_accuracy: 0.9753 - val_loss: 0.0789 - val_categorical_accuracy: 0.9783
Epoch 5/15
422/422 [==============================] - 2s 4ms/step - loss: 0.0682 - categorical_accuracy: 0.9786 - val_loss: 0.0848 - val_categorical_accuracy: 0.9760
Epoch 6/15
422/422 [==============================] - 2s 4ms/step - loss: 0.0598 - categorical_accuracy: 0.9805 - val_loss: 0.0692 - val_categorical_accuracy: 0.9807
Epoch 7/15
422/422 [==============================] - 2s 4ms/step - loss: 0.0518 - categorical_accuracy: 0.9830 - val_loss: 0.0755 - val_categorical_accuracy: 0.9812
Epoch 8/15
422/422 [==============================] - 2s 4ms/step - loss: 0.0500 - categorical_accuracy: 0.9839 - val_loss: 0.0739 - val_categorical_accuracy: 0.9813
Epoch 9/15
422/422 [==============================] - 2s 4ms/step - loss: 0.0434 - categorical_accuracy: 0.9858 - val_loss: 0.0652 - val_categorical_accuracy: 0.9815
Epoch 10/15
422/422 [==============================] - 2s 4ms/step - loss: 0.0453 - categorical_accuracy: 0.9854 - val_loss: 0.0640 - val_categorical_accuracy: 0.9833
Epoch 11/15
422/422 [==============================] - 2s 4ms/step - loss: 0.0388 - categorical_accuracy: 0.9867 - val_loss: 0.0692 - val_categorical_accuracy: 0.9823
Epoch 12/15
422/422 [==============================] - 2s 4ms/step - loss: 0.0322 - categorical_accuracy: 0.9896 - val_loss: 0.0715 - val_categorical_accuracy: 0.9837
Epoch 13/15
422/422 [==============================] - 2s 4ms/step - loss: 0.0343 - categorical_accuracy: 0.9888 - val_loss: 0.0698 - val_categorical_accuracy: 0.9830
Epoch 14/15
422/422 [==============================] - 2s 4ms/step - loss: 0.0310 - categorical_accuracy: 0.9895 - val_loss: 0.0644 - val_categorical_accuracy: 0.9857
Epoch 15/15
422/422 [==============================] - 2s 4ms/step - loss: 0.0293 - categorical_accuracy: 0.9904 - val_loss: 0.0786 - val_categorical_accuracy: 0.9813

finding test accuracy

In [ ]:
# another way of finding the accuracy
preds_nn = nn_model.evaluate(X_test, y_test)
print ("Loss NN = {}" + str(preds_nn[0]))
print ("Test Accuracy NN = " + str(preds_nn[1]))
313/313 [==============================] - 1s 2ms/step - loss: 0.0782 - categorical_accuracy: 0.9800
Loss NN = {}0.07821856439113617
Test Accuracy NN = 0.9800000190734863
In [ ]:
preds_bnn = bnn_model.evaluate(X_test, y_test)
print ("Loss BNN = {}" + str(preds_bnn[0]))
print ("Test Accuracy BNN = " + str(preds_bnn[1]))
313/313 [==============================] - 1s 4ms/step - loss: 8.1287 - categorical_accuracy: 0.9689
Loss BNN = {}8.128716468811035
Test Accuracy BNN = 0.9689000248908997

accuarcy plots

In [ ]:
plt.figure(figsize=(15, 15))
epoch = [i for i in range(1,16)]

plt.subplot(2,1,1)
plt.plot(epoch, hist_nn.history['categorical_accuracy'], label = 'Model NN')
plt.plot(epoch, hist_bnn.history['categorical_accuracy'], label = 'Model BNN')
plt.title('train accuracy')
plt.legend()

plt.subplot(2,1,2)
plt.plot(epoch, hist_nn.history['val_categorical_accuracy'], label = 'Model NN')
plt.plot(epoch, hist_bnn.history['val_categorical_accuracy'], label = 'Model BNN')
plt.title('validation accuracy')
plt.legend()

plt.show()

Quantify the uncertainty in predictions for bnn

In [ ]:
n_mc_run = 100
med_prob_thres = 0.2


y_pred_logits_list = [bnn_model.predict(X_test) for _ in range(n_mc_run)]  # a list of predicted logits
y_pred_prob_all = np.concatenate([softmax(y, axis=-1)[:, :, np.newaxis] for y in y_pred_logits_list], axis=-1)
y_pred = [[int(np.median(y) >= med_prob_thres) for y in y_pred_prob] for y_pred_prob in y_pred_prob_all]
y_pred = np.array(y_pred)

idx_valid = [any(y) for y in y_pred]
print('Number of recognizable samples:', sum(idx_valid))

idx_invalid = [not any(y) for y in y_pred]
print('Unrecognizable samples:', np.where(idx_invalid)[0])

print('Test accuracy on MNIST (recognizable samples):',
      sum(np.equal(np.argmax(y_test[idx_valid], axis=-1), np.argmax(y_pred[idx_valid], axis=-1))) / len(y_test[idx_valid]))

print('Test accuracy on MNIST (unrecognizable samples):',
      sum(np.equal(np.argmax(y_test[idx_invalid], axis=-1), np.argmax(y_pred[idx_invalid], axis=-1))) / len(y_test[idx_invalid]))
Number of recognizable samples: 9996
Unrecognizable samples: [1941 2266 6625 9634]
Test accuracy on MNIST (recognizable samples): 0.9746898759503801
Test accuracy on MNIST (unrecognizable samples): 0.25
In [ ]:
def plot_pred_hist(y_pred, n_class, n_mc_run, n_bins=30, med_prob_thres=0.2, n_subplot_rows=2, figsize=(25, 10)):
    bins = np.logspace(-n_bins, 0, n_bins+1)
    fig, ax = plt.subplots(n_subplot_rows, n_class // n_subplot_rows + 1, figsize=figsize)
    for i in range(n_subplot_rows):
        for j in range(n_class // n_subplot_rows + 1):
            idx = i * (n_class // n_subplot_rows + 1) + j
            if idx < n_class:
                ax[i, j].hist(y_pred[idx], bins)
                ax[i, j].set_xscale('log')
                ax[i, j].set_ylim([0, n_mc_run])
                ax[i, j].title.set_text("{} (median prob: {:.2f}) ({})".format(str(idx),
                                                                               np.median(y_pred[idx]),
                                                                               str(np.median(y_pred[idx]) >= med_prob_thres)))
            else:
                ax[i, j].axis('off')
    plt.show()

a recognizable example

In [ ]:
idx = 0
plt.imshow(X_test[idx, :, :, 0], cmap='gist_gray')
print("True label of the test sample {}: {}".format(idx, np.argmax(y_test[idx], axis=-1)))

plot_pred_hist(y_pred_prob_all[idx], n_class, n_mc_run, med_prob_thres=med_prob_thres)

if any(y_pred[idx]):
    print("Predicted label of the test sample {}: {}".format(idx, np.argmax(y_pred[idx], axis=-1)))
else:
    print("I don't know!")
True label of the test sample 0: 7
Predicted label of the test sample 0: 7

Unrecognizable examples

In [ ]:
for idx in np.where(idx_invalid)[0]:
    plt.imshow(X_test[idx, :, :, 0], cmap='gist_gray')
    print("True label of the test sample {}: {}".format(idx, np.argmax(y_test[idx], axis=-1)))

    plot_pred_hist(y_pred_prob_all[idx], n_class, n_mc_run, med_prob_thres=med_prob_thres)

    if any(y_pred[idx]):
        print("Predicted label of the test sample {}: {}".format(idx, np.argmax(y_pred[idx], axis=-1)))
    else:
        print("I don't know!")
True label of the test sample 1941: 7
I don't know!
True label of the test sample 2266: 1
I don't know!
True label of the test sample 6625: 8
I don't know!
True label of the test sample 9634: 0
I don't know!