Skip to content

================ by Jawad Haider

MNIST Code Along with ANN

Before we start working with Convolutional Neural Networks (CNN), let’s model the MNIST dataset using only linear layers.
In this exercise we’ll use the same logic laid out in the ANN notebook. We’ll reshape the MNIST data from a 28x28 image to a flattened 1x784 vector to mimic a single row of 784 features.

Perform standard imports

Torchvision should have been installed by the environment file during setup. If not, you can install it now. At the terminal with your virtual environment activated, run

conda install torchvision -c pytorch
pip install torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F          # adds some efficiency
from import DataLoader  # lets us load data in batches
from torchvision import datasets, transforms

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix  # for evaluating results
import matplotlib.pyplot as plt
%matplotlib inline

Load the MNIST dataset

PyTorch makes the MNIST dataset available through torchvision. The first time it’s called, the dataset will be downloaded onto your computer to the path specified. From that point, torchvision will always look for a local copy before attempting another download. ### Define transform As part of the loading process, we can apply multiple transformations (reshape, convert to tensor, normalize, etc.) to the incoming data.
For this exercise we only need to convert images to tensors.

transform = transforms.ToTensor()

Load the training set

train_data = datasets.MNIST(root='../Data', train=True, download=True, transform=transform)
Dataset MNIST
    Number of datapoints: 60000
    Split: train
    Root Location: ../Data
    Transforms (if any): ToTensor()
    Target Transforms (if any): None

Load the test set

There’s a companion set of MNIST data containing 10,000 records accessible by setting train=False. As before, torchvision will only download this once, and in the future will look for the local copy.

test_data = datasets.MNIST(root='../Data', train=False, download=True, transform=transform)
Dataset MNIST
    Number of datapoints: 10000
    Split: test
    Root Location: ../Data
    Transforms (if any): ToTensor()
    Target Transforms (if any): None

Examine a training record

(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0706, 0.0706, 0.0706,
           0.4941, 0.5333, 0.6863, 0.1020, 0.6510, 1.0000, 0.9686, 0.4980,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.1176, 0.1412, 0.3686, 0.6039, 0.6667, 0.9922, 0.9922, 0.9922,
           0.9922, 0.9922, 0.8824, 0.6745, 0.9922, 0.9490, 0.7647, 0.2510,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1922,
           0.9333, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922,
           0.9922, 0.9843, 0.3647, 0.3216, 0.3216, 0.2196, 0.1529, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706,
           0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765, 0.7137,
           0.9686, 0.9451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.3137, 0.6118, 0.4196, 0.9922, 0.9922, 0.8039, 0.0431, 0.0000,
           0.1686, 0.6039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0549, 0.0039, 0.6039, 0.9922, 0.3529, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.5451, 0.9922, 0.7451, 0.0078, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0431, 0.7451, 0.9922, 0.2745, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.1373, 0.9451, 0.8824, 0.6275,
           0.4235, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3176, 0.9412, 0.9922,
           0.9922, 0.4667, 0.0980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.7294,
           0.9922, 0.9922, 0.5882, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627,
           0.3647, 0.9882, 0.9922, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.9765, 0.9922, 0.9765, 0.2510, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1804, 0.5098,
           0.7176, 0.9922, 0.9922, 0.8118, 0.0078, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.5804, 0.8980, 0.9922,
           0.9922, 0.9922, 0.9804, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0941, 0.4471, 0.8667, 0.9922, 0.9922, 0.9922,
           0.9922, 0.7882, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0902, 0.2588, 0.8353, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765,
           0.3176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.6706,
           0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.7647, 0.3137, 0.0353,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.2157, 0.6745, 0.8863, 0.9922,
           0.9922, 0.9922, 0.9922, 0.9569, 0.5216, 0.0431, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.5333, 0.9922, 0.9922, 0.9922,
           0.8314, 0.5294, 0.5176, 0.0627, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000]]]), 5)

Calling the first record from train_data returns a two-item tuple. The first item is our 28x28 tensor representing the image. The second is a label, in this case the number “5”.

image, label = train_data[0]
print('Shape:', image.shape, '\nLabel:', label)
Shape: torch.Size([1, 28, 28]) 
Label: 5

View the image

Matplotlib can interpret pixel values through a variety of colormaps.

plt.imshow(train_data[0][0].reshape((28,28)), cmap="gray");

plt.imshow(train_data[0][0].reshape((28,28)), cmap="gist_yarg");

Batch loading with DataLoader

Our training set contains 60,000 records. If we look ahead to our model we have 784 incoming features, hidden layers of 120 and 84 neurons, and 10 output features. Including the bias terms for each layer, the total number of parameters being trained is:

\(\begin{split}\quad(784\times120)+120+(120\times84)+84+(84\times10)+10 &=\\ 94080+120+10080+84+840+10 &= 105,214\end{split}\)

For this reason it makes sense to load training data in batches using DataLoader.

torch.manual_seed(101)  # for consistent results

train_loader = DataLoader(train_data, batch_size=100, shuffle=True)

test_loader = DataLoader(test_data, batch_size=500, shuffle=False)

In the cell above, train_data is a PyTorch Dataset object (an object that supports data loading and sampling).
The batch_size is the number of records to be processed at a time. If it’s not evenly divisible into the dataset, then the final batch contains the remainder.
Setting shuffle to True means that the dataset will be shuffled after each epoch.

NOTE: DataLoader takes an optional num_workers parameter that sets up how many subprocesses to use for data loading. This behaves differently with different operating systems so we’ve omitted it here. See the docs for more information.

View a batch of images

Once we’ve defined a DataLoader, we can create a grid of images using torchvision.utils.make_grid

from torchvision.utils import make_grid
np.set_printoptions(formatter=dict(int=lambda x: f'{x:4}')) # to widen the printed array

# Grab the first batch of images
for images,labels in train_loader: 

# Print the first 12 labels
print('Labels: ', labels[:12].numpy())

# Print the first 12 images
im = make_grid(images[:12], nrow=12)  # the default nrow is 8
# We need to transpose the images from CWH to WHC
plt.imshow(np.transpose(im.numpy(), (1, 2, 0)));
Labels:  [   0    5    7    8    6    7    9    7    1    3    8    4]

Define the model

For this exercise we’ll use fully connected layers to develop a multilayer perceptron.
Our input size is 784 once we flatten the incoming 28x28 tensors.
Our output size represents the 10 possible digits.
We’ll set our hidden layers to [120, 84] for now. Once you’ve completed the exercise feel free to come back and try different values.

class MultilayerPerceptron(nn.Module):
    def __init__(self, in_sz=784, out_sz=10, layers=[120,84]):
        self.fc1 = nn.Linear(in_sz,layers[0])
        self.fc2 = nn.Linear(layers[0],layers[1])
        self.fc3 = nn.Linear(layers[1],out_sz)

    def forward(self,X):
        X = F.relu(self.fc1(X))
        X = F.relu(self.fc2(X))
        X = self.fc3(X)
        return F.log_softmax(X, dim=1)
model = MultilayerPerceptron()
  (fc1): Linear(in_features=784, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
NOTE: You may have noticed our shortcut for adding ReLU to the linear layer. In the last section this was done under the **init** section as
layerlist = []
for i in layers:
self.layers = nn.Sequential(*layerlist)
Here we’re calling F.relu() as a functional wrapper on the linear layer directly:
def forward(self,X):
    X = F.relu(self.fc1(X))

Count the model parameters

This optional step shows that the number of trainable parameters in our model matches the equation above.

def count_parameters(model):
    params = [p.numel() for p in model.parameters() if p.requires_grad]
    for item in params:

Define loss function & optimizer

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Flatten the training data

The batch tensors fed in by DataLoader have a shape of [100, 1, 28, 28]:

# Load the first batch, print its shape
for images, labels in train_loader:
    print('Batch shape:', images.size())

# dataiter = iter(train_loader)
# images, labels =
# print('Batch shape:', images.size())
Batch shape: torch.Size([100, 1, 28, 28])

We can flatten them using .view()

torch.Size([100, 784])

We’ll do this just before applying the model to our data.

Train the model

This time we’ll run the test data through the model during each epoch, so that we can compare loss & accuracy on the same plot.

A QUICK NOTE: In the section below marked  \#Tally the number of correct predictions  we include the code
predicted = torch.max(, 1)[1]
This uses the torch.max() function. torch.max() returns a tensor of maximum values, and a tensor of the indices where the max values were found. In our code we’re asking for the index positions of the maximum values along dimension 1. In this way we can match predictions up to image labels.
import time
start_time = time.time()

epochs = 10
train_losses = []
test_losses = []
train_correct = []
test_correct = []

for i in range(epochs):
    trn_corr = 0
    tst_corr = 0

    # Run the training batches
    for b, (X_train, y_train) in enumerate(train_loader):

        # Apply the model
        y_pred = model(X_train.view(100, -1))  # Here we flatten X_train
        loss = criterion(y_pred, y_train)

        # Tally the number of correct predictions
        predicted = torch.max(, 1)[1]
        batch_corr = (predicted == y_train).sum()
        trn_corr += batch_corr

        # Update parameters

        # Print interim results
        if b%200 == 0:
            print(f'epoch: {i:2}  batch: {b:4} [{100*b:6}/60000]  loss: {loss.item():10.8f}  \
accuracy: {trn_corr.item()*100/(100*b):7.3f}%')

    # Update train loss & accuracy for the epoch

    # Run the testing batches
    with torch.no_grad():
        for b, (X_test, y_test) in enumerate(test_loader):

            # Apply the model
            y_val = model(X_test.view(500, -1))  # Here we flatten X_test

            # Tally the number of correct predictions
            predicted = torch.max(, 1)[1] 
            tst_corr += (predicted == y_test).sum()

    # Update test loss & accuracy for the epoch
    loss = criterion(y_val, y_test)

print(f'\nDuration: {time.time() - start_time:.0f} seconds') # print the time elapsed            
epoch:  0  batch:  200 [ 20000/60000]  loss: 0.35221729  accuracy:  82.695%
epoch:  0  batch:  400 [ 40000/60000]  loss: 0.32761699  accuracy:  87.340%
epoch:  0  batch:  600 [ 60000/60000]  loss: 0.31156573  accuracy:  89.490%
epoch:  1  batch:  200 [ 20000/60000]  loss: 0.20120722  accuracy:  94.800%
epoch:  1  batch:  400 [ 40000/60000]  loss: 0.14656080  accuracy:  95.185%
epoch:  1  batch:  600 [ 60000/60000]  loss: 0.12691295  accuracy:  95.478%
epoch:  2  batch:  200 [ 20000/60000]  loss: 0.13621402  accuracy:  96.815%
epoch:  2  batch:  400 [ 40000/60000]  loss: 0.07235763  accuracy:  96.790%
epoch:  2  batch:  600 [ 60000/60000]  loss: 0.04241359  accuracy:  96.878%
epoch:  3  batch:  200 [ 20000/60000]  loss: 0.09474990  accuracy:  97.635%
epoch:  3  batch:  400 [ 40000/60000]  loss: 0.06394162  accuracy:  97.600%
epoch:  3  batch:  600 [ 60000/60000]  loss: 0.07836709  accuracy:  97.562%
epoch:  4  batch:  200 [ 20000/60000]  loss: 0.05509195  accuracy:  98.135%
epoch:  4  batch:  400 [ 40000/60000]  loss: 0.06395346  accuracy:  98.125%
epoch:  4  batch:  600 [ 60000/60000]  loss: 0.05392118  accuracy:  98.105%
epoch:  5  batch:  200 [ 20000/60000]  loss: 0.03487724  accuracy:  98.515%
epoch:  5  batch:  400 [ 40000/60000]  loss: 0.03120600  accuracy:  98.433%
epoch:  5  batch:  600 [ 60000/60000]  loss: 0.03449132  accuracy:  98.402%
epoch:  6  batch:  200 [ 20000/60000]  loss: 0.04473587  accuracy:  98.770%
epoch:  6  batch:  400 [ 40000/60000]  loss: 0.05389304  accuracy:  98.770%
epoch:  6  batch:  600 [ 60000/60000]  loss: 0.04762774  accuracy:  98.685%
epoch:  7  batch:  200 [ 20000/60000]  loss: 0.01370908  accuracy:  98.885%
epoch:  7  batch:  400 [ 40000/60000]  loss: 0.01426961  accuracy:  98.945%
epoch:  7  batch:  600 [ 60000/60000]  loss: 0.04490321  accuracy:  98.902%
epoch:  8  batch:  200 [ 20000/60000]  loss: 0.02279496  accuracy:  99.150%
epoch:  8  batch:  400 [ 40000/60000]  loss: 0.03816750  accuracy:  99.060%
epoch:  8  batch:  600 [ 60000/60000]  loss: 0.02311455  accuracy:  99.055%
epoch:  9  batch:  200 [ 20000/60000]  loss: 0.01244260  accuracy:  99.330%
epoch:  9  batch:  400 [ 40000/60000]  loss: 0.00740430  accuracy:  99.340%
epoch:  9  batch:  600 [ 60000/60000]  loss: 0.01638621  accuracy:  99.280%

Duration: 275 seconds

Plot the loss and accuracy comparisons

plt.plot(train_losses, label='training loss')
plt.plot(test_losses, label='validation loss')
plt.title('Loss at the end of each epoch')

This shows some evidence of overfitting the training data.

plt.plot([t/600 for t in train_correct], label='training accuracy')
plt.plot([t/100 for t in test_correct], label='validation accuracy')
plt.title('Accuracy at the end of each epoch')

Evaluate Test Data

We retained the test scores during our training session:

print(test_correct) # contains the results of all 10 epochs
print(f'Test accuracy: {test_correct[-1].item()*100/10000:.3f}%') # print the most recent result as a percent
[tensor(9439), tensor(9635), tensor(9666), tensor(9726), tensor(9746), tensor(9758), tensor(9737), tensor(9749), tensor(9746), tensor(9725)]

Test accuracy: 97.250%

However, we’d like to compare the predicted values to the ground truth (the y_test labels), so we’ll run the test set through the trained model all at once.

# Extract the data all at once, not in batches
test_load_all = DataLoader(test_data, batch_size=10000, shuffle=False)
with torch.no_grad():
    correct = 0
    for X_test, y_test in test_load_all:
        y_val = model(X_test.view(len(X_test), -1))  # pass in a flattened view of X_test
        predicted = torch.max(y_val,1)[1]
        correct += (predicted == y_test).sum()
print(f'Test accuracy: {correct.item()}/{len(test_data)} = {correct.item()*100/(len(test_data)):7.3f}%')
Test accuracy: 9725/10000 =  97.250%

Not bad considering that a random guess gives only 10% accuracy!

Display the confusion matrix

This uses scikit-learn, and the predicted values obtained above.

# print a row of values for reference
np.set_printoptions(formatter=dict(int=lambda x: f'{x:4}'))

# print the confusion matrix
print(confusion_matrix(predicted.view(-1), y_test.view(-1)))
[[   0    1    2    3    4    5    6    7    8    9]]

[[ 968    0    1    0    1    2    5    0    4    1]
 [   0 1126    4    0    0    0    2    5    0    4]
 [   2    1 1007    4    3    0    0   10    5    0]
 [   1    0    5  985    0    2    1    3    4    7]
 [   1    0    1    0  962    1    3    1    3   12]
 [   2    1    1   13    0  882   33    1   23   16]
 [   1    2    2    0    5    1  913    0    1    0]
 [   1    1    5    5    3    2    0 1003    7    7]
 [   2    4    6    2    1    2    1    1  925    8]
 [   2    0    0    1    7    0    0    4    2  954]]

This shows that the model had the greatest success with ones, twos and sevens, and the lowest with fives, sixes and eights.

Examine the misses

We can track the index positions of “missed” predictions, and extract the corresponding image and label. We’ll do this in batches to save screen space.

misses = np.array([])
for i in range(len(predicted.view(-1))):
    if predicted[i] != y_test[i]:
        misses = np.append(misses,i).astype('int64')

# Display the number of misses
# Display the first 10 index positions
array([  61,   62,   81,  104,  115,  151,  193,  217,  247,  259],
# Set up an iterator to feed batched rows
r = 12   # row size
row = iter(np.array_split(misses,len(misses)//r+1))

Now that everything is set up, run and re-run the cell below to view all of the missed predictions.
Use Ctrl+Enter to remain on the cell between runs. You’ll see a StopIteration once all the misses have been seen.

nextrow = next(row)
print("Index:", nextrow)
print("Label:", y_test.index_select(0,torch.tensor(nextrow)).numpy())
print("Guess:", predicted.index_select(0,torch.tensor(nextrow)).numpy())

images = X_test.index_select(0,torch.tensor(nextrow))
im = make_grid(images, nrow=r)
plt.imshow(np.transpose(im.numpy(), (1, 2, 0)));
Index: [  61   62   81  104  115  151  193  217  247  259  264  320]
Label: [   8    9    6    9    4    9    9    6    4    6    9    9]
Guess: [   2    5    5    5    9    8    8    5    6    0    4    8]

Great job!


Copyright Qalmaqihir
For more information, visit us at