5. Intermediate Topics
While model.fit() is incredibly convenient, sometimes you need more control over the training process. This chapter introduces two powerful intermediate topics: Custom Training Loops for ultimate flexibility and Keras Callbacks for customizing model.fit() behavior.
5.1 Custom Training Loops with tf.GradientTape
A custom training loop gives you full control over every aspect of the training process, from calculating gradients to updating model weights. This is particularly useful for:
- Implementing custom loss functions.
- Building generative adversarial networks (GANs) or other multi-model architectures.
- Applying complex regularization or gradient manipulation techniques.
- Debugging complex training dynamics.
The core component of a custom training loop is tf.GradientTape. It records operations for automatic differentiation. When you execute operations inside a tf.GradientTape context, the tape “watches” the variables (your model’s trainable weights) and operations applied to them. After the forward pass, you can use the tape to compute the gradients of a loss with respect to these watched variables.
Basic Custom Training Loop Structure
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
# 1. Prepare Data (using a simple synthetic dataset for clarity)
X = tf.constant(np.random.rand(100, 2), dtype=tf.float32) # 100 samples, 2 features
y_true = tf.constant(X[:, 0] * 3 + X[:, 1] * 2 + 5 + np.random.rand(100) * 0.1, dtype=tf.float32) # Simple linear relationship with noise
y_true = tf.reshape(y_true, (-1, 1)) # Reshape to (100, 1)
# Create a tf.data dataset
BATCH_SIZE = 16
train_dataset = tf.data.Dataset.from_tensor_slices((X, y_true)).shuffle(100).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
# 2. Define the Model (Simple Dense network)
model = keras.Sequential([
layers.Dense(units=1, name='output_layer') # For linear regression, 1 output neuron, no activation
])
# 3. Define Loss Function and Optimizer
loss_fn = keras.losses.MeanSquaredError() # MSE for regression
optimizer = keras.optimizers.Adam(learning_rate=0.01)
# 4. Define Metrics (Optional, but good practice)
train_mae_metric = keras.metrics.MeanAbsoluteError()
# 5. The Training Step Function
# We use @tf.function to compile this into a graph for performance
@tf.function
def train_step(inputs, targets):
with tf.GradientTape() as tape:
# Forward pass: Compute predictions
predictions = model(inputs, training=True) # `training=True` for correct batch normalization/dropout behavior
# Compute the loss
loss = loss_fn(targets, predictions)
# Compute gradients
gradients = tape.gradient(loss, model.trainable_variables)
# Update weights
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# Update metrics
train_mae_metric.update_state(targets, predictions)
return loss
# 6. The Full Training Loop
EPOCHS = 20
print("Starting Custom Training Loop...\n")
for epoch in range(EPOCHS):
print(f"\nEpoch {epoch+1}/{EPOCHS}")
total_loss = 0.0
num_batches = 0
train_mae_metric.reset_states() # Reset metrics at the start of each epoch
for step, (batch_X, batch_y) in enumerate(train_dataset):
batch_loss = train_step(batch_X, batch_y)
total_loss += batch_loss
num_batches += 1
if step % 10 == 0:
print(f" Batch {step}: Loss = {batch_loss:.4f}, MAE = {train_mae_metric.result():.4f}")
avg_loss_epoch = total_loss / num_batches
print(f"Epoch {epoch+1} finished. Average Loss: {avg_loss_epoch:.4f}, Average MAE: {train_mae_metric.result():.4f}")
print("\nCustom Training Loop Finished!")
# Make a prediction with the trained model
sample_input = tf.constant([[1.0, 2.0], [5.0, 1.0]], dtype=tf.float32)
sample_prediction = model.predict(sample_input)
print(f"\nPrediction for [[1.0, 2.0], [5.0, 1.0]]:\n{sample_prediction.numpy()}")
In this example, tf.GradientTape is the core mechanism that enables automatic differentiation. It records operations performed during the forward pass and then uses this record to compute gradients during the backward pass.
Exercise 5.1: Custom Classification Training Loop
- Objective: Implement a custom training loop for a binary classification problem using
tf.GradientTape. - Instructions:
- Generate Synthetic Data: Create a binary classification dataset.
X = np.random.rand(1000, 2).astype(np.float32)(1000 samples, 2 features)y = (X[:, 0] + X[:, 1] > 1.0).astype(np.int32)(a simple decision boundary, 0s and 1s)- Reshape
yto(1000, 1).
- Data Pipeline: Create a
tf.data.DatasetfromXandy. Shuffle, batch (e.g., 32), and prefetch. - Model: Define a
keras.Sequentialmodel with twoDensehidden layers (e.g., 32 units, ‘relu’) and an outputDenselayer with1unit andsigmoidactivation (for binary classification). - Loss and Optimizer: Use
tf.keras.losses.BinaryCrossentropy()andtf.keras.optimizers.Adam(learning_rate=0.001). - Metrics: Use
tf.keras.metrics.BinaryAccuracy()andtf.keras.metrics.Precision(). train_stepFunction: Implement a@tf.function-decoratedtrain_step(inputs, targets)that:- Performs the forward pass (
model(inputs, training=True)). - Calculates the
BinaryCrossentropyloss. - Computes gradients using
tf.GradientTape. - Applies gradients using the optimizer.
- Updates the
BinaryAccuracyandPrecisionmetrics.
- Performs the forward pass (
- Training Loop: Run the training for a few epochs (e.g., 10-20). Print epoch-level average loss, accuracy, and precision.
- Generate Synthetic Data: Create a binary classification dataset.
- Expected Output: Printed metrics for each epoch showing the model learning and improving, and a final accuracy/precision score.
# Your solution for Exercise 5.1 here
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
# 1. Generate Synthetic Data for Binary Classification
NUM_SAMPLES = 1000
X_data = np.random.rand(NUM_SAMPLES, 2).astype(np.float32)
# Simple decision boundary: if sum of features > 1.0, class is 1, else 0
y_data = (X_data[:, 0] + X_data[:, 1] > 1.0).astype(np.int32)
y_data = y_data.reshape(-1, 1) # Reshape labels to (NUM_SAMPLES, 1)
print(f"X_data shape: {X_data.shape}, y_data shape: {y_data.shape}\n")
# 2. Data Pipeline
BATCH_SIZE = 32
train_dataset_classification = tf.data.Dataset.from_tensor_slices((X_data, y_data))\
.shuffle(NUM_SAMPLES)\
.batch(BATCH_SIZE)\
.prefetch(tf.data.AUTOTUNE)
# 3. Define the Model for Binary Classification
classification_model = keras.Sequential([
layers.Dense(units=32, activation='relu', input_shape=(2,)),
layers.Dense(units=32, activation='relu'),
layers.Dense(units=1, activation='sigmoid') # Sigmoid for binary classification output (probability 0-1)
])
classification_model.summary()
# 4. Define Loss Function, Optimizer, and Metrics
loss_fn_cls = keras.losses.BinaryCrossentropy()
optimizer_cls = keras.optimizers.Adam(learning_rate=0.001)
train_accuracy_metric = keras.metrics.BinaryAccuracy()
train_precision_metric = keras.metrics.Precision()
# 5. The Training Step Function
@tf.function
def train_step_classification(inputs, targets):
with tf.GradientTape() as tape:
predictions = classification_model(inputs, training=True)
loss = loss_fn_cls(targets, predictions)
gradients = tape.gradient(loss, classification_model.trainable_variables)
optimizer_cls.apply_gradients(zip(gradients, classification_model.trainable_variables))
train_accuracy_metric.update_state(targets, predictions)
train_precision_metric.update_state(targets, predictions)
return loss
# 6. The Full Training Loop
EPOCHS_CLS = 20
print("\nStarting Custom Classification Training Loop...\n")
for epoch in range(EPOCHS_CLS):
print(f"\nEpoch {epoch+1}/{EPOCHS_CLS}")
total_loss_cls = 0.0
num_batches_cls = 0
train_accuracy_metric.reset_states() # Reset metrics at the start of each epoch
train_precision_metric.reset_states()
for step, (batch_X, batch_y) in enumerate(train_dataset_classification):
batch_loss_cls = train_step_classification(batch_X, batch_y)
total_loss_cls += batch_loss_cls
num_batches_cls += 1
if step % 10 == 0:
print(f" Batch {step}: Loss = {batch_loss_cls:.4f}, Acc = {train_accuracy_metric.result():.4f}, Prec = {train_precision_metric.result():.4f}")
avg_loss_epoch_cls = total_loss_cls / num_batches_cls
print(f"Epoch {epoch+1} finished. Average Loss: {avg_loss_epoch_cls:.4f}, Average Accuracy: {train_accuracy_metric.result():.4f}, Average Precision: {train_precision_metric.result():.4f}")
print("\nCustom Classification Training Loop Finished!")
# Make a prediction with the trained model
sample_input_cls = tf.constant([[0.1, 0.2], [0.9, 0.8], [0.5, 0.5]], dtype=tf.float32)
sample_prediction_cls = classification_model.predict(sample_input_cls)
print(f"\nPrediction probabilities for [[0.1, 0.2], [0.9, 0.8], [0.5, 0.5]]:\n{sample_prediction_cls.numpy()}")
print(f"Predicted classes: {(sample_prediction_cls > 0.5).astype(int)}")
5.2 Keras Callbacks
Keras Callbacks are powerful tools to customize the behavior of your model.fit() method. They are objects that can perform actions at various stages of training (e.g., at the beginning/end of an epoch, before/after a batch).
Commonly used callbacks include:
ModelCheckpoint: Save the model (or just its weights) at some frequency during training.EarlyStopping: Stop training when a monitored metric has stopped improving.ReduceLROnPlateau: Reduce the learning rate when a metric has stopped improving.TensorBoard: Visualize training metrics, model graphs, and more in TensorBoard.CSVLogger: Stream epoch results to a CSV file.
Example: Using Multiple Callbacks
Let’s modify our MNIST classification example to use several important callbacks.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import datetime
import matplotlib.pyplot as plt
# 1. Load and preprocess the MNIST dataset
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
train_images = train_images.astype('float32') / 255.0
test_images = test_images.astype('float32') / 255.0
train_images = train_images.reshape((60000, 784))
test_images = test_images.reshape((10000, 784))
# 2. Define the model
model_callbacks = keras.Sequential([
layers.Dense(units=128, activation='relu', input_shape=(784,)),
layers.Dropout(0.2), # Adding dropout for regularization
layers.Dense(units=64, activation='relu'),
layers.Dropout(0.2),
layers.Dense(units=10, activation='softmax')
])
# 3. Compile the model
model_callbacks.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 4. Define Callbacks
# TensorBoard Callback: Logs events for TensorBoard visualization
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
# EarlyStopping Callback: Stop training if validation accuracy doesn't improve for 3 epochs
early_stopping_callback = keras.callbacks.EarlyStopping(
monitor='val_accuracy', # Metric to monitor
patience=3, # Number of epochs with no improvement after which training will be stopped
verbose=1,
mode='max', # 'max' because we want to maximize accuracy
restore_best_weights=True # Restore model weights from the epoch with the best value of the monitored metric.
)
# ModelCheckpoint Callback: Save best model weights
checkpoint_filepath = 'best_model_weights/mnist_classifier_epoch_{epoch:02d}-val_accuracy_{val_accuracy:.2f}.keras'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True, # Only save model weights, not the whole model
monitor='val_accuracy',
mode='max',
save_best_only=True, # Only save when val_accuracy improves
verbose=1
)
# ReduceLROnPlateau Callback: Reduce learning rate when a metric has stopped improving
reduce_lr_callback = keras.callbacks.ReduceLROnPlateau(
monitor='val_loss', # Monitor validation loss
factor=0.2, # Reduce learning rate by a factor of 0.2 (i.e., new_lr = lr * 0.2)
patience=2, # Number of epochs with no improvement after which learning rate will be reduced.
min_lr=0.00001, # Minimum learning rate
verbose=1
)
# Combine all callbacks into a list
my_callbacks = [
tensorboard_callback,
early_stopping_callback,
model_checkpoint_callback,
reduce_lr_callback
]
# 5. Train the model with callbacks
print("\nStarting training with Callbacks...\n")
history_callbacks = model_callbacks.fit(train_images, train_labels,
epochs=20, # Set a higher number of epochs, EarlyStopping will stop it
batch_size=32,
validation_split=0.1,
callbacks=my_callbacks)
print("\nTraining with Callbacks finished!")
# 6. Evaluate and Predict
test_loss_cb, test_acc_cb = model_callbacks.evaluate(test_images, test_labels, verbose=2)
print(f"\nTest accuracy (with best weights restored): {test_acc_cb:.4f}")
# To view TensorBoard logs, open your terminal in the 'tensorflow_projects' directory
# and run: tensorboard --logdir logs/fit
# Then navigate to the URL provided by TensorBoard (usually http://localhost:6006/)
After running this code, you can open a new terminal, navigate to your tensorflow_projects directory, and run tensorboard --logdir logs/fit to visualize the training process. You’ll see graphs for loss, accuracy, and other metrics, as well as histograms of weights if histogram_freq is set.
Exercise 5.2: Custom Callback for Visualizing Predictions
- Objective: Create a custom Keras callback to visualize model predictions on a fixed set of validation images at the end of each epoch.
- Instructions:
- Data Prep: Use the MNIST dataset (already loaded in the example above). Extract a small fixed subset of test images and labels (e.g., 5-10 images) to monitor.
- Custom Callback Class: Create a class
ImagePredictionCallback(keras.callbacks.Callback)that inherits fromtf.keras.callbacks.Callback.- The constructor
__init__should take the fixed validation images and their true labels. - Override the
on_epoch_end(self, epoch, logs=None)method. - Inside
on_epoch_end:- Use
self.model.predict()on your fixed validation images. - Convert predictions to class labels (e.g.,
np.argmax(predictions, axis=1)). - Print the epoch number, true labels, and predicted labels for your fixed subset.
- (Bonus/Advanced): If you want to visualize, use
matplotlib.pyplotto plot images with predicted/true labels and save the plot to a file or log it to TensorBoard (this requirestf.summaryandtf.summary.image, which is more advanced for a basic custom callback). For this exercise, printing is sufficient.
- Use
- The constructor
- Integrate and Train:
- Define and compile the same MNIST model as in the example.
- Create an instance of your
ImagePredictionCallback. - Train the model using
model.fit(), including your custom callback in thecallbackslist.
- Expected Output: At the end of each epoch, the console should print the true and predicted labels for your chosen subset of validation images, showing how the model’s predictions improve over epochs.
# Your solution for Exercise 5.2 here
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
# 1. Data Prep: Load MNIST and extract a small fixed subset for monitoring
(train_images_cb, train_labels_cb), (test_images_cb, test_labels_cb) = keras.datasets.mnist.load_data()
train_images_cb = train_images_cb.astype('float32') / 255.0
test_images_cb = test_images_cb.astype('float32') / 255.0
train_images_cb = train_images_cb.reshape((60000, 784))
test_images_cb = test_images_cb.reshape((10000, 784))
# Take a small fixed subset (e.g., first 5) of test images for consistent monitoring
monitor_images = test_images_cb[:5]
monitor_labels = test_labels_cb[:5]
# 2. Custom Callback Class
class ImagePredictionCallback(keras.callbacks.Callback):
def __init__(self, images, labels):
super().__init__()
self.images = images
self.labels = labels
def on_epoch_end(self, epoch, logs=None):
predictions = self.model.predict(self.images, verbose=0) # Predict probabilities
predicted_classes = np.argmax(predictions, axis=1) # Get the class with highest probability
print(f"\n--- Epoch {epoch+1} Predictions ---")
print(f"True labels: {self.labels}")
print(f"Predicted labels: {predicted_classes}")
print("-----------------------------------\n")
# Bonus: Visualize a single image's prediction (simple text output for brevity)
if epoch % 5 == 0 or epoch == self.params['epochs'] - 1: # Visualize less frequently or on last epoch
print(f"Sample prediction for image 0: True={self.labels[0]}, Pred={predicted_classes[0]}")
# To actually plot:
# plt.figure(figsize=(2,2))
# plt.imshow(self.images[0].reshape(28,28), cmap='gray')
# plt.title(f"Epoch {epoch+1}: True {self.labels[0]}, Pred {predicted_classes[0]}")
# plt.axis('off')
# plt.savefig(f'epoch_{epoch+1}_pred_sample_0.png')
# plt.close()
# 3. Define and Compile the Model
model_custom_cb = keras.Sequential([
layers.Dense(units=128, activation='relu', input_shape=(784,)),
layers.Dropout(0.2),
layers.Dense(units=64, activation='relu'),
layers.Dropout(0.2),
layers.Dense(units=10, activation='softmax')
])
model_custom_cb.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Create an instance of the custom callback
prediction_callback = ImagePredictionCallback(monitor_images, monitor_labels)
# 4. Train the model with the custom callback
print("\nStarting training with Custom Image Prediction Callback...\n")
history_custom_cb = model_custom_cb.fit(train_images_cb, train_labels_cb,
epochs=10, # Run for fewer epochs for faster demonstration
batch_size=32,
validation_split=0.1,
callbacks=[prediction_callback])
print("\nTraining with Custom Callback finished!")
# Evaluate the model
test_loss_final, test_acc_final = model_custom_cb.evaluate(test_images_cb, test_labels_cb, verbose=2)
print(f"\nFinal Test Accuracy: {test_acc_final:.4f}")
You’ve now explored advanced ways to control and monitor your TensorFlow models, giving you the power to tackle more complex machine learning problems with greater insight and efficiency. In the next chapter, we’ll dive into advanced topics like model distribution and deployment.