TensorFlow Guide: Intermediate Topics - Custom Training Loops and Callbacks

5. Intermediate Topics

While model.fit() is incredibly convenient, sometimes you need more control over the training process. This chapter introduces two powerful intermediate topics: Custom Training Loops for ultimate flexibility and Keras Callbacks for customizing model.fit() behavior.

5.1 Custom Training Loops with tf.GradientTape

A custom training loop gives you full control over every aspect of the training process, from calculating gradients to updating model weights. This is particularly useful for:

  • Implementing custom loss functions.
  • Building generative adversarial networks (GANs) or other multi-model architectures.
  • Applying complex regularization or gradient manipulation techniques.
  • Debugging complex training dynamics.

The core component of a custom training loop is tf.GradientTape. It records operations for automatic differentiation. When you execute operations inside a tf.GradientTape context, the tape “watches” the variables (your model’s trainable weights) and operations applied to them. After the forward pass, you can use the tape to compute the gradients of a loss with respect to these watched variables.

Basic Custom Training Loop Structure

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

# 1. Prepare Data (using a simple synthetic dataset for clarity)
X = tf.constant(np.random.rand(100, 2), dtype=tf.float32) # 100 samples, 2 features
y_true = tf.constant(X[:, 0] * 3 + X[:, 1] * 2 + 5 + np.random.rand(100) * 0.1, dtype=tf.float32) # Simple linear relationship with noise
y_true = tf.reshape(y_true, (-1, 1)) # Reshape to (100, 1)

# Create a tf.data dataset
BATCH_SIZE = 16
train_dataset = tf.data.Dataset.from_tensor_slices((X, y_true)).shuffle(100).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# 2. Define the Model (Simple Dense network)
model = keras.Sequential([
    layers.Dense(units=1, name='output_layer') # For linear regression, 1 output neuron, no activation
])

# 3. Define Loss Function and Optimizer
loss_fn = keras.losses.MeanSquaredError() # MSE for regression
optimizer = keras.optimizers.Adam(learning_rate=0.01)

# 4. Define Metrics (Optional, but good practice)
train_mae_metric = keras.metrics.MeanAbsoluteError()

# 5. The Training Step Function
# We use @tf.function to compile this into a graph for performance
@tf.function
def train_step(inputs, targets):
    with tf.GradientTape() as tape:
        # Forward pass: Compute predictions
        predictions = model(inputs, training=True) # `training=True` for correct batch normalization/dropout behavior
        # Compute the loss
        loss = loss_fn(targets, predictions)

    # Compute gradients
    gradients = tape.gradient(loss, model.trainable_variables)

    # Update weights
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    # Update metrics
    train_mae_metric.update_state(targets, predictions)

    return loss

# 6. The Full Training Loop
EPOCHS = 20

print("Starting Custom Training Loop...\n")
for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    total_loss = 0.0
    num_batches = 0
    train_mae_metric.reset_states() # Reset metrics at the start of each epoch

    for step, (batch_X, batch_y) in enumerate(train_dataset):
        batch_loss = train_step(batch_X, batch_y)
        total_loss += batch_loss
        num_batches += 1

        if step % 10 == 0:
            print(f"  Batch {step}: Loss = {batch_loss:.4f}, MAE = {train_mae_metric.result():.4f}")

    avg_loss_epoch = total_loss / num_batches
    print(f"Epoch {epoch+1} finished. Average Loss: {avg_loss_epoch:.4f}, Average MAE: {train_mae_metric.result():.4f}")

print("\nCustom Training Loop Finished!")

# Make a prediction with the trained model
sample_input = tf.constant([[1.0, 2.0], [5.0, 1.0]], dtype=tf.float32)
sample_prediction = model.predict(sample_input)
print(f"\nPrediction for [[1.0, 2.0], [5.0, 1.0]]:\n{sample_prediction.numpy()}")

In this example, tf.GradientTape is the core mechanism that enables automatic differentiation. It records operations performed during the forward pass and then uses this record to compute gradients during the backward pass.

Exercise 5.1: Custom Classification Training Loop

  1. Objective: Implement a custom training loop for a binary classification problem using tf.GradientTape.
  2. Instructions:
    • Generate Synthetic Data: Create a binary classification dataset.
      • X = np.random.rand(1000, 2).astype(np.float32) (1000 samples, 2 features)
      • y = (X[:, 0] + X[:, 1] > 1.0).astype(np.int32) (a simple decision boundary, 0s and 1s)
      • Reshape y to (1000, 1).
    • Data Pipeline: Create a tf.data.Dataset from X and y. Shuffle, batch (e.g., 32), and prefetch.
    • Model: Define a keras.Sequential model with two Dense hidden layers (e.g., 32 units, ‘relu’) and an output Dense layer with 1 unit and sigmoid activation (for binary classification).
    • Loss and Optimizer: Use tf.keras.losses.BinaryCrossentropy() and tf.keras.optimizers.Adam(learning_rate=0.001).
    • Metrics: Use tf.keras.metrics.BinaryAccuracy() and tf.keras.metrics.Precision().
    • train_step Function: Implement a @tf.function-decorated train_step(inputs, targets) that:
      • Performs the forward pass (model(inputs, training=True)).
      • Calculates the BinaryCrossentropy loss.
      • Computes gradients using tf.GradientTape.
      • Applies gradients using the optimizer.
      • Updates the BinaryAccuracy and Precision metrics.
    • Training Loop: Run the training for a few epochs (e.g., 10-20). Print epoch-level average loss, accuracy, and precision.
  3. Expected Output: Printed metrics for each epoch showing the model learning and improving, and a final accuracy/precision score.
# Your solution for Exercise 5.1 here
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

# 1. Generate Synthetic Data for Binary Classification
NUM_SAMPLES = 1000
X_data = np.random.rand(NUM_SAMPLES, 2).astype(np.float32)
# Simple decision boundary: if sum of features > 1.0, class is 1, else 0
y_data = (X_data[:, 0] + X_data[:, 1] > 1.0).astype(np.int32)
y_data = y_data.reshape(-1, 1) # Reshape labels to (NUM_SAMPLES, 1)

print(f"X_data shape: {X_data.shape}, y_data shape: {y_data.shape}\n")

# 2. Data Pipeline
BATCH_SIZE = 32
train_dataset_classification = tf.data.Dataset.from_tensor_slices((X_data, y_data))\
                                .shuffle(NUM_SAMPLES)\
                                .batch(BATCH_SIZE)\
                                .prefetch(tf.data.AUTOTUNE)

# 3. Define the Model for Binary Classification
classification_model = keras.Sequential([
    layers.Dense(units=32, activation='relu', input_shape=(2,)),
    layers.Dense(units=32, activation='relu'),
    layers.Dense(units=1, activation='sigmoid') # Sigmoid for binary classification output (probability 0-1)
])

classification_model.summary()

# 4. Define Loss Function, Optimizer, and Metrics
loss_fn_cls = keras.losses.BinaryCrossentropy()
optimizer_cls = keras.optimizers.Adam(learning_rate=0.001)

train_accuracy_metric = keras.metrics.BinaryAccuracy()
train_precision_metric = keras.metrics.Precision()

# 5. The Training Step Function
@tf.function
def train_step_classification(inputs, targets):
    with tf.GradientTape() as tape:
        predictions = classification_model(inputs, training=True)
        loss = loss_fn_cls(targets, predictions)

    gradients = tape.gradient(loss, classification_model.trainable_variables)
    optimizer_cls.apply_gradients(zip(gradients, classification_model.trainable_variables))

    train_accuracy_metric.update_state(targets, predictions)
    train_precision_metric.update_state(targets, predictions)

    return loss

# 6. The Full Training Loop
EPOCHS_CLS = 20

print("\nStarting Custom Classification Training Loop...\n")
for epoch in range(EPOCHS_CLS):
    print(f"\nEpoch {epoch+1}/{EPOCHS_CLS}")
    total_loss_cls = 0.0
    num_batches_cls = 0
    train_accuracy_metric.reset_states() # Reset metrics at the start of each epoch
    train_precision_metric.reset_states()

    for step, (batch_X, batch_y) in enumerate(train_dataset_classification):
        batch_loss_cls = train_step_classification(batch_X, batch_y)
        total_loss_cls += batch_loss_cls
        num_batches_cls += 1

        if step % 10 == 0:
            print(f"  Batch {step}: Loss = {batch_loss_cls:.4f}, Acc = {train_accuracy_metric.result():.4f}, Prec = {train_precision_metric.result():.4f}")

    avg_loss_epoch_cls = total_loss_cls / num_batches_cls
    print(f"Epoch {epoch+1} finished. Average Loss: {avg_loss_epoch_cls:.4f}, Average Accuracy: {train_accuracy_metric.result():.4f}, Average Precision: {train_precision_metric.result():.4f}")

print("\nCustom Classification Training Loop Finished!")

# Make a prediction with the trained model
sample_input_cls = tf.constant([[0.1, 0.2], [0.9, 0.8], [0.5, 0.5]], dtype=tf.float32)
sample_prediction_cls = classification_model.predict(sample_input_cls)
print(f"\nPrediction probabilities for [[0.1, 0.2], [0.9, 0.8], [0.5, 0.5]]:\n{sample_prediction_cls.numpy()}")
print(f"Predicted classes: {(sample_prediction_cls > 0.5).astype(int)}")

5.2 Keras Callbacks

Keras Callbacks are powerful tools to customize the behavior of your model.fit() method. They are objects that can perform actions at various stages of training (e.g., at the beginning/end of an epoch, before/after a batch).

Commonly used callbacks include:

  • ModelCheckpoint: Save the model (or just its weights) at some frequency during training.
  • EarlyStopping: Stop training when a monitored metric has stopped improving.
  • ReduceLROnPlateau: Reduce the learning rate when a metric has stopped improving.
  • TensorBoard: Visualize training metrics, model graphs, and more in TensorBoard.
  • CSVLogger: Stream epoch results to a CSV file.

Example: Using Multiple Callbacks

Let’s modify our MNIST classification example to use several important callbacks.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import datetime
import matplotlib.pyplot as plt

# 1. Load and preprocess the MNIST dataset
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
train_images = train_images.astype('float32') / 255.0
test_images = test_images.astype('float32') / 255.0
train_images = train_images.reshape((60000, 784))
test_images = test_images.reshape((10000, 784))

# 2. Define the model
model_callbacks = keras.Sequential([
    layers.Dense(units=128, activation='relu', input_shape=(784,)),
    layers.Dropout(0.2), # Adding dropout for regularization
    layers.Dense(units=64, activation='relu'),
    layers.Dropout(0.2),
    layers.Dense(units=10, activation='softmax')
])

# 3. Compile the model
model_callbacks.compile(optimizer='adam',
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy'])

# 4. Define Callbacks

# TensorBoard Callback: Logs events for TensorBoard visualization
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

# EarlyStopping Callback: Stop training if validation accuracy doesn't improve for 3 epochs
early_stopping_callback = keras.callbacks.EarlyStopping(
    monitor='val_accuracy', # Metric to monitor
    patience=3,             # Number of epochs with no improvement after which training will be stopped
    verbose=1,
    mode='max',             # 'max' because we want to maximize accuracy
    restore_best_weights=True # Restore model weights from the epoch with the best value of the monitored metric.
)

# ModelCheckpoint Callback: Save best model weights
checkpoint_filepath = 'best_model_weights/mnist_classifier_epoch_{epoch:02d}-val_accuracy_{val_accuracy:.2f}.keras'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True, # Only save model weights, not the whole model
    monitor='val_accuracy',
    mode='max',
    save_best_only=True,    # Only save when val_accuracy improves
    verbose=1
)

# ReduceLROnPlateau Callback: Reduce learning rate when a metric has stopped improving
reduce_lr_callback = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', # Monitor validation loss
    factor=0.2,         # Reduce learning rate by a factor of 0.2 (i.e., new_lr = lr * 0.2)
    patience=2,         # Number of epochs with no improvement after which learning rate will be reduced.
    min_lr=0.00001,     # Minimum learning rate
    verbose=1
)

# Combine all callbacks into a list
my_callbacks = [
    tensorboard_callback,
    early_stopping_callback,
    model_checkpoint_callback,
    reduce_lr_callback
]

# 5. Train the model with callbacks
print("\nStarting training with Callbacks...\n")
history_callbacks = model_callbacks.fit(train_images, train_labels,
                                        epochs=20, # Set a higher number of epochs, EarlyStopping will stop it
                                        batch_size=32,
                                        validation_split=0.1,
                                        callbacks=my_callbacks)

print("\nTraining with Callbacks finished!")

# 6. Evaluate and Predict
test_loss_cb, test_acc_cb = model_callbacks.evaluate(test_images, test_labels, verbose=2)
print(f"\nTest accuracy (with best weights restored): {test_acc_cb:.4f}")

# To view TensorBoard logs, open your terminal in the 'tensorflow_projects' directory
# and run: tensorboard --logdir logs/fit
# Then navigate to the URL provided by TensorBoard (usually http://localhost:6006/)

After running this code, you can open a new terminal, navigate to your tensorflow_projects directory, and run tensorboard --logdir logs/fit to visualize the training process. You’ll see graphs for loss, accuracy, and other metrics, as well as histograms of weights if histogram_freq is set.

Exercise 5.2: Custom Callback for Visualizing Predictions

  1. Objective: Create a custom Keras callback to visualize model predictions on a fixed set of validation images at the end of each epoch.
  2. Instructions:
    • Data Prep: Use the MNIST dataset (already loaded in the example above). Extract a small fixed subset of test images and labels (e.g., 5-10 images) to monitor.
    • Custom Callback Class: Create a class ImagePredictionCallback(keras.callbacks.Callback) that inherits from tf.keras.callbacks.Callback.
      • The constructor __init__ should take the fixed validation images and their true labels.
      • Override the on_epoch_end(self, epoch, logs=None) method.
      • Inside on_epoch_end:
        • Use self.model.predict() on your fixed validation images.
        • Convert predictions to class labels (e.g., np.argmax(predictions, axis=1)).
        • Print the epoch number, true labels, and predicted labels for your fixed subset.
        • (Bonus/Advanced): If you want to visualize, use matplotlib.pyplot to plot images with predicted/true labels and save the plot to a file or log it to TensorBoard (this requires tf.summary and tf.summary.image, which is more advanced for a basic custom callback). For this exercise, printing is sufficient.
    • Integrate and Train:
      • Define and compile the same MNIST model as in the example.
      • Create an instance of your ImagePredictionCallback.
      • Train the model using model.fit(), including your custom callback in the callbacks list.
  3. Expected Output: At the end of each epoch, the console should print the true and predicted labels for your chosen subset of validation images, showing how the model’s predictions improve over epochs.
# Your solution for Exercise 5.2 here
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

# 1. Data Prep: Load MNIST and extract a small fixed subset for monitoring
(train_images_cb, train_labels_cb), (test_images_cb, test_labels_cb) = keras.datasets.mnist.load_data()
train_images_cb = train_images_cb.astype('float32') / 255.0
test_images_cb = test_images_cb.astype('float32') / 255.0
train_images_cb = train_images_cb.reshape((60000, 784))
test_images_cb = test_images_cb.reshape((10000, 784))

# Take a small fixed subset (e.g., first 5) of test images for consistent monitoring
monitor_images = test_images_cb[:5]
monitor_labels = test_labels_cb[:5]

# 2. Custom Callback Class
class ImagePredictionCallback(keras.callbacks.Callback):
    def __init__(self, images, labels):
        super().__init__()
        self.images = images
        self.labels = labels

    def on_epoch_end(self, epoch, logs=None):
        predictions = self.model.predict(self.images, verbose=0) # Predict probabilities
        predicted_classes = np.argmax(predictions, axis=1) # Get the class with highest probability

        print(f"\n--- Epoch {epoch+1} Predictions ---")
        print(f"True labels:      {self.labels}")
        print(f"Predicted labels: {predicted_classes}")
        print("-----------------------------------\n")

        # Bonus: Visualize a single image's prediction (simple text output for brevity)
        if epoch % 5 == 0 or epoch == self.params['epochs'] - 1: # Visualize less frequently or on last epoch
             print(f"Sample prediction for image 0: True={self.labels[0]}, Pred={predicted_classes[0]}")
             # To actually plot:
             # plt.figure(figsize=(2,2))
             # plt.imshow(self.images[0].reshape(28,28), cmap='gray')
             # plt.title(f"Epoch {epoch+1}: True {self.labels[0]}, Pred {predicted_classes[0]}")
             # plt.axis('off')
             # plt.savefig(f'epoch_{epoch+1}_pred_sample_0.png')
             # plt.close()


# 3. Define and Compile the Model
model_custom_cb = keras.Sequential([
    layers.Dense(units=128, activation='relu', input_shape=(784,)),
    layers.Dropout(0.2),
    layers.Dense(units=64, activation='relu'),
    layers.Dropout(0.2),
    layers.Dense(units=10, activation='softmax')
])

model_custom_cb.compile(optimizer='adam',
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy'])

# Create an instance of the custom callback
prediction_callback = ImagePredictionCallback(monitor_images, monitor_labels)

# 4. Train the model with the custom callback
print("\nStarting training with Custom Image Prediction Callback...\n")
history_custom_cb = model_custom_cb.fit(train_images_cb, train_labels_cb,
                                        epochs=10, # Run for fewer epochs for faster demonstration
                                        batch_size=32,
                                        validation_split=0.1,
                                        callbacks=[prediction_callback])

print("\nTraining with Custom Callback finished!")

# Evaluate the model
test_loss_final, test_acc_final = model_custom_cb.evaluate(test_images_cb, test_labels_cb, verbose=2)
print(f"\nFinal Test Accuracy: {test_acc_final:.4f}")

You’ve now explored advanced ways to control and monitor your TensorFlow models, giving you the power to tackle more complex machine learning problems with greater insight and efficiency. In the next chapter, we’ll dive into advanced topics like model distribution and deployment.