6. Advanced Topics and Best Practices
As you move beyond basic model building, two crucial aspects come into play for real-world applications: scaling your training to leverage powerful hardware and deploying your models to various environments, especially resource-constrained ones. This chapter covers TensorFlow’s Distribution Strategies and TensorFlow Lite.
6.1 Distribution Strategies: Scaling Your Training
Training large models on massive datasets can be time-consuming. TensorFlow’s tf.distribute.Strategy API allows you to easily distribute your training across multiple GPUs, multiple machines, or even Google’s TPUs (Tensor Processing Units) with minimal changes to your code.
The core idea is to replicate the model on multiple devices (or parts of a device) and then combine the gradients computed on each replica.
Common Distribution Strategies
MirroredStrategy: Synchronous training on multiple GPUs on a single machine. Each replica processes a different slice of the input data and computes gradients. The gradients are then aggregated and averaged across all replicas, and the model weights are updated identically on each replica. This is the most common strategy for multi-GPU training on a single host.MultiWorkerMirroredStrategy: Similar toMirroredStrategybut for synchronous training across multiple machines, each with one or more GPUs.TPUStrategy: For distributed training on TPUs. TPUs are specialized ASICs designed by Google for accelerating machine learning workloads.
Using MirroredStrategy (Single Host, Multi-GPU)
Let’s adapt our MNIST example to use MirroredStrategy. For this to show actual benefits, you need at least two GPUs on your system. If you only have one GPU or are running on CPU, the code will still run, but you won’t see parallelization benefits.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
# Ensure TensorFlow detects GPUs (if available)
physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices) > 0:
print(f"Detected GPUs: {physical_devices}")
for device in physical_devices:
tf.config.experimental.set_memory_growth(device, True)
else:
print("No GPUs detected. Running on CPU or single GPU if available.")
# 1. Choose a strategy
# MirroredStrategy will use all detected GPUs on the current machine.
strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}\n")
# 2. Load and preprocess the MNIST dataset
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
train_images = train_images.astype('float32') / 255.0
test_images = test_images.astype('float32') / 255.0
train_images = train_images.reshape((60000, 784))
test_images = test_images.reshape((10000, 784))
# Create tf.data datasets for efficient input pipelines
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))\
.shuffle(10000).batch(GLOBAL_BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))\
.batch(GLOBAL_BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
# 3. Create the model, optimizer, and metrics within the strategy's scope
# THIS IS CRUCIAL: All variables (model weights) must be created under the strategy scope.
with strategy.scope():
model_distributed = keras.Sequential([
layers.Dense(units=128, activation='relu', input_shape=(784,)),
layers.Dropout(0.2),
layers.Dense(units=64, activation='relu'),
layers.Dropout(0.2),
layers.Dense(units=10, activation='softmax')
])
model_distributed.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# (Optional) Verify that the model's variables are distributed
print("Model built within strategy scope.")
print(f"Number of trainable variables: {len(model_distributed.trainable_variables)}")
# The actual variable objects will internally manage replication.
# 4. Train the model using the `.fit()` method (no changes needed here!)
print("\nStarting distributed training...\n")
history_distributed = model_distributed.fit(train_dataset,
epochs=5, # Reduced epochs for faster demonstration
validation_data=test_dataset)
print("\nDistributed training finished!")
# 5. Evaluate the model
test_loss_dist, test_acc_dist = model_distributed.evaluate(test_dataset, verbose=2)
print(f"\nDistributed Test accuracy: {test_acc_dist:.4f}")
The beauty of tf.distribute.Strategy is that once you define the strategy and create your model within its scope, the rest of your tf.keras workflow (.compile(), .fit(), .evaluate(), .predict()) works seamlessly without further modifications.
Best Practices for Distribution Strategies:
GLOBAL_BATCH_SIZE: When using multiple replicas, the total batch size isBATCH_SIZE_PER_REPLICA * num_replicas. Make sureGLOBAL_BATCH_SIZEis adjusted appropriately for your hardware.- Data
tf.data: Always usetf.datadatasets with distribution strategies. The strategy automatically shards the dataset across replicas. - Scope: Always create your model and
compile()it insidestrategy.scope(). - TPUs: For TPUs,
TPUStrategyrequires special setup (often on Google Cloud Platform’s Colab or Vertex AI).
Exercise 6.1: Experimenting with Batch Size and Distributed Training
- Objective: Observe the effect of
GLOBAL_BATCH_SIZEwhen usingMirroredStrategy. - Instructions:
- Set up the
MirroredStrategyas shown in the example. - Define two
GLOBAL_BATCH_SIZEvalues:GLOBAL_BATCH_SIZE_SMALL = 32 * strategy.num_replicas_in_syncGLOBAL_BATCH_SIZE_LARGE = 256 * strategy.num_replicas_in_sync
- Create two separate
tf.datatraining datasets (e.g.,train_dataset_small_batchandtrain_dataset_large_batch) using these different batch sizes. Keep other dataset parameters (shuffle buffer, prefetch) the same. - Crucially, create a new model instance for each training run, inside a new
strategy.scope(), compile it, and train it for a fixed small number of epochs (e.g., 3). - Compare the training speed (time per epoch, if you add a timer callback) and final accuracy of the models trained with small vs. large global batch sizes.
- Set up the
- Expected Outcome: You should see differences in training speed (larger batches often process faster per step, but might need more epochs for similar accuracy) and potentially slight differences in final accuracy. This exercise aims to highlight the impact of batch size in a distributed context.
# Your solution for Exercise 6.1 here
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import time
# Ensure TensorFlow detects GPUs (if available)
physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices) > 0:
print(f"Detected GPUs: {physical_devices}")
for device in physical_devices:
tf.config.experimental.set_memory_growth(device, True)
else:
print("No GPUs detected. Running on CPU or single GPU if available.")
strategy = tf.distribute.MirroredStrategy()
num_replicas = strategy.num_replicas_in_sync
print(f"Number of devices for MirroredStrategy: {num_replicas}\n")
# Load and preprocess the MNIST dataset
(train_images_ex, train_labels_ex), (test_images_ex, test_labels_ex) = keras.datasets.mnist.load_data()
train_images_ex = train_images_ex.astype('float32') / 255.0
test_images_ex = test_images_ex.astype('float32') / 255.0
train_images_ex = train_images_ex.reshape((60000, 784))
test_images_ex = test_images_ex.reshape((10000, 784))
# Define GLOBAL_BATCH_SIZE for small and large
BATCH_SIZE_PER_REPLICA_SMALL = 32
GLOBAL_BATCH_SIZE_SMALL = BATCH_SIZE_PER_REPLICA_SMALL * num_replicas
BATCH_SIZE_PER_REPLICA_LARGE = 128 # Increased from 64 for more noticeable difference
GLOBAL_BATCH_SIZE_LARGE = BATCH_SIZE_PER_REPLICA_LARGE * num_replicas
print(f"Global Batch Size (Small): {GLOBAL_BATCH_SIZE_SMALL}")
print(f"Global Batch Size (Large): {GLOBAL_BATCH_SIZE_LARGE}\n")
# Create tf.data datasets
train_dataset_small_batch = tf.data.Dataset.from_tensor_slices((train_images_ex, train_labels_ex))\
.shuffle(10000).batch(GLOBAL_BATCH_SIZE_SMALL).prefetch(tf.data.AUTOTUNE)
test_dataset_small_batch = tf.data.Dataset.from_tensor_slices((test_images_ex, test_labels_ex))\
.batch(GLOBAL_BATCH_SIZE_SMALL).prefetch(tf.data.AUTOTUNE)
train_dataset_large_batch = tf.data.Dataset.from_tensor_slices((train_images_ex, train_labels_ex))\
.shuffle(10000).batch(GLOBAL_BATCH_SIZE_LARGE).prefetch(tf.data.AUTOTUNE)
test_dataset_large_batch = tf.data.Dataset.from_tensor_slices((test_images_ex, test_labels_ex))\
.batch(GLOBAL_BATCH_SIZE_LARGE).prefetch(tf.data.AUTOTUNE)
EPOCHS_EX = 3 # Fixed small number of epochs for comparison
# Custom Callback to measure epoch time
class TimeHistory(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.times = []
def on_epoch_begin(self, epoch, logs={}):
self.epoch_time_start = time.time()
def on_epoch_end(self, epoch, logs={}):
self.times.append(time.time() - self.epoch_time_start)
print(f"Epoch {epoch+1} took {self.times[-1]:.2f} seconds.")
# --- Training with SMALL Global Batch Size ---
print("--- Training with SMALL Global Batch Size ---")
with strategy.scope():
model_small_batch = keras.Sequential([
layers.Dense(units=128, activation='relu', input_shape=(784,)),
layers.Dropout(0.2),
layers.Dense(units=64, activation='relu'),
layers.Dropout(0.2),
layers.Dense(units=10, activation='softmax')
])
model_small_batch.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
time_callback_small = TimeHistory()
history_small = model_small_batch.fit(train_dataset_small_batch,
epochs=EPOCHS_EX,
validation_data=test_dataset_small_batch,
callbacks=[time_callback_small],
verbose=1)
test_loss_small, test_acc_small = model_small_batch.evaluate(test_dataset_small_batch, verbose=0)
print(f"\nFinal Test Accuracy (Small Batch): {test_acc_small:.4f}")
print(f"Average Epoch Time (Small Batch): {np.mean(time_callback_small.times):.2f} seconds\n")
# --- Training with LARGE Global Batch Size ---
print("--- Training with LARGE Global Batch Size ---")
with strategy.scope():
model_large_batch = keras.Sequential([
layers.Dense(units=128, activation='relu', input_shape=(784,)),
layers.Dropout(0.2),
layers.Dense(units=64, activation='relu'),
layers.Dropout(0.2),
layers.Dense(units=10, activation='softmax')
])
model_large_batch.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
time_callback_large = TimeHistory()
history_large = model_large_batch.fit(train_dataset_large_batch,
epochs=EPOCHS_EX,
validation_data=test_dataset_large_batch,
callbacks=[time_callback_large],
verbose=1)
test_loss_large, test_acc_large = model_large_batch.evaluate(test_dataset_large_batch, verbose=0)
print(f"\nFinal Test Accuracy (Large Batch): {test_acc_large:.4f}")
print(f"Average Epoch Time (Large Batch): {np.mean(time_callback_large.times):.2f} seconds\n")
print("\n--- Comparison ---")
print(f"Small Batch Final Accuracy: {test_acc_small:.4f}, Avg. Epoch Time: {np.mean(time_callback_small.times):.2f}s")
print(f"Large Batch Final Accuracy: {test_acc_large:.4f}, Avg. Epoch Time: {np.mean(time_callback_large.times):.2f}s")
print("\nObservations:")
print("- Larger batches often lead to faster epoch times because fewer gradient updates are needed per epoch.")
print("- However, very large batches might sometimes result in slightly lower final accuracy or require more epochs to converge.")
print("- The optimal batch size can depend on the model, dataset, and available hardware.")
6.2 TensorFlow Lite: Deploying on Mobile and Edge Devices
TensorFlow Lite is a set of tools that enables on-device machine learning inference. It’s designed for mobile, embedded, and IoT devices that have limited computational resources, memory, and battery.
Key benefits of TensorFlow Lite:
- Low Latency: Inference happens directly on the device, avoiding network roundtrips.
- Small Binary Size: Models are optimized and quantized to be much smaller.
- High Performance: Optimized for various device hardware (CPUs, GPUs, NPUs).
- Privacy: User data doesn’t leave the device for inference.
The typical workflow involves:
- Train a TensorFlow model (as you’ve learned).
- Convert the model to TensorFlow Lite format (
.tflite) using the TensorFlow Lite Converter. During conversion, you can apply optimizations like quantization. - Deploy the
.tflitemodel to your mobile or edge device. - Run inference using the TensorFlow Lite Interpreter.
6.2.1 Converting a Model to TensorFlow Lite
The TensorFlow Lite Converter is the central tool. It can convert tf.keras models or SavedModel formats.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
# 1. Train a simple model (reusing our MNIST classifier)
(train_images_tf, train_labels_tf), (test_images_tf, test_labels_tf) = keras.datasets.mnist.load_data()
train_images_tf = train_images_tf.astype('float32') / 255.0
test_images_tf = test_images_tf.astype('float32') / 255.0
train_images_tf = train_images_tf.reshape((60000, 784))
test_images_tf = test_images_tf.reshape((10000, 784))
model_to_convert = keras.Sequential([
layers.Dense(units=128, activation='relu', input_shape=(784,)),
layers.Dense(units=10, activation='softmax')
])
model_to_convert.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model_to_convert.fit(train_images_tf, train_labels_tf, epochs=3, verbose=0)
print("Model trained for conversion.\n")
# 2. Convert the Keras model to TensorFlow Lite format
converter = tf.lite.TFLiteConverter.from_keras_model(model_to_convert)
tflite_model = converter.convert()
# Save the TFLite model to a file
with open('mnist_model.tflite', 'wb') as f:
f.write(tflite_model)
print(f"TensorFlow Lite model saved to mnist_model.tflite. Size: {len(tflite_model) / 1024:.2f} KB\n")
6.2.2 Quantization: Further Optimizing for Edge Devices
Quantization reduces the precision of the numbers used to represent a model’s parameters and activations, often from 32-bit floating-point to 8-bit integers. This significantly reduces model size and speeds up inference with minimal impact on accuracy.
There are different types of quantization:
- Post-training dynamic range quantization: Quantizes weights to 8-bit, and dynamically quantizes activations during inference. Fastest and easiest.
- Post-training float16 quantization: Quantizes weights to 16-bit floating-point. Good balance of size and accuracy.
- Post-training integer quantization: Quantizes weights and activations to 8-bit integers. Requires a “representative dataset” for calibration. Best for memory and integer-only hardware.
Let’s apply post-training dynamic range quantization:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
# Assuming `model_to_convert` is already trained from the previous step
# 1. Convert with dynamic range quantization
converter_quant = tf.lite.TFLiteConverter.from_keras_model(model_to_convert)
converter_quant.optimizations = [tf.lite.Optimize.DEFAULT] # Enable default optimizations (dynamic range)
tflite_quant_model = converter_quant.convert()
# Save the quantized TFLite model
with open('mnist_model_quant.tflite', 'wb') as f:
f.write(tflite_quant_model)
print(f"Quantized TensorFlow Lite model saved to mnist_model_quant.tflite. Size: {len(tflite_quant_model) / 1024:.2f} KB\n")
# Compare sizes
print(f"Original .tflite size: {len(tflite_model) / 1024:.2f} KB")
print(f"Quantized .tflite size: {len(tflite_quant_model) / 1024:.2f} KB")
You should see a noticeable reduction in file size for the quantized model.
6.2.3 Running Inference with TensorFlow Lite Interpreter
Once you have a .tflite model, you can load it with the tf.lite.Interpreter and run inference.
import tensorflow as tf
import numpy as np
# Load the TFLite model
interpreter = tf.lite.Interpreter(model_path="mnist_model_quant.tflite")
interpreter.allocate_tensors() # Allocate tensors for input, output, and intermediate arrays
# Get input and output details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(f"Input details: {input_details}\n")
print(f"Output details: {output_details}\n")
# Load a sample test image
# (Using the previously loaded test_images_tf)
sample_image = test_images_tf[0:1] # Get first image and make it a batch of 1
true_label = test_labels_tf[0]
# TFLite models often expect specific input data types (e.g., float32).
# Ensure your input data matches the model's expected type.
input_shape = input_details[0]['shape']
input_dtype = input_details[0]['dtype']
print(f"Interpreter expects input shape: {input_shape}, dtype: {input_dtype}")
# If input_dtype is different, convert sample_image. E.g., if it expects int8:
# input_data = (sample_image * 255).astype(np.uint8)
input_data = sample_image.astype(input_dtype) # Ensure dtype matches
# Set the tensor to be the input
interpreter.set_tensor(input_details[0]['index'], input_data)
# Run inference
interpreter.invoke()
# Get the output tensor
output_data = interpreter.get_tensor(output_details[0]['index'])
# Process the output (softmax probabilities)
predicted_class = np.argmax(output_data)
print(f"True Label: {true_label}")
print(f"Predicted Class (Quantized Model): {predicted_class}")
print(f"Raw Output (logits/probabilities): {output_data.flatten()}\n")
# Verify with original Keras model for comparison
original_prediction = model_to_convert.predict(sample_image, verbose=0)
original_predicted_class = np.argmax(original_prediction)
print(f"Predicted Class (Original Keras Model): {original_predicted_class}")
You should observe similar predictions between the original Keras model and the quantized TFLite model, demonstrating the effectiveness of quantization.
Best Practices for TensorFlow Lite:
- Quantization-Aware Training (QAT): For even higher accuracy with integer quantization, you can integrate quantization directly into the training process. This allows the model to “learn” to be robust to quantization noise.
- Representative Dataset: For full integer quantization, provide a small, diverse subset of your training data to the converter for calibration.
- Performance Benchmarking: Always benchmark your
.tflitemodel on the target device to ensure it meets your performance requirements. - Delegate Integration: For specialized hardware (like NPUs, DSPs), use TensorFlow Lite delegates to offload computation and achieve maximum acceleration.
Exercise 6.2: Full Integer Quantization with Representative Dataset
- Objective: Convert a trained model to TensorFlow Lite using full integer quantization, leveraging a representative dataset.
- Instructions:
- Keep Model and Data: Use the same
model_to_convert(trained MNIST classifier) andtrain_images_tffrom the previous example. - Representative Dataset Generator: Create a small function
representative_data_gen()that yields batches of your input data. The TFLite converter will use this to calibrate the integer ranges for quantization. It should yield(input_data,)tuples.def representative_data_gen(): for i in range(100): # Use a small subset (e.g., 100 samples) yield [train_images_tf[i:i+1]] # Yield a batch of 1 image (input shape expected by model) - Converter Configuration:
- Instantiate
tf.lite.TFLiteConverter.from_keras_model(model_to_convert). - Set
converter.optimizations = [tf.lite.Optimize.DEFAULT]. - Set
converter.representative_dataset = representative_data_gen. - Specify
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]to ensure full integer quantization. - Optionally, set
converter.inference_input_type = tf.int8andconverter.inference_output_type = tf.int8to make the model entirely integer-based (this simplifies deployment on integer-only hardware but requires input data conversion before inference).
- Instantiate
- Convert and Save: Convert the model and save it as
mnist_model_full_int_quant.tflite. - Compare Size: Print the file size of this fully integer-quantized model and compare it to the dynamic range quantized model.
- Run Inference: Load the
mnist_model_full_int_quant.tflitewithtf.lite.Interpreter.- When preparing
sample_imagefor inference, you must cast it totf.int8(ornp.int8) ifinference_input_typewas set totf.int8. Ifinference_input_typewas not explicitly set, it will likely still expectfloat32input but perform internal integer conversion. Checkinput_details[0]['dtype']. - Run inference and print the predicted class, comparing it to the original model.
- When preparing
- Keep Model and Data: Use the same
- Expected Outcome: A smaller
.tflitefile size and correct predictions, demonstrating full integer quantization.
# Your solution for Exercise 6.2 here
import tensorflow as tf
import numpy as np
import os
# Assuming model_to_convert is already trained from the previous step
# (If not, run the 'Train a simple model' section above)
# 1. Representative Dataset Generator
# The converter uses this data to calibrate the ranges for integer quantization.
def representative_data_gen():
for i in range(100): # Use a small, diverse subset for calibration
# The generator must yield a list of input tensors.
# Our model expects input shape (None, 784), so each yield should be a batch of 1
yield [tf.constant(train_images_tf[i:i+1], dtype=tf.float32)]
# 2. Converter Configuration for Full Integer Quantization
converter_full_int = tf.lite.TFLiteConverter.from_keras_model(model_to_convert)
converter_full_int.optimizations = [tf.lite.Optimize.DEFAULT] # Enable default optimizations
converter_full_int.representative_dataset = representative_data_gen
converter_full_int.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Ensure that the input and output tensors are also integer.
# This makes the model entirely integer-based, which is optimal for some embedded devices.
# If you don't set this, input/output might remain float32, with internal conversions.
converter_full_int.inference_input_type = tf.int8
converter_full_int.inference_output_type = tf.int8
# 3. Convert and Save
tflite_full_int_quant_model = converter_full_int.convert()
with open('mnist_model_full_int_quant.tflite', 'wb') as f:
f.write(tflite_full_int_quant_model)
print(f"Full Integer Quantized TFLite model saved to mnist_model_full_int_quant.tflite. Size: {len(tflite_full_int_quant_model) / 1024:.2f} KB\n")
# Compare sizes
print(f"Original Keras model size (rough estimate, saved as SavedModel first for accurate comparison): N/A for direct comparison here")
if os.path.exists('mnist_model.tflite'):
print(f"Original (float32) .tflite size: {os.path.getsize('mnist_model.tflite') / 1024:.2f} KB")
if os.path.exists('mnist_model_quant.tflite'):
print(f"Dynamic Range Quantized .tflite size: {os.path.getsize('mnist_model_quant.tflite') / 1024:.2f} KB")
print(f"Full Integer Quantized .tflite size: {len(tflite_full_int_quant_model) / 1024:.2f} KB")
print("Full integer quantization typically results in the smallest model size.\n")
# 4. Run Inference with Full Integer Quantized Model
interpreter_full_int = tf.lite.Interpreter(model_path="mnist_model_full_int_quant.tflite")
interpreter_full_int.allocate_tensors()
input_details_fi = interpreter_full_int.get_input_details()
output_details_fi = interpreter_full_int.get_output_details()
print(f"Input details (Full Integer): {input_details_fi}\n")
print(f"Output details (Full Integer): {output_details_fi}\n")
sample_image_fi = test_images_tf[10:11] # Pick a different sample for variety
true_label_fi = test_labels_tf[10]
# IMPORTANT: Convert input data to the expected int8 type for the fully integer quantized model
# The original image data is float32 (0-1). We need to scale and cast to int8.
# The scaling factor and zero-point are from the input_details, if they were computed by the converter.
# For simplicity, we'll assume a direct scale to [-128, 127] or [0, 255] if no specific range was found.
# In this case, `inference_input_type = tf.int8` means it expects signed 8-bit integers.
# A common pattern is to scale float32 [0,1] to int8 [-128, 127].
# However, if the model expects [0, 255] and then quantizes to int8 internally, the direct casting might be more appropriate.
# Let's check the quantization parameters of the input tensor.
input_quantization_params = input_details_fi[0]['quantization_parameters']
if input_quantization_params and input_quantization_params['scales'] and input_quantization_params['zero_points']:
input_scale = input_quantization_params['scales'][0]
input_zero_point = input_quantization_params['zero_points'][0]
# De-quantize formula: real_value = (quantized_value - zero_point) * scale
# So, quantized_value = real_value / scale + zero_point
input_data_fi = (sample_image_fi / input_scale + input_zero_point).astype(np.int8)
else:
# Fallback if no specific quantization params found, often 0-255 mapped to int8
# For MNIST (0-1 float32), scaling to 0-255 range then casting to int8 might be a simple approach.
# Note: This is an approximation. For true integer models, you'd carefully handle the input range.
input_data_fi = (sample_image_fi * 255).astype(np.int8) # Cast to int8 after scaling 0-1 to 0-255
interpreter_full_int.set_tensor(input_details_fi[0]['index'], input_data_fi)
interpreter_full_int.invoke()
output_data_fi = interpreter_full_int.get_tensor(output_details_fi[0]['index'])
# The output is also int8. We need to de-quantize it to get meaningful probabilities/logits.
output_quantization_params = output_details_fi[0]['quantization_parameters']
if output_quantization_params and output_quantization_params['scales'] and output_quantization_params['zero_points']:
output_scale = output_quantization_params['scales'][0]
output_zero_point = output_quantization_params['zero_points'][0]
dequantized_output = (output_data_fi.astype(np.float32) - output_zero_point) * output_scale
else:
# If no specific output quantization params, assume it maps directly, might be logits or unscaled probs
dequantized_output = output_data_fi.astype(np.float32)
# If the original model used softmax, apply softmax to the dequantized logits
# to get probabilities for comparison.
if model_to_convert.layers[-1].activation == tf.keras.activations.softmax:
dequantized_output = tf.nn.softmax(dequantized_output).numpy()
else:
dequantized_output = dequantized_output.flatten()
predicted_class_fi = np.argmax(dequantized_output)
print(f"True Label: {true_label_fi}")
print(f"Predicted Class (Full Integer Quantized Model): {predicted_class_fi}")
print(f"De-quantized Output (probabilities): {dequantized_output.flatten()}\n")
# Verify with original Keras model for comparison
original_prediction_fi = model_to_convert.predict(test_images_tf[10:11], verbose=0)
original_predicted_class_fi = np.argmax(original_prediction_fi)
print(f"Predicted Class (Original Keras Model): {original_predicted_class_fi}")
Handling int8 inputs and outputs requires careful attention to the quantization parameters (scale and zero-point) determined by the converter. The exact scaling from float to int and vice versa depends on how the model was quantized and its original input/output ranges. This example provides a general approach, but real-world deployment on specific hardware might require more precise handling based on device documentation.
You’ve now learned how to tackle advanced challenges in TensorFlow, from scaling your training with distribution strategies to optimizing and deploying models on resource-constrained edge devices using TensorFlow Lite. These skills are invaluable for bringing your machine learning projects to production. Next, we’ll apply all your knowledge in comprehensive guided projects.