example.py
This page contains the full source code for the example.py file. This file is a simple example of a basic Python scripts to train Machine Learning models using PyTorch.
You can download the example.py file here.
import os
import argparse
from datetime import datetime
import yaml
import tqdm
import numpy as np
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import wandb
class MNISTDataset(Dataset):
"""MNIST Dataset
This class is already implemented in PyTorch, but we reimplement
it here to show how to create a custom dataset.
We show two modes, one that loads the data in memory (preload=True)
and another that loads the data on the fly (preload=False).
"""
def __init__(
self,
path: str,
train: bool,
preload: bool,
transform: callable = None,
):
"""
Args:
path (str): Path to the data
train (bool): If True, load the training data, otherwise load the test data
preload (bool): If True, load the data in memory, otherwise load the data on the fly
transform (callable, optional): Optional transform to be applied on a sample.
"""
self.data = None
self.train = train
self.set_path = os.path.join(path, "train" if train else "test")
self.labels = np.load(os.path.join(self.set_path, "labels.npy"))
self.preload = preload
self.transform = transform
self.paths = [
os.path.join(self.set_path, i)
for i in os.listdir(self.set_path)
if i.startswith("image_")
]
self.paths = sorted(
self.paths, key=lambda x: int(x.split("_")[-1].split(".")[0])
)
if self.preload:
self.data = []
for path in self.paths:
# load image and change type to float32 and values to be between 0 and 1
self.data.append(
np.repeat(np.expand_dims(np.load(path), axis=0), 3, axis=0).astype(
np.float32
)
/ 255.0
)
self.data = torch.from_numpy(np.stack(self.data))
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
x = None
y = self.labels[idx]
if self.preload:
x = self.data[idx]
else:
x = torch.from_numpy(
np.repeat(
np.expand_dims(np.load(self.paths[idx]), axis=0), 3, axis=0
).astype(np.float32)
/ 255.0
)
if self.transform is not None:
x = self.transform(x)
return x, y
class Model(nn.Module):
"""Example of a model using PyTorch
This specific model is a pretrained ResNet18 model from torch hub
with a custom head.
"""
def __init__(self):
super().__init__()
self.base = torch.hub.load(
"pytorch/vision:v0.10.0", "mobilenet_v2", pretrained=False
)
self.base.classifier = nn.Sequential(
nn.LazyLinear(512), nn.Dropout(), nn.Linear(512, 10)
)
def forward(self, x):
x = self.base(x)
return x
def train(
model: nn.Module,
train_dataloader: DataLoader,
test_dataloader: DataLoader,
hyperparameters: dict,
args: object,
device: str,
):
"""Train model
Args:
model (nn.Module): Model to train
train_dataloader (DataLoader): Dataloader to use for training
test_dataloader (DataLoader): Dataloader to use for testing
hyperparameters (dict): Dictionary with hyperparameters
device (str): Device to use when training the model
"""
# Create optimizer and loss function
optimizer = torch.optim.SGD(model.parameters(), lr=hyperparameters["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=hyperparameters["scheduler_step_size"],
gamma=hyperparameters["scheduler_gamma"],
)
criterion = nn.CrossEntropyLoss()
for epoch in range(hyperparameters["epochs"]):
# Set model to train mode
model.train()
# Create tqdm progress bar
with tqdm.tqdm(total=len(train_dataloader), desc=f"Epoch {epoch}") as pbar:
for (x, y) in train_dataloader:
# Move batch to device
x = x.to(device)
y = y.to(device)
# Forward pass
y_hat = model(x)
# Compute loss
loss = criterion(y_hat, y)
# Zero gradients
optimizer.zero_grad()
# Backward pass
loss.backward()
# Update weights
optimizer.step()
# Update progress bar
pbar.update(1)
pbar.set_postfix(loss=loss.item())
# Log loss to wandb
wandb.log({"loss": loss.item(), "lr": scheduler.get_last_lr()[0]})
# Update learning rate
scheduler.step()
# Evaluate model
test(model, test_dataloader, device)
# Save model, scheduler and optimizer checkpoint
torch.save(
{
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"epoch": epoch,
},
os.path.join(args.out_dir, "last.pt"),
)
def test(
model: nn.Module,
dataloader: DataLoader,
device: torch.device,
final_eval: bool = False,
):
"""Test model by plotting the embeddings in 2D space
Args:
model (nn.Module): Model to test
dataloader (DataLoader): Dataloader to use for testing
"""
# Create Loss function
criterion = nn.CrossEntropyLoss()
# Set model to eval mode
model.eval()
logs_dict = {}
losses = []
# List to store predictions and labels
predictions = []
labels = []
# Disable gradient computation during evaluation
with torch.no_grad():
for (x, y) in dataloader:
# Move batch to device
x = x.to(device)
y = y.to(device)
# Forward pass
output = model(x)
prediction = torch.argmax(output, dim=1)
losses.append(criterion(output, y).item())
predictions.append(prediction.cpu().numpy())
labels.append(y.cpu().numpy())
# Concatenate embeddings into a np array of shape: (num_samples, embedding_dim)
predictions = np.concatenate(predictions)
labels = np.concatenate(labels)
# Compute loss
loss = np.mean(losses)
logs_dict["Test Loss"] = loss
# Compute accuracy
accuracy = (predictions == labels).sum() / labels.shape[0]
logs_dict["Test Accuracy"] = accuracy
if final_eval:
# Compute Confusion Matrix
cm = confusion_matrix(labels, predictions)
cm_display = ConfusionMatrixDisplay(cm).plot()
cm_display.figure_.savefig(
os.path.join(args.out_dir, "confusion_matrix.png")
)
logs_dict["Confusion Matrix"] = wandb.Image(
os.path.join(args.out_dir, "confusion_matrix.png")
)
# Log to wandb
wandb.log(logs_dict)
if __name__ == "__main__":
# Define args
parser = argparse.ArgumentParser(
description="Example of good practices in an ML script"
)
parser.add_argument(
"--hyperparameters",
default="example.yml",
type=str,
help="Path to yaml file with hyperparameters",
)
parser.add_argument(
"--out_dir",
default="./output",
type=str,
help="Path to the output directory where the model will be saved",
)
parser.add_argument(
"--data_path",
default="/media/mnist",
type=str,
help="Path to the data",
)
parser.add_argument(
"--preload",
action="store_true",
help="Whether to preload the data in memory",
)
parser.add_argument(
"--num_workers",
default=8,
type=int,
help="How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--prefetch_factor",
default=2,
type=int,
help="Number of batches loaded in advance by each worker.",
)
parser.add_argument(
"--experiment_name",
default="example",
type=str,
help="Name to identify the experiment",
)
parser.add_argument(
"--device",
default="cuda:0",
type=str,
help="Device to use when training the model",
)
args = parser.parse_args()
with open(args.hyperparameters, "r") as f:
hyperparameters = yaml.safe_load(f)
# Create output directory
args.out_dir = os.path.join(args.out_dir, args.experiment_name)
if os.path.isdir(args.out_dir):
args.experiment_name = (
args.experiment_name + "_" + datetime.now().strftime("%Y%m%d-%H%M%S")
)
args.out_dir = os.path.join(
args.out_dir,
args.experiment_name,
)
os.makedirs(args.out_dir, exist_ok=True)
# Set device
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
# Set seed
torch.manual_seed(hyperparameters["seed"])
np.random.seed(hyperparameters["seed"])
# Create model
model = Model()
# Move model to device
model.to(device)
# Create data augmentations
data_transform = T.Compose(
[
T.Normalize((0.1307,), (0.3081,)),
]
)
# Create datasets
train_dataset = MNISTDataset(args.data_path, True, args.preload, data_transform)
test_dataset = MNISTDataset(args.data_path, False, args.preload, data_transform)
# Create dataloaders
train_dataloader = DataLoader(
train_dataset,
batch_size=hyperparameters["batch_size"],
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True,
prefetch_factor=args.prefetch_factor,
)
test_dataloader = DataLoader(
test_dataset,
batch_size=hyperparameters["batch_size"],
num_workers=args.num_workers,
pin_memory=True,
prefetch_factor=args.prefetch_factor,
)
# Create Logger
wandb.init(
project="example",
name=args.experiment_name,
config=hyperparameters,
save_code=True,
)
# Train model
train(model, train_dataloader, test_dataloader, hyperparameters, args, device)
# Test model
test(model, test_dataloader, device, final_eval=True)