Example
example.py

example.py

This page contains the full source code for the example.py file. This file is a simple example of a basic Python scripts to train Machine Learning models using PyTorch.

You can download the example.py file here.

import os
import argparse
from datetime import datetime
 
import yaml
import tqdm
 
import numpy as np
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
 
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
 
import wandb
 
 
class MNISTDataset(Dataset):
    """MNIST Dataset
 
    This class is already implemented in PyTorch, but we reimplement
    it here to show how to create a custom dataset.
 
    We show two modes, one that loads the data in memory (preload=True)
    and another that loads the data on the fly (preload=False).
    """
 
    def __init__(
        self,
        path: str,
        train: bool,
        preload: bool,
        transform: callable = None,
    ):
        """
        Args:
            path (str): Path to the data
            train (bool): If True, load the training data, otherwise load the test data
            preload (bool): If True, load the data in memory, otherwise load the data on the fly
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data = None
        self.train = train
        self.set_path = os.path.join(path, "train" if train else "test")
        self.labels = np.load(os.path.join(self.set_path, "labels.npy"))
        self.preload = preload
        self.transform = transform
        self.paths = [
            os.path.join(self.set_path, i)
            for i in os.listdir(self.set_path)
            if i.startswith("image_")
        ]
        self.paths = sorted(
            self.paths, key=lambda x: int(x.split("_")[-1].split(".")[0])
        )
 
        if self.preload:
            self.data = []
            for path in self.paths:
                # load image and change type to float32 and values to be between 0 and 1
                self.data.append(
                    np.repeat(np.expand_dims(np.load(path), axis=0), 3, axis=0).astype(
                        np.float32
                    )
                    / 255.0
                )
            self.data = torch.from_numpy(np.stack(self.data))
 
    def __len__(self):
        return len(self.paths)
 
    def __getitem__(self, idx):
        x = None
        y = self.labels[idx]
        if self.preload:
            x = self.data[idx]
        else:
            x = torch.from_numpy(
                np.repeat(
                    np.expand_dims(np.load(self.paths[idx]), axis=0), 3, axis=0
                ).astype(np.float32)
                / 255.0
            )
 
        if self.transform is not None:
            x = self.transform(x)
 
        return x, y
 
 
class Model(nn.Module):
    """Example of a model using PyTorch
 
    This specific model is a pretrained ResNet18 model from torch hub
    with a custom head.
 
    """
 
    def __init__(self):
        super().__init__()
 
        self.base = torch.hub.load(
            "pytorch/vision:v0.10.0", "mobilenet_v2", pretrained=False
        )
        self.base.classifier = nn.Sequential(
            nn.LazyLinear(512), nn.Dropout(), nn.Linear(512, 10)
        )
 
    def forward(self, x):
        x = self.base(x)
        return x
 
 
def train(
    model: nn.Module,
    train_dataloader: DataLoader,
    test_dataloader: DataLoader,
    hyperparameters: dict,
    args: object,
    device: str,
):
    """Train model
 
    Args:
        model (nn.Module): Model to train
        train_dataloader (DataLoader): Dataloader to use for training
        test_dataloader (DataLoader): Dataloader to use for testing
        hyperparameters (dict): Dictionary with hyperparameters
        device (str): Device to use when training the model
    """
    # Create optimizer and loss function
    optimizer = torch.optim.SGD(model.parameters(), lr=hyperparameters["lr"])
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=hyperparameters["scheduler_step_size"],
        gamma=hyperparameters["scheduler_gamma"],
    )
    criterion = nn.CrossEntropyLoss()
 
    for epoch in range(hyperparameters["epochs"]):
        # Set model to train mode
        model.train()
 
        # Create tqdm progress bar
        with tqdm.tqdm(total=len(train_dataloader), desc=f"Epoch {epoch}") as pbar:
            for (x, y) in train_dataloader:
                # Move batch to device
                x = x.to(device)
                y = y.to(device)
 
                # Forward pass
                y_hat = model(x)
 
                # Compute loss
                loss = criterion(y_hat, y)
 
                # Zero gradients
                optimizer.zero_grad()
 
                # Backward pass
                loss.backward()
 
                # Update weights
                optimizer.step()
 
                # Update progress bar
                pbar.update(1)
                pbar.set_postfix(loss=loss.item())
 
                # Log loss to wandb
                wandb.log({"loss": loss.item(), "lr": scheduler.get_last_lr()[0]})
 
        # Update learning rate
        scheduler.step()
 
        # Evaluate model
        test(model, test_dataloader, device)
 
        # Save model, scheduler and optimizer checkpoint
        torch.save(
            {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "epoch": epoch,
            },
            os.path.join(args.out_dir, "last.pt"),
        )
 
 
def test(
    model: nn.Module,
    dataloader: DataLoader,
    device: torch.device,
    final_eval: bool = False,
):
    """Test model by plotting the embeddings in 2D space
 
    Args:
        model (nn.Module): Model to test
        dataloader (DataLoader): Dataloader to use for testing
    """
    # Create Loss function
    criterion = nn.CrossEntropyLoss()
 
    # Set model to eval mode
    model.eval()
 
    logs_dict = {}
 
    losses = []
 
    # List to store predictions and labels
    predictions = []
    labels = []
 
    # Disable gradient computation during evaluation
    with torch.no_grad():
        for (x, y) in dataloader:
            # Move batch to device
            x = x.to(device)
            y = y.to(device)
 
            # Forward pass
            output = model(x)
            prediction = torch.argmax(output, dim=1)
 
            losses.append(criterion(output, y).item())
            predictions.append(prediction.cpu().numpy())
            labels.append(y.cpu().numpy())
 
        # Concatenate embeddings into a np array of shape: (num_samples, embedding_dim)
        predictions = np.concatenate(predictions)
        labels = np.concatenate(labels)
 
        # Compute loss
        loss = np.mean(losses)
        logs_dict["Test Loss"] = loss
 
        # Compute accuracy
        accuracy = (predictions == labels).sum() / labels.shape[0]
        logs_dict["Test Accuracy"] = accuracy
 
        if final_eval:
            # Compute Confusion Matrix
            cm = confusion_matrix(labels, predictions)
            cm_display = ConfusionMatrixDisplay(cm).plot()
            cm_display.figure_.savefig(
                os.path.join(args.out_dir, "confusion_matrix.png")
            )
            logs_dict["Confusion Matrix"] = wandb.Image(
                os.path.join(args.out_dir, "confusion_matrix.png")
            )
 
        # Log to wandb
        wandb.log(logs_dict)
 
 
if __name__ == "__main__":
    # Define args
    parser = argparse.ArgumentParser(
        description="Example of good practices in an ML script"
    )
    parser.add_argument(
        "--hyperparameters",
        default="example.yml",
        type=str,
        help="Path to yaml file with hyperparameters",
    )
    parser.add_argument(
        "--out_dir",
        default="./output",
        type=str,
        help="Path to the output directory where the model will be saved",
    )
    parser.add_argument(
        "--data_path",
        default="/media/mnist",
        type=str,
        help="Path to the data",
    )
    parser.add_argument(
        "--preload",
        action="store_true",
        help="Whether to preload the data in memory",
    )
    parser.add_argument(
        "--num_workers",
        default=8,
        type=int,
        help="How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
    )
    parser.add_argument(
        "--prefetch_factor",
        default=2,
        type=int,
        help="Number of batches loaded in advance by each worker.",
    )
    parser.add_argument(
        "--experiment_name",
        default="example",
        type=str,
        help="Name to identify the experiment",
    )
    parser.add_argument(
        "--device",
        default="cuda:0",
        type=str,
        help="Device to use when training the model",
    )
    args = parser.parse_args()
    with open(args.hyperparameters, "r") as f:
        hyperparameters = yaml.safe_load(f)
 
    # Create output directory
    args.out_dir = os.path.join(args.out_dir, args.experiment_name)
    if os.path.isdir(args.out_dir):
        args.experiment_name = (
            args.experiment_name + "_" + datetime.now().strftime("%Y%m%d-%H%M%S")
        )
        args.out_dir = os.path.join(
            args.out_dir,
            args.experiment_name,
        )
    os.makedirs(args.out_dir, exist_ok=True)
 
    # Set device
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
 
    # Set seed
    torch.manual_seed(hyperparameters["seed"])
    np.random.seed(hyperparameters["seed"])
 
    # Create model
    model = Model()
 
    # Move model to device
    model.to(device)
 
    # Create data augmentations
    data_transform = T.Compose(
        [
            T.Normalize((0.1307,), (0.3081,)),
        ]
    )
 
    # Create datasets
    train_dataset = MNISTDataset(args.data_path, True, args.preload, data_transform)
 
    test_dataset = MNISTDataset(args.data_path, False, args.preload, data_transform)
 
    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=hyperparameters["batch_size"],
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=args.prefetch_factor,
    )
 
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=hyperparameters["batch_size"],
        num_workers=args.num_workers,
        pin_memory=True,
        prefetch_factor=args.prefetch_factor,
    )
 
    # Create Logger
    wandb.init(
        project="example",
        name=args.experiment_name,
        config=hyperparameters,
        save_code=True,
    )
 
    # Train model
    train(model, train_dataloader, test_dataloader, hyperparameters, args, device)
 
    # Test model
    test(model, test_dataloader, device, final_eval=True)