Skip to content

TF2 0 RNN MNIST

================ by Jawad Haider

RNN MNIST

# Install TensorFlow
# !pip install -q tensorflow-gpu==2.0.0-beta1

try:
  %tensorflow_version 2.x  # Colab only.
except Exception:
  pass

import tensorflow as tf
print(tf.__version__)
     |████████████████████████████████| 348.9MB 70kB/s 
     |████████████████████████████████| 3.1MB 33.3MB/s 
     |████████████████████████████████| 501kB 41.2MB/s 
2.0.0-beta1
# Load in the data
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
print("x_train.shape:", x_train.shape)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
x_train.shape: (60000, 28, 28)
# Build the model
i = Input(shape=x_train[0].shape)
x = LSTM(128)(i)
x = Dense(10, activation='softmax')(x)

model = Model(i, x)
# Compile and train
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
r = model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=10)
WARNING: Logging before flag parsing goes to stderr.
W0803 17:03:38.451040 140542585788288 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

Train on 60000 samples, validate on 10000 samples
Epoch 1/10
60000/60000 [==============================] - 25s 411us/sample - loss: 0.3334 - accuracy: 0.8929 - val_loss: 0.1187 - val_accuracy: 0.9629
Epoch 2/10
60000/60000 [==============================] - 21s 354us/sample - loss: 0.1046 - accuracy: 0.9692 - val_loss: 0.0792 - val_accuracy: 0.9759
Epoch 3/10
60000/60000 [==============================] - 21s 357us/sample - loss: 0.0753 - accuracy: 0.9774 - val_loss: 0.0770 - val_accuracy: 0.9756
Epoch 4/10
60000/60000 [==============================] - 22s 359us/sample - loss: 0.0576 - accuracy: 0.9827 - val_loss: 0.0601 - val_accuracy: 0.9803
Epoch 5/10
60000/60000 [==============================] - 21s 358us/sample - loss: 0.0453 - accuracy: 0.9862 - val_loss: 0.0525 - val_accuracy: 0.9833
Epoch 6/10
60000/60000 [==============================] - 21s 353us/sample - loss: 0.0395 - accuracy: 0.9877 - val_loss: 0.0448 - val_accuracy: 0.9864
Epoch 7/10
60000/60000 [==============================] - 21s 352us/sample - loss: 0.0331 - accuracy: 0.9897 - val_loss: 0.0476 - val_accuracy: 0.9857
Epoch 8/10
60000/60000 [==============================] - 21s 353us/sample - loss: 0.0292 - accuracy: 0.9908 - val_loss: 0.0416 - val_accuracy: 0.9884
Epoch 9/10
60000/60000 [==============================] - 21s 357us/sample - loss: 0.0246 - accuracy: 0.9923 - val_loss: 0.0397 - val_accuracy: 0.9887
Epoch 10/10
60000/60000 [==============================] - 21s 356us/sample - loss: 0.0228 - accuracy: 0.9927 - val_loss: 0.0409 - val_accuracy: 0.9881
# Plot loss per iteration
import matplotlib.pyplot as plt
plt.plot(r.history['loss'], label='loss')
plt.plot(r.history['val_loss'], label='val_loss')
plt.legend()
<matplotlib.legend.Legend at 0x7fd1e8624a58>

# Plot accuracy per iteration
plt.plot(r.history['accuracy'], label='acc')
plt.plot(r.history['val_accuracy'], label='val_acc')
plt.legend()
<matplotlib.legend.Legend at 0x7fd1e5dd0358>

# Plot confusion matrix
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
  """
  This function prints and plots the confusion matrix.
  Normalization can be applied by setting `normalize=True`.
  """
  if normalize:
      cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
      print("Normalized confusion matrix")
  else:
      print('Confusion matrix, without normalization')

  print(cm)

  plt.imshow(cm, interpolation='nearest', cmap=cmap)
  plt.title(title)
  plt.colorbar()
  tick_marks = np.arange(len(classes))
  plt.xticks(tick_marks, classes, rotation=45)
  plt.yticks(tick_marks, classes)

  fmt = '.2f' if normalize else 'd'
  thresh = cm.max() / 2.
  for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
      plt.text(j, i, format(cm[i, j], fmt),
               horizontalalignment="center",
               color="white" if cm[i, j] > thresh else "black")

  plt.tight_layout()
  plt.ylabel('True label')
  plt.xlabel('Predicted label')
  plt.show()


p_test = model.predict(x_test).argmax(axis=1)
cm = confusion_matrix(y_test, p_test)
plot_confusion_matrix(cm, list(range(10)))

# Do these results make sense?
# It's easy to confuse 9 <--> 4, 9 <--> 7, 2 <--> 7, etc. 
Confusion matrix, without normalization
[[ 975    0    2    0    0    0    2    1    0    0]
 [   0 1129    1    1    0    0    1    3    0    0]
 [   0    0 1019    1    1    0    0    8    3    0]
 [   0    0    2 1001    0    3    0    2    0    2]
 [   0    0    1    0  965    0    2    3    1   10]
 [   0    0    0    8    0  881    1    2    0    0]
 [   3    2    0    1    2    2  948    0    0    0]
 [   0    4    1    1    0    0    0 1019    1    2]
 [   2    0    2    5    0    9    0    2  954    0]
 [   0    0    0    6    5    3    0    2    3  990]]

# Show some misclassified examples
misclassified_idx = np.where(p_test != y_test)[0]
i = np.random.choice(misclassified_idx)
plt.imshow(x_test[i], cmap='gray')
plt.title("True label: %s Predicted: %s" % (y_test[i], p_test[i]));

Logo

Copyright Qalmaqihir
For more information, visit us at www.github.com/qalmaqihir/