Models

Open in Colab

Note

If you’re in COLAB or have a local CUDA GPU, you can follow along with the more computationally intensive training in this lesson.

For those in COLAB, ensure the session is using a GPU by going to: Runtime > Change runtime type > Hardware accelerator = GPU.

# if you're using colab, then install the required modules
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install --quiet --upgrade keras-tuner pytorch-lightning "ray[tune]" lightning-bolts lightning-flash 'lightning-flash[image]' Pillow==9.0.0

Hyperparameter tuning

Tune the hyperparameters to find the best model.

TensorFlow (Keras)

KerasTuner is a library the helps you pick the best hyperparameters for your model.

Other options

import os

import keras_tuner
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
2022-05-05 15:44:27.679056: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-05-05 15:44:27.679092: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
print("Num GPUs Available: ", len(tf.config.list_physical_devices("GPU")))
Num GPUs Available:  0
2022-05-05 15:44:30.300532: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2022-05-05 15:44:30.300568: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
2022-05-05 15:44:30.300590: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (fv-az90-458): /proc/driver/nvidia/version does not exist

Build the model with the hyperparameters argument.

This contains various objects, such as:

  • The number of units.

  • The activation function to use.

  • Whether to use dropout.

  • The optimiser learning rate.

def build_model(hyperparameters):
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Flatten())
    model.add(
        tf.keras.layers.Dense(
            units=hyperparameters.Int("units", min_value=32, max_value=512, step=32),
            activation=hyperparameters.Choice("activation", ["relu", "tanh"]),
        )
    )
    if hyperparameters.Boolean("dropout"):
        model.add(tf.keras.layers.Dropout(rate=0.25))
    model.add(tf.keras.layers.Dense(10, activation="softmax"))
    learning_rate = hyperparameters.Float(
        "lr", min_value=1e-4, max_value=1e-2, sampling="log"
    )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss="categorical_crossentropy",
        metrics=["accuracy"],
    )
    return model

First, prepare a directory to store all the hyperparameters:

hyperparameters_path = f"{os.getcwd()}/models/hyperparameters"
if not os.path.exists(hyperparameters_path):
    os.makedirs(hyperparameters_path)

Now, initialise the tuner.

This uses Random Search, though you could also use other methods.

tuner = keras_tuner.RandomSearch(
    hypermodel=build_model,  # the model building function
    objective="val_accuracy",  # the objective to optimise
    max_trials=3,  # the number of trials for the search
    executions_per_trial=2,  # the number of models to build and fit per trial
    overwrite=True,  # whether to overwrite previous results
    directory=hyperparameters_path,  # the path for the hyperparameter results
    project_name="hp_example",
)
2022-05-05 15:44:30.325918: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.

View the search space:

tuner.search_space_summary()
Search space summary
Default search space size: 4
units (Int)
{'default': None, 'conditions': [], 'min_value': 32, 'max_value': 512, 'step': 32, 'sampling': None}
activation (Choice)
{'default': 'relu', 'conditions': [], 'values': ['relu', 'tanh'], 'ordered': False}
dropout (Boolean)
{'default': False, 'conditions': []}
lr (Float)
{'default': 0.0001, 'conditions': [], 'min_value': 0.0001, 'max_value': 0.01, 'step': None, 'sampling': 'log'}

Load in the MNIST dataset for this example:

(x, y), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x[:-10000]
x_val = x[-10000:]
y_train = y[:-10000]
y_val = y[-10000:]

x_train = np.expand_dims(x_train, -1).astype("float32") / 255.0
x_val = np.expand_dims(x_val, -1).astype("float32") / 255.0
x_test = np.expand_dims(x_test, -1).astype("float32") / 255.0

num_classes = 10
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_val = tf.keras.utils.to_categorical(y_val, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

Start the search:

if IN_COLAB:
    tuner.search(
        x_train, y_train, epochs=2, validation_data=(x_val, y_val), verbose=False
    )

View the results.

You can also visualise them using TensorBoard.

if IN_COLAB:
    tuner.results_summary()

Select the best model:

if IN_COLAB:
    best_model = tuner.get_best_models()[0]

Show the best model’s summary.

Note, the Sequential model will need to be built first with a input shape.

if IN_COLAB:
    best_model.build(input_shape=(None, 28, 28))
    best_model.summary()

PyTorch (Lightning)

Ray Tune is a tool for distributed hyperparameter tuning.

import math
import os

import pytorch_lightning as pl
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from ray import tune
from ray.tune import CLIReporter
from ray.tune.integration.pytorch_lightning import (
    TuneReportCallback,
    TuneReportCheckpointCallback,
)
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
PATH_DATASET = f"{os.getcwd()}/data"
AVAIL_GPUS = min(1, torch.cuda.device_count())
NUM_EPOCHS = 1

Create the model:

class MNISTModel(LightningModule):
    def __init__(self, config, data_dir=PATH_DATASET):
        super(MNISTModel, self).__init__()

        self.data_dir = data_dir
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),  # specific to MNIST
            ]
        )

        # setup the hyperparameters within the config dictionary
        self.layer_1_size = config["layer_1_size"]
        self.layer_2_size = config["layer_2_size"]
        self.learning_rate = config["learning_rate"]
        self.batch_size = config["batch_size"]

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_size)
        self.layer_2 = torch.nn.Linear(self.layer_1_size, self.layer_2_size)
        self.layer_3 = torch.nn.Linear(self.layer_2_size, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)

        x = self.layer_1(x)
        x = torch.relu(x)

        x = self.layer_2(x)
        x = torch.relu(x)

        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)

        return x

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def accuracy(self, logits, labels):
        _, predicted = torch.max(logits.data, 1)
        correct = (predicted == labels).sum().item()
        accuracy = correct / len(labels)
        return torch.tensor(accuracy)

    def training_step(self, train_batch, batch_index):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)

        self.log("ptl/train_loss", loss)
        self.log("ptl/train_accuracy", accuracy)
        return loss

    def validation_step(self, val_batch, batch_index):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)
        return {"val_loss": loss, "val_accuracy": accuracy}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
        self.log("ptl/val_loss", avg_loss)
        self.log("ptl/val_accuracy", avg_acc)

    def prepare_data(self):  # download the data, once if distributed
        MNIST(self.data_dir, train=True, download=True)

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            ds_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.ds_train, self.ds_val = random_split(ds_full, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.ds_train, batch_size=int(self.batch_size))

    def val_dataloader(self):
        return DataLoader(self.ds_val, batch_size=int(self.batch_size))

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

First, setup and train the model without tuning:

def mnist_without_tuning(config):
    model = MNISTModel(config)

    trainer = Trainer(max_epochs=NUM_EPOCHS, callbacks=TQDMProgressBar(refresh_rate=20))

    trainer.fit(model)

    return trainer
def run_mnist_without_tuning():
    # here, the config has one value for each hyperparameter only
    config = {
        "layer_1_size": 128,
        "layer_2_size": 256,
        "learning_rate": 1e-3,
        "batch_size": 64,
    }
    trainer = mnist_without_tuning(config)
    return trainer
if IN_COLAB:
    trainer = run_mnist_without_tuning()
if IN_COLAB:
    trainer.validate()

Now, setup and train the model with tuning.

This will include a tune reporting callback (more on these later). This reports back the tuning results.

tune_reporting_callback = TuneReportCallback(
    {"loss": "ptl/val_loss", "mean_accuracy": "ptl/val_accuracy"}, on="validation_end"
)
def mnist_tuning(
    config, num_epochs=NUM_EPOCHS, num_gpus=AVAIL_GPUS, data_dir=PATH_DATASET
):
    model = MNISTModel(config, data_dir)

    trainer = Trainer(
        max_epochs=num_epochs,
        gpus=num_gpus,
        logger=TensorBoardLogger(save_dir=tune.get_trial_dir(), name="", version="."),
        progress_bar_refresh_rate=0,
        callbacks=[tune_reporting_callback],
    )

    trainer.fit(model)
def run_mnist_tuning(
    num_samples=10,
    num_epochs=NUM_EPOCHS,
    gpus_per_trial=AVAIL_GPUS,
    data_dir=PATH_DATASET,
):
    # setup the hyperparameter space to explore
    config = {
        "layer_1_size": tune.choice([64, 128]),
        "layer_2_size": tune.choice([128, 256]),
        "learning_rate": tune.loguniform(1e-3, 1e-2),
        "batch_size": tune.choice([32, 64]),
    }

    # define the scheduler
    # the Asynchronous Hyperband scheduler stops poor trials
    # https://blog.ml.cmu.edu/2018/12/12/massively-parallel-hyperparameter-optimization/
    scheduler = ASHAScheduler(max_t=NUM_EPOCHS, grace_period=1, reduction_factor=2)

    # setup the reporting metrics
    reporter = CLIReporter(
        parameter_columns=[
            "layer_1_size",
            "layer_2_size",
            "learning_rate",
            "batch_size",
        ],
        metric_columns=["loss", "mean_accuracy", "training_iteration"],
    )

    # pass the constants to the train function
    train_function_with_parameters = tune.with_parameters(
        mnist_tuning,
        num_epochs=NUM_EPOCHS,
        num_gpus=AVAIL_GPUS,
        data_dir=PATH_DATASET,
    )
    resources_per_trial = {"cpu": 1, "gpu": AVAIL_GPUS}

    # run the tuning
    tuning_results = tune.run(
        train_function_with_parameters,
        resources_per_trial=resources_per_trial,
        metric="loss",
        mode="min",
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter,
        name="run_mnist_tuning",
    )

    return tuning_results
if IN_COLAB:
    tuning_results = run_mnist_tuning()

And now we can see the best hyperparamters:

if IN_COLAB:
    print(tuning_results.best_config)

Transfer learning

Transfer learning is where a model that has been pre-trained on a problem is transferred to another similar problem.

For example, if a model learnt to classify images it can be transferred to other image classification problems.

This works because the pre-trained model has extracted features from the similar task, for example:

  • Lower-level: lines, curves, etc.

  • Higher-level: eyes, ears, etc.

Commonly, the steps of transfer learning are:

  • Instantiate a pre-trained model and load pre-trained weights into it.

  • Extract layers and freeze them to retain their information (e.g., for lower level features).

  • Add new trainable layers on top of these frozen layers.

  • Train the new layers on your dataset.

  • Optional: Unfreeze the frozen layers and train these on the dataset with a very low learning rate (i.e., fine tuning).

    • This is done when the training dataset is larger and very similar to the original pre-trained dataset.

Transfer learning is useful when:

  • You have a small dataset.

  • You want to take advantage of huge models without the costs of training them.

Example: Transfer learning

TensorFlow (Keras)

There are range of pre-trained models available in Keras Applications.

Download and split the data:

if IN_COLAB:
    # tfds.disable_progress_bar()

    ds_train, ds_val, ds_test = tfds.load(
        "cats_vs_dogs",
        # Reserve 10% for validation and 10% for test
        split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
        as_supervised=True,  # Include labels
    )

    print("Number of training samples: %d" % tf.data.experimental.cardinality(ds_train))
    print("Number of validation samples: %d" % tf.data.experimental.cardinality(ds_val))
    print("Number of test samples: %d" % tf.data.experimental.cardinality(ds_test))

View a few of the training samples:

if IN_COLAB:
    labels = {0: "Cat", 1: "Dog"}

    plt.figure(figsize=(10, 10))
    for i, (image, label) in enumerate(ds_train.take(9)):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(image)
        plt.title(labels[int(label)])
        plt.axis("off")

First, preprocess the data.

This is to standardise the image sizes and normalise the pixel values.

if IN_COLAB:
    IMAGE_SIZE = (150, 150)

    ds_train = ds_train.map(lambda x, y: (tf.image.resize(x, IMAGE_SIZE), y))
    ds_val = ds_val.map(lambda x, y: (tf.image.resize(x, IMAGE_SIZE), y))
    ds_test = ds_test.map(lambda x, y: (tf.image.resize(x, IMAGE_SIZE), y))

Now, we can cache, batch, and prefetch the data:

if IN_COLAB:
    BATCH_SIZE = 32
    AUTOTUNE = tf.data.AUTOTUNE

    ds_train = ds_train.cache().batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
    ds_val = ds_val.cache().batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
    ds_test = ds_test.cache().batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)

Apply some data augmentation:

data_augmentation = tf.keras.Sequential(
    [
        tf.keras.layers.RandomFlip("horizontal"),
        tf.keras.layers.RandomRotation(0.1),
    ]
)

Now, build the model.

Steps:

  • First, load the pre-trained model as the base.

    • This example uses the Xception model (other options are here).

    • The weights were pre-trained on ImageNet (a large image dataset).

    • The top layer from the Xception classifier is not included (i.e., the one that predicts the ImageNet categories).

  • Freeze the layers from the base model.

  • Add new layers on top.

Note

Batch Normalisation layers should be kept in inference mode (i.e., the base model should keep training = False) when the base model is unfrozen for fine-tuning.

More information here.

base_model = tf.keras.applications.Xception(
    weights="imagenet",  # use weights from ImageNet
    input_shape=(150, 150, 3),
    include_top=False,  # don't include Xceptions classifier for ImageNet
)
base_model.trainable = False  # freeze the pre-trained model

inputs = tf.keras.Input(shape=(150, 150, 3))  # create new model on top
x = data_augmentation(inputs)  # apply data augmentation

scale_layer = tf.keras.layers.Rescaling(
    scale=1.0 / 127.5, offset=-1
)  # rescale from [0, 255] to [-1.0, 1.0]
x = scale_layer(x)

x = base_model(
    x, training=False
)  # keep BatchNorm layers in inference mode (see note above)
x = tf.keras.layers.GlobalAveragePooling2D()(x)  # convert 2D locations to vector
x = tf.keras.layers.Dropout(0.2)(x)  # for regularisation

outputs = tf.keras.layers.Dense(1)(x)  # prediction layer

model = tf.keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
   16384/83683744 [..............................] - ETA: 28s

   65536/83683744 [..............................] - ETA: 1:44

  212992/83683744 [..............................] - ETA: 51s 

  614400/83683744 [..............................] - ETA: 27s

 1032192/83683744 [..............................] - ETA: 20s

 2048000/83683744 [..............................] - ETA: 12s

 4014080/83683744 [>.............................] - ETA: 7s 

 6717440/83683744 [=>............................] - ETA: 4s

 9912320/83683744 [==>...........................] - ETA: 3s

13221888/83683744 [===>..........................] - ETA: 3s

16498688/83683744 [====>.........................] - ETA: 2s

20004864/83683744 [======>.......................] - ETA: 2s

23363584/83683744 [=======>......................] - ETA: 1s

26902528/83683744 [========>.....................] - ETA: 1s

30162944/83683744 [=========>....................] - ETA: 1s

33832960/83683744 [===========>..................] - ETA: 1s

37191680/83683744 [============>.................] - ETA: 1s

40730624/83683744 [=============>................] - ETA: 1s

43974656/83683744 [==============>...............] - ETA: 1s

46891008/83683744 [===============>..............] - ETA: 0s

49938432/83683744 [================>.............] - ETA: 0s

53116928/83683744 [==================>...........] - ETA: 0s

56508416/83683744 [===================>..........] - ETA: 0s

59637760/83683744 [====================>.........] - ETA: 0s

61898752/83683744 [=====================>........] - ETA: 0s

64241664/83683744 [======================>.......] - ETA: 0s

66437120/83683744 [======================>.......] - ETA: 0s

69009408/83683744 [=======================>......] - ETA: 0s

71925760/83683744 [========================>.....] - ETA: 0s

75186176/83683744 [=========================>....] - ETA: 0s

78331904/83683744 [===========================>..] - ETA: 0s

81969152/83683744 [============================>.] - ETA: 0s

83689472/83683744 [==============================] - 2s 0us/step

83697664/83683744 [==============================] - 2s 0us/step
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 150, 150, 3)]     0         
                                                                 
 sequential_1 (Sequential)   (None, 150, 150, 3)       0         
                                                                 
 rescaling (Rescaling)       (None, 150, 150, 3)       0         
                                                                 
 xception (Functional)       (None, 5, 5, 2048)        20861480  
                                                                 
 global_average_pooling2d (G  (None, 2048)             0         
 lobalAveragePooling2D)                                          
                                                                 
 dropout (Dropout)           (None, 2048)              0         
                                                                 
 dense_2 (Dense)             (None, 1)                 2049      
                                                                 
=================================================================
Total params: 20,863,529
Trainable params: 2,049
Non-trainable params: 20,861,480
_________________________________________________________________

1% (approx. 2,000) of the parameters are trainable, with 99% (approx. 20,000,000) from the pre-trained model.

That will save a lot of computation.

Now, train the top layers only:

model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.BinaryAccuracy(name="accuracy")],
)
if IN_COLAB:
    # just 1 epoch for the demonstration, would increase this to say 20 normally
    NUM_EPOCHS = 1
    model.fit(ds_train, epochs=NUM_EPOCHS, validation_data=ds_val)

Now you could perform fine tuning.

This is where you unfreeze the frozen layers and train these on the dataset with a very low learning rate.

TensorFlow Hub

There are many pre-trained models on TensorFlow Hub for use in transfer learning.

Download the data:

data_root = tf.keras.utils.get_file(
    "flower_photos",
    "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz",
    untar=True,
)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    16384/228813984 [..............................] - ETA: 1:16

    73728/228813984 [..............................] - ETA: 2:55

   286720/228813984 [..............................] - ETA: 1:25

  1081344/228813984 [..............................] - ETA: 33s 

  4046848/228813984 [..............................] - ETA: 11s

  9764864/228813984 [>.............................] - ETA: 5s 

 14983168/228813984 [>.............................] - ETA: 4s

 20348928/228813984 [=>............................] - ETA: 3s

 26394624/228813984 [==>...........................] - ETA: 3s

 32522240/228813984 [===>..........................] - ETA: 2s

 38584320/228813984 [====>.........................] - ETA: 2s

 44761088/228813984 [====>.........................] - ETA: 2s

 50888704/228813984 [=====>........................] - ETA: 2s

 57016320/228813984 [======>.......................] - ETA: 2s

 63078400/228813984 [=======>......................] - ETA: 1s

 69189632/228813984 [========>.....................] - ETA: 1s

 75382784/228813984 [========>.....................] - ETA: 1s

 81510400/228813984 [=========>....................] - ETA: 1s

 87621632/228813984 [==========>...................] - ETA: 1s

 93691904/228813984 [===========>..................] - ETA: 1s

 99819520/228813984 [============>.................] - ETA: 1s

106012672/228813984 [============>.................] - ETA: 1s

112189440/228813984 [=============>................] - ETA: 1s

118317056/228813984 [==============>...............] - ETA: 1s

124379136/228813984 [===============>..............] - ETA: 1s

130490368/228813984 [================>.............] - ETA: 0s

136093696/228813984 [================>.............] - ETA: 0s

141582336/228813984 [=================>............] - ETA: 0s

147382272/228813984 [==================>...........] - ETA: 0s

153231360/228813984 [===================>..........] - ETA: 0s

159031296/228813984 [===================>..........] - ETA: 0s

164798464/228813984 [====================>.........] - ETA: 0s

170762240/228813984 [=====================>........] - ETA: 0s

176873472/228813984 [======================>.......] - ETA: 0s

182804480/228813984 [======================>.......] - ETA: 0s

188997632/228813984 [=======================>......] - ETA: 0s

195174400/228813984 [========================>.....] - ETA: 0s

201302016/228813984 [=========================>....] - ETA: 0s

207495168/228813984 [==========================>...] - ETA: 0s

213557248/228813984 [==========================>...] - ETA: 0s

219717632/228813984 [===========================>..] - ETA: 0s

225861632/228813984 [============================>.] - ETA: 0s

228818944/228813984 [==============================] - 2s 0us/step

228827136/228813984 [==============================] - 2s 0us/step
BATCH_SIZE = 32
IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224

ds_train = tf.keras.utils.image_dataset_from_directory(
    str(data_root),
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
    batch_size=BATCH_SIZE,
)

val_ds = tf.keras.utils.image_dataset_from_directory(
    str(data_root),
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
    batch_size=BATCH_SIZE,
)
Found 3670 files belonging to 5 classes.
Using 2936 files for training.
Found 3670 files belonging to 5 classes.
Using 734 files for validation.
class_names = np.array(ds_train.class_names)
num_classes = len(class_names)
class_names
array(['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'],
      dtype='<U10')

Pre-process the data:

normalization_layer = tf.keras.layers.Rescaling(1.0 / 255)
ds_train = ds_train.map(lambda x, y: (normalization_layer(x), y))
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))

Cache and prefetch the data:

AUTOTUNE = tf.data.AUTOTUNE
ds_train = ds_train.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

Download the headless pre-trained model (i.e., without the top layer):

mobilenet_v2 = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"
inception_v3 = "https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4"

feature_extractor_model = (
    mobilenet_v2  # could change to inception_v2, or choose a different one
)

The download location defaults to a local temporary directory.

To change this, set the following:

import os
os.environ["TFHUB_CACHE_DIR"] = "/nobackup/username/tf_hub_modules"

Create the headless pre-trained layer by setting trainable to False:

import tensorflow_hub as hub
feature_extractor_layer = hub.KerasLayer(
    feature_extractor_model,
    input_shape=(IMAGE_HEIGHT, IMAGE_WIDTH, 3),
    trainable=False,  # remember to set this to False
)

Create the new model with the pre-trained layer as the base and a new classification layer on top:

model = tf.keras.Sequential(
    [feature_extractor_layer, tf.keras.layers.Dense(num_classes)]
)

model.summary()
Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 keras_layer (KerasLayer)    (None, 1280)              2257984   
                                                                 
 dense_3 (Dense)             (None, 5)                 6405      
                                                                 
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________

Compile the new model:

model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

Train the new model:

if IN_COLAB:
    NUM_EPOCHS = 3
    history = model.fit(ds_train, validation_data=val_ds, epochs=NUM_EPOCHS)

PyTorch (Lightning)

PyTorch Lightning can use a pre-trained model as a backbone in the Model definition.

import torch
import torchvision.models as models
from pytorch_lightning import LightningModule
class ImagenetTransferLearning(LightningModule):
    def __init__(self):
        super().__init__()

        # init a pretrained resnet
        backbone = models.resnet50(pretrained=True)

        # extract some setup details
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)

        # use the pretrained model to classify cifar-10 (10 image classes)
        num_target_classes = 10
        self.classifier = nn.Linear(num_filters, num_target_classes)

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        # ... the rest of your LightningModule ...

PyTorch Lightning Bolts

PyTorch Lightning Bolts has many pre-trained models ready for you to import.

from pl_bolts.models.autoencoders import VAE
model = VAE(input_height=32, pretrained="imagenet2012")
encoder = model.encoder
encoder.eval()
ResNetEncoder(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): EncoderBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): EncoderBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): EncoderBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): EncoderBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): EncoderBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): EncoderBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): EncoderBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): EncoderBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
)

PyTorch Lightning-Flash

If you remember, PyTorch Lightning-Flash is the even higher-level API for PyTorch that provides abstractions above PyTorch Lightning for fast prototyping.

import os

import flash
import torch
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
from pytorch_lightning import seed_everything
seed_everything(42)
Global seed set to 42
42

Download and organize the data:

data_path = f"{os.getcwd()}/data"
download_data(
    "https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", f"{data_path}/"
)

datamodule = ImageClassificationData.from_folders(
    train_folder=f"{data_path}/hymenoptera_data/train/",
    val_folder=f"{data_path}/hymenoptera_data/val/",
    test_folder=f"{data_path}/hymenoptera_data/test/",
    batch_size=4,
    transform_kwargs={
        "image_size": (196, 196),
        "mean": (0.485, 0.456, 0.406),
        "std": (0.229, 0.224, 0.225),
    },
)

Build the model using desired pre-trained model (e.g., ResNet18):

model = ImageClassifier(backbone="resnet18", labels=datamodule.labels)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/runner/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth

Create the trainer (run one epoch for demo):

trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

Finetune the model.

strategy="freeze" keeps the pre-trained model (backbone) frozen throughout.

if IN_COLAB:
    trainer.finetune(
        model,
        datamodule=datamodule,
        strategy="freeze",  # remember to have the strategy set to "freeze"
    )

Save the model:

model_path = f"{os.getcwd()}/models"
if IN_COLAB:
    trainer.save_checkpoint(f"{model_path}/image_classification_model.pt")

Callbacks

Callbacks are objects that get called by the model at different points during training, often after each batch or epoch.

For example, they could be used to:

  • Save a model version at regularly intervals or once attained a metric threshold (i.e., checkpointing).

  • Monitor and profile the training progress (i.e., TensorBoard).

  • Change the learning rate when the training plateaus.

  • Fine tuning when the training plateaus.

Checkpoints

You can save a model (/ model weights) at custom checkpoints e.g., the latest best model in terms of accuracy per epoch.

The save_best_only option is useful as after training only the best model will have been saved.

TensorFlow (Keras)

path_models = f"{os.getcwd()}/models"
callback_tf_model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=path_models + "/model_{epoch}",
    save_freq="epoch",  # save a model version at the end of each epoch
    save_best_only=True,  # only save a model if val_accuracy improved
    monitor="val_accuracy",
    # save_weights_only=True
)

PyTorch (Lightning)

Checkpointing enabled by default in PyTorch Lightning.

Model checkpoints are saved per epoch to lightning_logs/version_X/checkpoints.

Fault tolerance

For longer or distributed training, it’s helpful to save the model at regular intervals in case it crashes during training.

TensorFlow (Keras)

This is done via BackupAndRestore (previously was done using ModelCheckpoint).

callback_tf_model_backup = tf.keras.callbacks.BackupAndRestore(
    backup_dir=f"{path_models}/backup"
)

PyTorch (Lightning)

By default, if Trainer.fit() fails, then it can be restarted automatically from the beginning of the epoch it failed on.

With Fault Tolerant Training, when Trainer.fit() fails in the middle of an epoch during training or validation, Lightning will restart exactly where it failed, and everything will be restored.

This is enabled via:

PL_FAULT_TOLERANT_TRAINING=1 python script.py

Logging and Profiling

Tensorboard is a browser-based application that provides live plots of loss and metrics for training and evaluation.

TensorFlow (Keras)

callback_tf_tensorboard_with_profiling = tf.keras.callbacks.TensorBoard(
    log_dir=f"{path_models}/logs",
    profile_batch=(1, 5),  # profile batches 1 to 5
    update_freq="epoch",
)
2022-05-05 15:44:46.215027: I tensorflow/core/profiler/lib/profiler_session.cc:110] Profiler session initializing.
2022-05-05 15:44:46.215057: I tensorflow/core/profiler/lib/profiler_session.cc:125] Profiler session started.
2022-05-05 15:44:46.215858: I tensorflow/core/profiler/lib/profiler_session.cc:143] Profiler session tear down.
# may need to remove previous logs
# import shutil
# shutil.rmtree(f"{path_models}/logs")

View them with (we’ll see an example in a bit):

tensorboard --logdir=/full_path_to_your_logs

Also, in-line in Jupyter Notebooks / Google Colab:

from tensorboard import notebook
notebook.list() # list open tensorboard instances
notebook.display(port=6006, height=1000) # open tensorboard

# or

%load_ext tensorboard
%tensorboard --logdir logs

PyTorch (Lightning)

Logging

PyTorch Lightning uses TensorBoard for logging by default.

As before, you can view the logs using:

%load_ext tensorboard
%tensorboard --logdir lightning_logs/
Profiling

You can profile (time and memory) with PyTorch using the torch.autograd.profiler context manager:

import torch
import torch.autograd.profiler as profiler
import torchvision.models as models
model = models.resnet18()
inputs = torch.randn(5, 3, 244, 244)
with profiler.profile() as prof:
    with profiler.record_function("model_inference"):
        model(inputs)

This can then help you find bottlenecks in the code:

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=5))
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  model_inference         2.31%       4.710ms        99.80%     203.840ms     203.840ms             1  
                     aten::conv2d         0.05%     108.000us        64.50%     131.735ms       6.587ms            20  
                aten::convolution         0.18%     366.000us        64.44%     131.627ms       6.581ms            20  
               aten::_convolution         0.34%     703.000us        64.26%     131.261ms       6.563ms            20  
         aten::mkldnn_convolution        63.56%     129.814ms        63.92%     130.558ms       6.528ms            20  
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 204.254ms

In PyTorch Lightning, for a simple profiler over method calls and time spent calling them, use:

Trainer(profiler=True)

For more control, use:

from pytorch_lightning.profiler import AdvancedProfiler

Trainer(profiler=AdvancedProfiler())

Early stopping

Stop training early when a monitored metric has stopped improving.

TensorFlow (Keras)

callback_tf_early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor="val_accuracy",  # quantity to be monitored
    patience=2,  # "no longer improving" also means "for at least 2 epochs"
)

PyTorch (Lightning)

callback_torch_early_stopping = pl.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=2,
)

Learning rate decay

Reduce learning rate when a metric has stopped improving (i.e., where it’s reached a plateau).

TensorFlow (Keras)

callback_tf_learning_rate_decay = tf.keras.callbacks.ReduceLROnPlateau(
    monitor="val_accuracy",
    patience=5,
    factor=0.2,  # factor by which the learning rate will be reduced: new_lr = lr * factor
    min_lr=0.001,  # lower bound on the learning rate
)

Example: Callbacks

TensorFlow (Keras)

callbacks_tf = [
    callback_tf_model_checkpoint,
    callback_tf_model_backup,
    callback_tf_tensorboard_with_profiling,
    callback_tf_early_stopping,
    callback_tf_learning_rate_decay,
]
if IN_COLAB:
    epochs = 3
    model.fit(ds_train, epochs=epochs, validation_data=ds_val, callbacks=callbacks_tf)

View the results in TensorBoard:

# from tensorboard import notebook
# notebook.list()
# notebook.display(port=6006, height=1000)
if IN_COLAB:
    %load_ext tensorboard
    %tensorboard --logdir /content/models/logs

PyTorch (Lightning)

# callbacks_torch = [
#     callback_torch_early_stopping,
# ]
# Trainer(callbacks=callbacks_torch)

Questions

Question 1

What are possible hyperparameters that could be tuned?

  • Learning rate and the number of units.

  • Weights and biases.

Question 2

What is transfer learning?

Question 3

Why is transfer learning useful?

Question 4

What is a key step in transfer learning?

Question 5

What are callbacks?

Question 6

Name three examples of callbacks.

Key Points

Important

  • Tune hyperparamaters for the best model fit.

  • Use transfer learning to save computation on similar problems.

  • Consider using callbacks to help with model training, such as:

    • Checkpoints.

    • Fault tolerance.

    • Logging.

    • Profiling.

    • Early stopping.

    • Learning rate decay.

Further information

Good practices

  • See if there is a model architecture (and parameters) that already addresses the task.

  • Consider the tradeoff between model complexity and size.

    • For high accuracy, maybe need a large and complex model.

    • For less precision, smaller models use less disk space, memory, and are faster.

  • Best practices for PyTorch Lightning model training.

  • Consider sharing your model for repoducibility if you can.

Other options

There are many other options for hyperparameters tuning, including:

  • Hyperopt

    • A general purpose optimisation library.

  • scikit-optimise

    • A general purpose optimisation library.

  • TPOT

    • Automated optimisation library using genetic programming.