# Models

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ARCTraining/swd8_intro_ml/blob/main/docs/04_models.ipynb)

```{note}
If youâ€™re in COLAB or have a local CUDA GPU, you can follow along with the more computationally intensive training in this lesson.

For those in COLAB, ensure the session is using a GPU by going to: Runtime > Change runtime type > Hardware accelerator = GPU.
```

In [3]:
# if you're using colab, then install the required modules
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install --quiet --upgrade keras-tuner pytorch-lightning "ray[tune]" lightning-bolts lightning-flash 'lightning-flash[image]' Pillow==9.0.0

## Hyperparameter tuning

Tune the [hyperparameters](hyperparameters) to find the best model.

### TensorFlow (Keras)

[KerasTuner](https://keras.io/guides/keras_tuner/getting_started/) is a library the helps you pick the best hyperparameters for your model.

#### Other options

- [TensorFlow Model Optimisation Toolkit](https://www.tensorflow.org/model_optimization)
    - A suite of tools for optimising machine learning models.
- [AutoKeras](https://autokeras.com/)
    - An AutoML system to automate the building of the machine learning model.

In [None]:
import os

import keras_tuner
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices("GPU")))

Build the model with the `hyperparameters` argument.

This contains various objects, such as:

- The number of units.
- The activation function to use.
- Whether to use dropout.
- The optimiser learning rate.

In [None]:
def build_model(hyperparameters):
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Flatten())
    model.add(
        tf.keras.layers.Dense(
            units=hyperparameters.Int("units", min_value=32, max_value=512, step=32),
            activation=hyperparameters.Choice("activation", ["relu", "tanh"]),
        )
    )
    if hyperparameters.Boolean("dropout"):
        model.add(tf.keras.layers.Dropout(rate=0.25))
    model.add(tf.keras.layers.Dense(10, activation="softmax"))
    learning_rate = hyperparameters.Float(
        "lr", min_value=1e-4, max_value=1e-2, sampling="log"
    )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss="categorical_crossentropy",
        metrics=["accuracy"],
    )
    return model

First, prepare a directory to store all the hyperparameters:

In [None]:
hyperparameters_path = f"{os.getcwd()}/models/hyperparameters"
if not os.path.exists(hyperparameters_path):
    os.makedirs(hyperparameters_path)

Now, initialise the tuner.

This uses Random Search, though you could also use [other methods](https://www.tensorflow.org/tutorials/keras/keras_tuner#instantiate_the_tuner_and_perform_hypertuning).

In [None]:
tuner = keras_tuner.RandomSearch(
    hypermodel=build_model,  # the model building function
    objective="val_accuracy",  # the objective to optimise
    max_trials=3,  # the number of trials for the search
    executions_per_trial=2,  # the number of models to build and fit per trial
    overwrite=True,  # whether to overwrite previous results
    directory=hyperparameters_path,  # the path for the hyperparameter results
    project_name="hp_example",
)

View the search space:

In [None]:
tuner.search_space_summary()

Load in the MNIST dataset for this example:

In [None]:
(x, y), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x[:-10000]
x_val = x[-10000:]
y_train = y[:-10000]
y_val = y[-10000:]

x_train = np.expand_dims(x_train, -1).astype("float32") / 255.0
x_val = np.expand_dims(x_val, -1).astype("float32") / 255.0
x_test = np.expand_dims(x_test, -1).astype("float32") / 255.0

num_classes = 10
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_val = tf.keras.utils.to_categorical(y_val, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

Start the search:

In [None]:
if IN_COLAB:
    tuner.search(
        x_train, y_train, epochs=2, validation_data=(x_val, y_val), verbose=False
    )

View the results.

You can also visualise them using [TensorBoard](https://keras.io/guides/keras_tuner/visualize_tuning/).

In [None]:
if IN_COLAB:
    tuner.results_summary()

Select the best model:

In [None]:
if IN_COLAB:
    best_model = tuner.get_best_models()[0]

Show the best model's summary.

Note, the `Sequential` model will need to be built first with a input shape.

In [None]:
if IN_COLAB:
    best_model.build(input_shape=(None, 28, 28))
    best_model.summary()

### [PyTorch (Lightning)](https://docs.ray.io/en/latest/tune/tutorials/tune-pytorch-lightning.html)

[Ray Tune](https://docs.ray.io/en/latest/tune.html) is a tool for distributed hyperparameter tuning.

In [None]:
import math
import os

import pytorch_lightning as pl
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from ray import tune
from ray.tune import CLIReporter
from ray.tune.integration.pytorch_lightning import (
    TuneReportCallback,
    TuneReportCheckpointCallback,
)
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST

In [None]:
PATH_DATASET = f"{os.getcwd()}/data"
AVAIL_GPUS = min(1, torch.cuda.device_count())
NUM_EPOCHS = 1

Create the model:

In [None]:
class MNISTModel(LightningModule):
    def __init__(self, config, data_dir=PATH_DATASET):
        super(MNISTModel, self).__init__()

        self.data_dir = data_dir
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),  # specific to MNIST
            ]
        )

        # setup the hyperparameters within the config dictionary
        self.layer_1_size = config["layer_1_size"]
        self.layer_2_size = config["layer_2_size"]
        self.learning_rate = config["learning_rate"]
        self.batch_size = config["batch_size"]

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_size)
        self.layer_2 = torch.nn.Linear(self.layer_1_size, self.layer_2_size)
        self.layer_3 = torch.nn.Linear(self.layer_2_size, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)

        x = self.layer_1(x)
        x = torch.relu(x)

        x = self.layer_2(x)
        x = torch.relu(x)

        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)

        return x

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def accuracy(self, logits, labels):
        _, predicted = torch.max(logits.data, 1)
        correct = (predicted == labels).sum().item()
        accuracy = correct / len(labels)
        return torch.tensor(accuracy)

    def training_step(self, train_batch, batch_index):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)

        self.log("ptl/train_loss", loss)
        self.log("ptl/train_accuracy", accuracy)
        return loss

    def validation_step(self, val_batch, batch_index):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)
        return {"val_loss": loss, "val_accuracy": accuracy}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
        self.log("ptl/val_loss", avg_loss)
        self.log("ptl/val_accuracy", avg_acc)

    def prepare_data(self):  # download the data, once if distributed
        MNIST(self.data_dir, train=True, download=True)

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            ds_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.ds_train, self.ds_val = random_split(ds_full, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.ds_train, batch_size=int(self.batch_size))

    def val_dataloader(self):
        return DataLoader(self.ds_val, batch_size=int(self.batch_size))

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

First, setup and train the model _without_ tuning:

In [None]:
def mnist_without_tuning(config):
    model = MNISTModel(config)

    trainer = Trainer(max_epochs=NUM_EPOCHS, callbacks=TQDMProgressBar(refresh_rate=20))

    trainer.fit(model)

    return trainer

In [None]:
def run_mnist_without_tuning():
    # here, the config has one value for each hyperparameter only
    config = {
        "layer_1_size": 128,
        "layer_2_size": 256,
        "learning_rate": 1e-3,
        "batch_size": 64,
    }
    trainer = mnist_without_tuning(config)
    return trainer

In [None]:
if IN_COLAB:
    trainer = run_mnist_without_tuning()

In [None]:
if IN_COLAB:
    trainer.validate()

Now, setup and train the model _with_ tuning.

This will include a tune reporting callback (more on these later). This reports back the tuning results.

In [None]:
tune_reporting_callback = TuneReportCallback(
    {"loss": "ptl/val_loss", "mean_accuracy": "ptl/val_accuracy"}, on="validation_end"
)

In [None]:
def mnist_tuning(
    config, num_epochs=NUM_EPOCHS, num_gpus=AVAIL_GPUS, data_dir=PATH_DATASET
):
    model = MNISTModel(config, data_dir)

    trainer = Trainer(
        max_epochs=num_epochs,
        gpus=num_gpus,
        logger=TensorBoardLogger(save_dir=tune.get_trial_dir(), name="", version="."),
        progress_bar_refresh_rate=0,
        callbacks=[tune_reporting_callback],
    )

    trainer.fit(model)

In [None]:
def run_mnist_tuning(
    num_samples=10,
    num_epochs=NUM_EPOCHS,
    gpus_per_trial=AVAIL_GPUS,
    data_dir=PATH_DATASET,
):
    # setup the hyperparameter space to explore
    config = {
        "layer_1_size": tune.choice([64, 128]),
        "layer_2_size": tune.choice([128, 256]),
        "learning_rate": tune.loguniform(1e-3, 1e-2),
        "batch_size": tune.choice([32, 64]),
    }

    # define the scheduler
    # the Asynchronous Hyperband scheduler stops poor trials
    # https://blog.ml.cmu.edu/2018/12/12/massively-parallel-hyperparameter-optimization/
    scheduler = ASHAScheduler(max_t=NUM_EPOCHS, grace_period=1, reduction_factor=2)

    # setup the reporting metrics
    reporter = CLIReporter(
        parameter_columns=[
            "layer_1_size",
            "layer_2_size",
            "learning_rate",
            "batch_size",
        ],
        metric_columns=["loss", "mean_accuracy", "training_iteration"],
    )

    # pass the constants to the train function
    train_function_with_parameters = tune.with_parameters(
        mnist_tuning,
        num_epochs=NUM_EPOCHS,
        num_gpus=AVAIL_GPUS,
        data_dir=PATH_DATASET,
    )
    resources_per_trial = {"cpu": 1, "gpu": AVAIL_GPUS}

    # run the tuning
    tuning_results = tune.run(
        train_function_with_parameters,
        resources_per_trial=resources_per_trial,
        metric="loss",
        mode="min",
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter,
        name="run_mnist_tuning",
    )

    return tuning_results

In [None]:
if IN_COLAB:
    tuning_results = run_mnist_tuning()

And now we can see the best hyperparamters:

In [None]:
if IN_COLAB:
    print(tuning_results.best_config)

## [Transfer learning](https://youtu.be/yofjFQddwHE)

Transfer learning is where a model that has been pre-trained on a problem is transferred to another similar problem.  

[For example](https://keras.io/examples/vision/image_classification_efficientnet_fine_tuning/), if a model learnt to classify images it can be transferred to other image classification problems.

This works because the pre-trained model has _extracted features_ from the similar task, for example:

- Lower-level: lines, curves, etc.
- Higher-level: eyes, ears, etc.

Commonly, the steps of transfer learning are:

- Instantiate a pre-trained model and load pre-trained weights into it.
- Extract layers and freeze them to retain their information (e.g., for lower level features).
- Add new trainable layers on top of these frozen layers.
- Train the new layers on your dataset.
- Optional: Unfreeze the frozen layers and train these on the dataset with a very low learning rate (i.e., [fine tuning](https://www.tensorflow.org/tutorials/images/transfer_learning#fine_tuning)).
    - This is done when the training dataset is larger and very similar to the original pre-trained dataset.

Transfer learning is useful when:

- You have a small dataset.
- You want to take advantage of huge models without the costs of training them.

### Example: Transfer learning

#### [TensorFlow (Keras)](https://www.tensorflow.org/tutorials/images/transfer_learning)

There are range of pre-trained models available in [Keras Applications](https://keras.io/api/applications/).

Download and split the data:

In [None]:
if IN_COLAB:
    # tfds.disable_progress_bar()

    ds_train, ds_val, ds_test = tfds.load(
        "cats_vs_dogs",
        # Reserve 10% for validation and 10% for test
        split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
        as_supervised=True,  # Include labels
    )

    print("Number of training samples: %d" % tf.data.experimental.cardinality(ds_train))
    print("Number of validation samples: %d" % tf.data.experimental.cardinality(ds_val))
    print("Number of test samples: %d" % tf.data.experimental.cardinality(ds_test))

View a few of the training samples:

In [None]:
if IN_COLAB:
    labels = {0: "Cat", 1: "Dog"}

    plt.figure(figsize=(10, 10))
    for i, (image, label) in enumerate(ds_train.take(9)):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(image)
        plt.title(labels[int(label)])
        plt.axis("off")

First, preprocess the data.

This is to standardise the image sizes and normalise the pixel values.

In [None]:
if IN_COLAB:
    IMAGE_SIZE = (150, 150)

    ds_train = ds_train.map(lambda x, y: (tf.image.resize(x, IMAGE_SIZE), y))
    ds_val = ds_val.map(lambda x, y: (tf.image.resize(x, IMAGE_SIZE), y))
    ds_test = ds_test.map(lambda x, y: (tf.image.resize(x, IMAGE_SIZE), y))

Now, we can [cache](cache_tf), [batch](batch_tf), and [prefetch](prefetch_tf) the data:

In [None]:
if IN_COLAB:
    BATCH_SIZE = 32
    AUTOTUNE = tf.data.AUTOTUNE

    ds_train = ds_train.cache().batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
    ds_val = ds_val.cache().batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
    ds_test = ds_test.cache().batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)

Apply some [data augmentation](data_augmentation):

In [None]:
data_augmentation = tf.keras.Sequential(
    [
        tf.keras.layers.RandomFlip("horizontal"),
        tf.keras.layers.RandomRotation(0.1),
    ]
)

Now, build the model.

Steps:

- First, load the pre-trained model as the base.
    - This example uses the [Xception](https://keras.io/api/applications/xception/) model (other options are [here](https://keras.io/api/applications/)).
    - The weights were pre-trained on ImageNet (a large image dataset).
    - The top layer from the Xception classifier is not included (i.e., the one that predicts the ImageNet categories).
- Freeze the layers from the base model.
- Add new layers on top.

```{note}

[Batch Normalisation](https://youtu.be/em6dfRxYkYU) layers should be kept in inference mode (i.e., the base model should keep `training = False`) when the base model is unfrozen for fine-tuning.

More information [here](https://www.tensorflow.org/tutorials/images/transfer_learning#important_note_about_batchnormalization_layers).

```

In [None]:
base_model = tf.keras.applications.Xception(
    weights="imagenet",  # use weights from ImageNet
    input_shape=(150, 150, 3),
    include_top=False,  # don't include Xceptions classifier for ImageNet
)
base_model.trainable = False  # freeze the pre-trained model

inputs = tf.keras.Input(shape=(150, 150, 3))  # create new model on top
x = data_augmentation(inputs)  # apply data augmentation

scale_layer = tf.keras.layers.Rescaling(
    scale=1.0 / 127.5, offset=-1
)  # rescale from [0, 255] to [-1.0, 1.0]
x = scale_layer(x)

x = base_model(
    x, training=False
)  # keep BatchNorm layers in inference mode (see note above)
x = tf.keras.layers.GlobalAveragePooling2D()(x)  # convert 2D locations to vector
x = tf.keras.layers.Dropout(0.2)(x)  # for regularisation

outputs = tf.keras.layers.Dense(1)(x)  # prediction layer

model = tf.keras.Model(inputs, outputs)

model.summary()

1% (approx. 2,000) of the parameters are trainable, with 99% (approx. 20,000,000) from the pre-trained model.

That will save a lot of computation.

Now, train the top layers only:

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.BinaryAccuracy(name="accuracy")],
)

In [None]:
if IN_COLAB:
    # just 1 epoch for the demonstration, would increase this to say 20 normally
    NUM_EPOCHS = 1
    model.fit(ds_train, epochs=NUM_EPOCHS, validation_data=ds_val)

Now you could perform [fine tuning](https://www.tensorflow.org/tutorials/images/transfer_learning#fine_tuning).

This is where you unfreeze the frozen layers and train these on the dataset with a very low learning rate.

(tf_transfer_learning)=
#### [TensorFlow Hub](https://www.tensorflow.org/tutorials/images/transfer_learning_with_hub)

There are many [pre-trained models](https://tfhub.dev/) on [TensorFlow Hub](https://www.tensorflow.org/hub/overview) for use in transfer learning.

Download the data:

In [None]:
data_root = tf.keras.utils.get_file(
    "flower_photos",
    "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz",
    untar=True,
)

In [None]:
BATCH_SIZE = 32
IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224

ds_train = tf.keras.utils.image_dataset_from_directory(
    str(data_root),
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
    batch_size=BATCH_SIZE,
)

val_ds = tf.keras.utils.image_dataset_from_directory(
    str(data_root),
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
    batch_size=BATCH_SIZE,
)

In [None]:
class_names = np.array(ds_train.class_names)
num_classes = len(class_names)
class_names

Pre-process the data:

In [None]:
normalization_layer = tf.keras.layers.Rescaling(1.0 / 255)
ds_train = ds_train.map(lambda x, y: (normalization_layer(x), y))
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))

Cache and prefetch the data:

In [None]:
AUTOTUNE = tf.data.AUTOTUNE
ds_train = ds_train.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

Download the headless pre-trained model (i.e., without the top layer):

In [None]:
mobilenet_v2 = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"
inception_v3 = "https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4"

feature_extractor_model = (
    mobilenet_v2  # could change to inception_v2, or choose a different one
)

The download location defaults to a [local temporary directory](https://www.tensorflow.org/hub/caching).

To change this, set the following:

```python
import os
os.environ["TFHUB_CACHE_DIR"] = "/nobackup/username/tf_hub_modules"
```

Create the headless pre-trained _layer_ by setting trainable to False:

In [None]:
import tensorflow_hub as hub

In [None]:
feature_extractor_layer = hub.KerasLayer(
    feature_extractor_model,
    input_shape=(IMAGE_HEIGHT, IMAGE_WIDTH, 3),
    trainable=False,  # remember to set this to False
)

Create the new model with the pre-trained layer as the base and a new classification layer on top:

In [None]:
model = tf.keras.Sequential(
    [feature_extractor_layer, tf.keras.layers.Dense(num_classes)]
)

model.summary()

Compile the new model:

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

Train the new model:

In [None]:
if IN_COLAB:
    NUM_EPOCHS = 3
    history = model.fit(ds_train, validation_data=val_ds, epochs=NUM_EPOCHS)

#### [PyTorch (Lightning)](https://pytorch-lightning.readthedocs.io/en/stable/advanced/transfer_learning.html)

PyTorch Lightning can use a pre-trained model as a backbone in the Model definition.

In [1]:
import torch
import torchvision.models as models
from pytorch_lightning import LightningModule

In [None]:
class ImagenetTransferLearning(LightningModule):
    def __init__(self):
        super().__init__()

        # init a pretrained resnet
        backbone = models.resnet50(pretrained=True)

        # extract some setup details
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)

        # use the pretrained model to classify cifar-10 (10 image classes)
        num_target_classes = 10
        self.classifier = nn.Linear(num_filters, num_target_classes)

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        # ... the rest of your LightningModule ...

#### [PyTorch Lightning Bolts](https://lightning-bolts.readthedocs.io/en/latest/introduction_guide.html#for-pretrained-models)

PyTorch Lightning Bolts has many pre-trained models ready for you to import.

In [None]:
from pl_bolts.models.autoencoders import VAE

In [None]:
model = VAE(input_height=32, pretrained="imagenet2012")

In [None]:
encoder = model.encoder
encoder.eval()

(torch_transfer_learning)=
#### [PyTorch Lightning-Flash](https://lightning-flash.readthedocs.io/en/latest/general/finetuning.html)

If you [remember](pytorch), PyTorch Lightning-Flash is the even higher-level API for PyTorch that provides abstractions above PyTorch Lightning for fast prototyping.

In [None]:
import os

import flash
import torch
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
from pytorch_lightning import seed_everything

In [None]:
seed_everything(42)

Download and organize the data:

In [None]:
data_path = f"{os.getcwd()}/data"

In [None]:
download_data(
    "https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", f"{data_path}/"
)

datamodule = ImageClassificationData.from_folders(
    train_folder=f"{data_path}/hymenoptera_data/train/",
    val_folder=f"{data_path}/hymenoptera_data/val/",
    test_folder=f"{data_path}/hymenoptera_data/test/",
    batch_size=4,
    transform_kwargs={
        "image_size": (196, 196),
        "mean": (0.485, 0.456, 0.406),
        "std": (0.229, 0.224, 0.225),
    },
)

Build the model using desired pre-trained model (e.g., [ResNet18](https://arxiv.org/pdf/1512.03385.pdf)):

In [None]:
model = ImageClassifier(backbone="resnet18", labels=datamodule.labels)

Create the trainer (run one epoch for demo):

In [None]:
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())

Finetune the model.

`strategy="freeze"` keeps the pre-trained model (backbone) frozen throughout.

In [None]:
if IN_COLAB:
    trainer.finetune(
        model,
        datamodule=datamodule,
        strategy="freeze",  # remember to have the strategy set to "freeze"
    )

Save the model:

In [None]:
model_path = f"{os.getcwd()}/models"

In [None]:
if IN_COLAB:
    trainer.save_checkpoint(f"{model_path}/image_classification_model.pt")

## [Callbacks](https://keras.io/api/callbacks/)

Callbacks are objects that get called by the model at different points during training, often after each batch or epoch.

For example, they could be used to:

- Save a model version at regularly intervals or once attained a metric threshold (i.e., checkpointing).
- Monitor and profile the training progress (i.e., TensorBoard).
- Change the learning rate when the training plateaus.
- Fine tuning when the training plateaus.

### Checkpoints

You can save a model (/ model weights) at custom checkpoints e.g., the latest best model in terms of accuracy per epoch.

The `save_best_only` option is useful as after training only the best model will have been saved.

#### [TensorFlow (Keras)](https://www.tensorflow.org/guide/checkpoint)

In [None]:
path_models = f"{os.getcwd()}/models"

In [None]:
callback_tf_model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=path_models + "/model_{epoch}",
    save_freq="epoch",  # save a model version at the end of each epoch
    save_best_only=True,  # only save a model if val_accuracy improved
    monitor="val_accuracy",
    # save_weights_only=True
)

#### PyTorch (Lightning)

Checkpointing enabled by default in PyTorch Lightning.

Model checkpoints are saved per epoch to `lightning_logs/version_X/checkpoints`.

### Fault tolerance

For longer or distributed training, it's helpful to save the model at regular intervals in case it crashes during training.

#### [TensorFlow (Keras)](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#fault_tolerance)

This is done via [BackupAndRestore](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#checkpoint_saving_and_restoring) (previously was done using ModelCheckpoint).

In [None]:
callback_tf_model_backup = tf.keras.callbacks.BackupAndRestore(
    backup_dir=f"{path_models}/backup"
)

#### [PyTorch (Lightning)](https://pytorch-lightning.readthedocs.io/en/stable/advanced/fault_tolerant_training.html)

By default, if `Trainer.fit()` fails, then it can be restarted automatically from the _beginning_ of the epoch it failed on.

With Fault Tolerant Training, when `Trainer.fit()` fails in the middle of an epoch during training or validation, Lightning will restart exactly where it failed, and everything will be restored.

This is enabled via:

```bash
PL_FAULT_TOLERANT_TRAINING=1 python script.py
```

### Logging and Profiling

[Tensorboard](https://www.tensorflow.org/tensorboard) is a browser-based application that provides live plots of loss and metrics for training and evaluation.

#### [TensorFlow (Keras)](https://www.tensorflow.org/guide/profiler)

In [None]:
callback_tf_tensorboard_with_profiling = tf.keras.callbacks.TensorBoard(
    log_dir=f"{path_models}/logs",
    profile_batch=(1, 5),  # profile batches 1 to 5
    update_freq="epoch",
)

In [None]:
# may need to remove previous logs
# import shutil
# shutil.rmtree(f"{path_models}/logs")

View them with (we'll see an example in a bit):

```bash
tensorboard --logdir=/full_path_to_your_logs
```

Also, in-line in [Jupyter Notebooks / Google Colab](https://www.tensorflow.org/tensorboard/tensorboard_in_notebooks):

```python
from tensorboard import notebook
notebook.list() # list open tensorboard instances
notebook.display(port=6006, height=1000) # open tensorboard

# or

%load_ext tensorboard
%tensorboard --logdir logs
```

#### PyTorch (Lightning)

##### [Logging](https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html)

PyTorch Lightning uses TensorBoard for logging by default.  

As before, you can view the logs using:

```python
%load_ext tensorboard
%tensorboard --logdir lightning_logs/
```

##### Profiling

You can profile (time and memory) with [PyTorch](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) using the `torch.autograd.profiler` context manager:

In [None]:
import torch
import torch.autograd.profiler as profiler
import torchvision.models as models

In [None]:
model = models.resnet18()
inputs = torch.randn(5, 3, 244, 244)

In [None]:
with profiler.profile() as prof:
    with profiler.record_function("model_inference"):
        model(inputs)

This can then help you find bottlenecks in the code:

In [None]:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=5))

In [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/advanced/profiler.html), for a simple profiler over method calls and time spent calling them, use:

```python
Trainer(profiler=True)
```

For more control, use:

```python
from pytorch_lightning.profiler import AdvancedProfiler

Trainer(profiler=AdvancedProfiler())
```

### Early stopping

Stop training early when a monitored metric has stopped improving.

#### [TensorFlow (Keras)](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/EarlyStopping)

In [None]:
callback_tf_early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor="val_accuracy",  # quantity to be monitored
    patience=2,  # "no longer improving" also means "for at least 2 epochs"
)

#### [PyTorch (Lightning)](https://pytorch-lightning.readthedocs.io/en/stable/common/early_stopping.html#)

In [None]:
callback_torch_early_stopping = pl.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=2,
)

### Learning rate decay

Reduce learning rate when a metric has stopped improving (i.e., where it's reached a plateau).

#### [TensorFlow (Keras)](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ReduceLROnPlateau)

In [None]:
callback_tf_learning_rate_decay = tf.keras.callbacks.ReduceLROnPlateau(
    monitor="val_accuracy",
    patience=5,
    factor=0.2,  # factor by which the learning rate will be reduced: new_lr = lr * factor
    min_lr=0.001,  # lower bound on the learning rate
)

### Example: Callbacks

#### TensorFlow (Keras)

In [None]:
callbacks_tf = [
    callback_tf_model_checkpoint,
    callback_tf_model_backup,
    callback_tf_tensorboard_with_profiling,
    callback_tf_early_stopping,
    callback_tf_learning_rate_decay,
]

In [None]:
if IN_COLAB:
    epochs = 3
    model.fit(ds_train, epochs=epochs, validation_data=ds_val, callbacks=callbacks_tf)

View the results in TensorBoard:

In [None]:
# from tensorboard import notebook
# notebook.list()
# notebook.display(port=6006, height=1000)

In [None]:
if IN_COLAB:
    %load_ext tensorboard
    %tensorboard --logdir /content/models/logs

#### PyTorch (Lightning)

In [None]:
# callbacks_torch = [
#     callback_torch_early_stopping,
# ]

In [None]:
# Trainer(callbacks=callbacks_torch)

## Questions

```{admonition} Question 1

What are possible hyperparameters that could be tuned?

- Learning rate and the number of units.
- Weights and biases.

```

```{admonition} Question 2

What is transfer learning?

```

```{admonition} Question 3

Why is transfer learning useful?

```

```{admonition} Question 4

What is a key step in transfer learning?

```

```{admonition} Question 5

What are callbacks?

```

```{admonition} Question 6

Name three examples of callbacks.

```

## {ref}`Solutions <models>`

## Key Points

```{important}

- [x] _Tune hyperparamaters for the best model fit._
- [x] _Use transfer learning to save computation on similar problems._
- [x] _Consider using callbacks to help with model training, such as:_
    - [x] _Checkpoints._
    - [x] _Fault tolerance._
    - [x] _Logging._
    - [x] _Profiling._
    - [x] _Early stopping._
    - [x] _Learning rate decay._

```

## Further information

### Good practices

- See if there is a model architecture (and parameters) that already addresses the task.
- Consider the tradeoff between model complexity and size.
    - For high accuracy, maybe need a large and complex model.
    - For less precision, smaller models use less disk space, memory, and are faster.
- Best practices for [PyTorch Lightning model training](https://pytorch-lightning.readthedocs.io/en/stable/guides/speed.html).
- Consider sharing your model for repoducibility if you can.

### Other options

There are many other options for hyperparameters tuning, including:

- [Hyperopt](http://hyperopt.github.io/hyperopt/)
    - A general purpose optimisation library.
- [scikit-optimise](https://scikit-optimize.github.io/stable/)
    - A general purpose optimisation library.
- [TPOT](http://epistasislab.github.io/tpot/)
    - Automated optimisation library using genetic programming.
 
### Resources

#### General

- [Model Zoo](https://modelzoo.co/)
- [Papers with code - Models](https://paperswithcode.com/methods)
- [HuggingFace - Models](https://huggingface.co/models)

#### TensorFlow

- [TensorFlow Model Garden](https://github.com/tensorflow/models/tree/master/official) for model source code.

#### PyTorch

- [PyTorch Hub](https://pytorch.org/docs/stable/hub.html) for pre-retrained models.
- [Torch Vision Models](https://pytorch.org/vision/stable/models.html)
- [Torch Text Models](https://pytorch.org/text/stable/models.html)
- [Torch Audio Models](https://pytorch.org/audio/stable/models.html)
- [TIMM (pyTorch IMage Models)](https://rwightman.github.io/pytorch-image-models/)