TensorFlow Guide: Guided Project 1 - Image Classification with CNNs

7. Guided Project 1: Image Classification with CNNs

This project will guide you through building a Convolutional Neural Network (CNN) to classify images from the CIFAR-10 dataset. CIFAR-10 consists of 60,000 32x32 color images in 10 classes (e.g., airplane, automobile, bird, cat). This project will solidify your understanding of data pipelines, model building with Keras, and training strategies.

Project Objective

Build and train a CNN model capable of classifying CIFAR-10 images with reasonable accuracy.

Step 1: Set up the Environment and Load Data

First, we’ll import necessary libraries and load the CIFAR-10 dataset.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import numpy as np

# Load the CIFAR-10 dataset
# The dataset returns (train_images, train_labels) and (test_images, test_labels)
(train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()

# Define class names for visualization
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

# Display basic information about the dataset
print(f"Train images shape: {train_images.shape}")
print(f"Train labels shape: {train_labels.shape}")
print(f"Test images shape: {test_images.shape}")
print(f"Test labels shape: {test_labels.shape}\n")

# Display a few sample images
plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i])
    # The labels are arrays, e.g., [6], so access the element with [0]
    plt.xlabel(class_names[train_labels[i][0]])
plt.suptitle("Sample CIFAR-10 Images")
plt.show()

Self-Check: Confirm that the shapes match the expected (50000, 32, 32, 3) for training images and (10000, 32, 32, 3) for test images, and that the sample images display correctly.

Step 2: Preprocessing and Data Augmentation Pipeline with tf.data

We’ll normalize the pixel values and apply data augmentation to the training set to prevent overfitting. We’ll use the tf.data API for an efficient input pipeline.

Encourage Independent Problem-Solving: Before looking at the solution, try to create a tf.data pipeline that:

  1. Normalizes pixel values from [0, 255] to [0, 1].
  2. Applies random horizontal flipping and random brightness adjustments to training images.
  3. Batches the data.
  4. Shuffles the training data.
  5. Prefetches data for optimized loading.
# Convert labels to one-hot encoding if using 'categorical_crossentropy' later,
# or keep as integers if using 'sparse_categorical_crossentropy'.
# For this project, we'll stick to 'sparse_categorical_crossentropy' for simplicity,
# so no need to one-hot encode labels at this stage.

# Normalize pixel values to be between 0 and 1
def normalize_img(image, label):
    return tf.cast(image, tf.float32) / 255.0, label

# Data augmentation for training images
def augment_img(image, label):
    # Random horizontal flip
    image = tf.image.random_flip_left_right(image)
    # Random brightness adjustment
    image = tf.image.random_brightness(image, max_delta=0.2)
    # Random contrast adjustment
    image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
    return image, label

BATCH_SIZE = 64
AUTOTUNE = tf.data.AUTOTUNE

# Create tf.data.Dataset objects
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))

# Apply preprocessing and augmentation for training data
train_dataset = train_dataset.map(normalize_img, num_parallel_calls=AUTOTUNE)
train_dataset = train_dataset.cache() # Cache data in memory after first pass
train_dataset = train_dataset.shuffle(buffer_size=50000) # Shuffle entire dataset
train_dataset = train_dataset.map(augment_img, num_parallel_calls=AUTOTUNE) # Apply augmentation after shuffling
train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(AUTOTUNE)

# Apply only preprocessing for test data (no augmentation)
test_dataset = test_dataset.map(normalize_img, num_parallel_calls=AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.cache() # Cache test data too
test_dataset = test_dataset.prefetch(AUTOTUNE)

print("\nData pipelines created and optimized.")
print(f"Sample batch from training dataset (images shape): {next(iter(train_dataset))[0].shape}")
print(f"Sample batch from training dataset (labels shape): {next(iter(train_dataset))[1].shape}")

Self-Check: The sample batch images shape should be (64, 32, 32, 3) and labels shape (64, 1).

Step 3: Define the CNN Model Architecture

We’ll build a standard CNN architecture using Keras Sequential API, consisting of convolutional layers, pooling layers, and dense layers.

Encourage Independent Problem-Solving: Try to design a CNN architecture for image classification. Consider using:

  1. Conv2D layers with ‘relu’ activation.
  2. MaxPooling2D layers to reduce dimensionality.
  3. Flatten to convert 2D feature maps to 1D vectors.
  4. Dense layers for classification.
  5. A final Dense layer with softmax activation for 10 classes.
  6. Dropout layers for regularization.
# Define the CNN model
cnn_model = models.Sequential([
    # First Conv2D Block
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
    layers.BatchNormalization(),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.25), # Added dropout after pooling

    # Second Conv2D Block
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.BatchNormalization(),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.25),

    # Third Conv2D Block (optional, can experiment with depth)
    layers.Conv2D(128, (3, 3), activation='relu'),
    layers.BatchNormalization(),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.25),

    # Flatten and Dense layers
    layers.Flatten(),
    layers.Dense(256, activation='relu'), # A larger Dense layer
    layers.BatchNormalization(),
    layers.Dropout(0.5), # Higher dropout before output layer
    layers.Dense(10, activation='softmax') # Output layer for 10 classes
])

# Display the model summary
cnn_model.summary()

Self-Check: Review the model summary. Pay attention to the output shapes of each layer and the total number of parameters.

Step 4: Compile and Train the Model

Now, we’ll compile the model, define callbacks for better training management, and then train it using the prepared tf.data pipeline.

# Compile the model
cnn_model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy', # Use sparse_categorical_crossentropy for integer labels
                  metrics=['accuracy'])

# Define Callbacks
# EarlyStopping: Stop training if validation accuracy doesn't improve for 5 epochs
early_stopping_cb = keras.callbacks.EarlyStopping(
    monitor='val_accuracy', patience=5, verbose=1, mode='max', restore_best_weights=True
)

# ModelCheckpoint: Save the best model weights based on validation accuracy
checkpoint_path = "cnn_cifar10_best_model.keras" # New Keras format
model_checkpoint_cb = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=False, # Save the entire model
    monitor='val_accuracy',
    mode='max',
    save_best_only=True,
    verbose=1
)

# ReduceLROnPlateau: Reduce learning rate if val_loss plateaus
reduce_lr_cb = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', factor=0.5, patience=3, min_lr=0.00001, verbose=1
)

callbacks_list = [early_stopping_cb, model_checkpoint_cb, reduce_lr_cb]

# Train the model
EPOCHS = 50 # Set a higher number, EarlyStopping will manage it

print("\nStarting CNN model training...\n")
history = cnn_model.fit(train_dataset,
                        epochs=EPOCHS,
                        validation_data=test_dataset,
                        callbacks=callbacks_list,
                        verbose=1)

print("\nCNN model training finished!")

Self-Check: Observe the training logs. Note how the validation accuracy changes and if any callbacks are triggered (e.g., learning rate reduction or early stopping).

Step 5: Evaluate the Model and Make Predictions

After training, we’ll evaluate the model’s performance on the test set and make some predictions.

# Evaluate the model on the unseen test data
print("\nEvaluating model on test data...")
test_loss, test_acc = cnn_model.evaluate(test_dataset, verbose=2)
print(f"\nTest accuracy: {test_acc:.4f}")
print(f"Test loss: {test_loss:.4f}")

# Plot training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.ylim([0, 1])

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Make predictions on a few test images
num_prediction_samples = 10
sample_images, sample_labels = next(iter(test_dataset.take(1))) # Get one batch from test set
sample_images = sample_images[:num_prediction_samples]
sample_labels = sample_labels[:num_prediction_samples]

predictions = cnn_model.predict(sample_images)
predicted_classes = np.argmax(predictions, axis=1)
true_classes = sample_labels.numpy().flatten() # Flatten the (N,1) labels to (N,)

print(f"\nPredictions for {num_prediction_samples} test images:")
for i in range(num_prediction_samples):
    print(f"  Image {i}: True: {class_names[true_classes[i]]} ({true_classes[i]}), "
          f"Predicted: {class_names[predicted_classes[i]]} ({predicted_classes[i]})")

# Visualize some predictions
plt.figure(figsize=(12, 12))
for i in range(num_prediction_samples):
    plt.subplot(5, 5, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(sample_images[i].numpy()) # Plot the normalized image
    color = 'green' if predicted_classes[i] == true_classes[i] else 'red'
    plt.xlabel(f"True: {class_names[true_classes[i]]}\nPred: {class_names[predicted_classes[i]]}", color=color)
plt.suptitle("Sample Predictions (Green=Correct, Red=Incorrect)", fontsize=16)
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent suptitle overlap
plt.show()

# Load the best model weights if saved by ModelCheckpoint
try:
    best_model = keras.models.load_model(checkpoint_path)
    print(f"\nLoaded best model from: {checkpoint_path}")
    test_loss_best, test_acc_best = best_model.evaluate(test_dataset, verbose=2)
    print(f"Test accuracy (best saved model): {test_acc_best:.4f}")
except Exception as e:
    print(f"Could not load best model (perhaps no model was saved yet or path is incorrect): {e}")

Congratulations! You’ve completed your first guided project, building and training a CNN for image classification. This project integrated concepts from tf.data pipelines, Keras model building, and training with callbacks. You’ve also gained experience in evaluating and visualizing model performance.