Aim
To load the MNIST handwritten digit dataset using Keras/TensorFlow, explore its structure, and visualize sample images along with their pixel intensity distributions.
Prerequisites
Theory
The MNIST (Modified National Institute of Standards and Technology) dataset is one of the most widely used benchmarks in machine learning. It contains 70,000 grayscale images of handwritten digits (0–9), each of size 28×28 pixels. The dataset is split into 60,000 training samples and 10,000 test samples. Each pixel value ranges from 0 (black) to 255 (white), representing the intensity of that pixel.
The dataset is organized as a 3D NumPy array of shape (N, 28, 28) where N is the number of samples. Each 28×28 matrix is a 2D representation of a digit image. For fully connected neural networks, this 2D matrix must be flattened to a 1D vector of size 784 (28×28=784). For CNNs, a channel dimension is added, making the shape (N, 28, 28, 1).
Visualization is critical for understanding data quality, class distribution, and identifying anomalies. Using Matplotlib's imshow() function with the 'gray' colormap, we can render each 28×28 array as a human-readable digit image. Bar charts of class distributions help reveal if the dataset is well balanced across all categories, which informs decisions about model training.
The labels are integers 0–9 stored as a 1D array of shape (N,). For multi-class classification, these must often be one-hot encoded — converting an integer label into a binary vector where only the index corresponding to the class is 1 and all others are 0. For example, label 3 becomes [0,0,0,1,0,0,0,0,0,0].
Algorithm / Step-by-Step
- Import required libraries:
tensorflow,matplotlib.pyplot, andnumpy. - Load the MNIST dataset using
tf.keras.datasets.mnist.load_data()which returns (x_train, y_train), (x_test, y_test). - Print the shapes of training and test arrays to verify the tensor dimensions.
- Display the first 9 images from the training set in a 3×3 grid using Matplotlib subplots, setting titles to their corresponding labels.
- Calculate the class distribution (count per digit 0–9) using numpy's unique() with return_counts=True.
- Plot a bar chart of the class frequencies to visualize the dataset balance.
Key Code Concepts
Snippet 1 — Loading and Inspecting MNIST
import tensorflow as tf import matplotlib.pyplot as plt import numpy as np # Load the MNIST dataset mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() # Verify the dimensions of the tensors print(f'Training data shape: {x_train.shape}') print(f'Training labels shape: {y_train.shape}') print(f'Test data shape: {x_test.shape}') print(f'Test labels shape: {y_test.shape}')
The load_data() call automatically downloads the MNIST dataset (if not cached) and returns it as NumPy arrays. We then print the shapes of these tensors to verify their dimensions.
Snippet 2 — Visualizing a Grid of Images
# Set up a matplotlib figure and axis grid plt.figure(figsize=(10, 10)) # Iterate through the first 9 images in the training set for i in range(9): plt.subplot(3, 3, i + 1) plt.imshow(x_train[i], cmap='gray') plt.title(f'Label: {y_train[i]}') plt.axis('off') plt.tight_layout() plt.show()
We use a for loop and plt.subplot(3, 3, i + 1) to arrange the first 9 images into a 3x3 grid. The cmap='gray' argument renders the single-channel image in grayscale, and plt.axis('off') hides the coordinate axes.
Snippet 3 — Class Distribution Analysis
# Count samples per digit class classes, counts = np.unique(y_train, return_counts=True) print("Class Distribution:", dict(zip(classes, counts))) # Bar chart plt.bar(classes, counts, color='steelblue') plt.xlabel("Digit Class") plt.ylabel("Count") plt.title("MNIST Training Set Class Distribution") plt.show()
This reveals whether the dataset is balanced (approximately equal counts per class) or imbalanced. MNIST is nearly balanced with ~6,000 samples per digit.
Expected Output
Console Output: Shape prints confirming x_train = (60000, 28, 28), y_train = (60000,), x_test = (10000, 28, 28), and the Class Distribution dictionary.
Figure 1: A 3×3 grid of the first 9 grayscale digit images (0–9) with their correct labels shown as titles. Images appear clear and distinguishable without coordinate axes.
Figure 2: A bar chart showing ~5,000–7,000 samples per digit class, confirming a well-balanced dataset with no significant class imbalance.
