Skip to content

A distributed training example

The following presents a MWE of a multi-node/multi-GPU PyTorch script. This script demonstrates the integration of SLURM with PyTorch for distributed training across GPUs on Meluxina. It might be used as a basis to users willing to move their PyTorch training script from serial to distributed.

For the sake of the example, we used a Resnet50 model from the torchvision module. Feel free to tweak the slurm launcher and the training script to make some tests!

Source code

To reproduce this example, you will need to create tree separate files in the directory from which you would like to work:

  • a first python snippet getPort.py that will allow us to choose the port on which the node responsible to orchestrate the communication when necessary during the training,
  • a second python script resnet50_LXP.py containing the training code,
  • a slurm script slurm_resnet50_LXP.sh that we will use to launch the training

To test this example, execute the following command sbatch slurm_resnet50_LXP.sh.

import socket

def find_free_port():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(('', 0))  # Bind to a free port provided by the host.
        return s.getsockname()[1]  # Return the port number assigned.

if __name__ == "__main__":
    print(find_free_port())
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
# from datautils import MyTrainDataset

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import torch.distributed as dist
import os
from torchvision import models

def ddp_setup():
    init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    return world_size, rank


class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        save_every: int,
        snapshot_path: str,
    ) -> None:

        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.global_rank = int(os.environ["RANK"])
        self.model = model.to(self.local_rank)
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every
        self.epochs_run = 0
        self.snapshot_path = snapshot_path
        if os.path.exists(snapshot_path):
            print("Loading snapshot")
            self._load_snapshot(snapshot_path)

        self.model = DDP(self.model, device_ids=[self.local_rank])

    def _load_snapshot(self, snapshot_path):
        loc = f"cuda:{self.local_rank}"
        snapshot = torch.load(snapshot_path, map_location=loc)
        self.model.load_state_dict(snapshot["MODEL_STATE"])
        self.epochs_run = snapshot["EPOCHS_RUN"]
        print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

    def _run_batch(self, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source)
        loss = F.cross_entropy(output, targets)
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        b_sz = len(next(iter(self.train_data))[0])
        print(f"[GPU{self.global_rank}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
        self.train_data.sampler.set_epoch(epoch)
        for source, targets in self.train_data:
            source = source.to(self.local_rank)
            targets = targets.to(self.local_rank)
            self._run_batch(source, targets)

    def _save_snapshot(self, epoch):
        snapshot = {
            "MODEL_STATE": self.model.module.state_dict(),
            "EPOCHS_RUN": epoch,
        }
        torch.save(snapshot, self.snapshot_path)
        print(f"Epoch {epoch} | Training snapshot saved at {self.snapshot_path}")


    def train(self, max_epochs: int):
        for epoch in range(self.epochs_run, max_epochs):
            self._run_epoch(epoch)
            if self.local_rank == 0 and epoch % self.save_every == 0:
                self._save_snapshot(epoch)


def load_train_objs(num_epochs, batch_size_per_gpu, world_size):
    train_set = SyntheticDataset(num_epochs, batch_size_per_gpu, world_size)
    model = models.resnet50(pretrained=False)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    return train_set, model, optimizer


def prepare_dataloader(dataset: Dataset, batch_size: int):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        sampler=DistributedSampler(dataset),
        num_workers=32
    )


def main(save_every: int,\
          total_epochs: int,\
              batch_size: int,\
                  snapshot_path: str = os.path.join(os.getcwd(), 'snapshot.pt')):
    world_size, rank = ddp_setup()
    batch_size_per_gpu = int(batch_size / 4 )
    dataset, model, optimizer = load_train_objs(total_epochs, batch_size_per_gpu, world_size)
    train_data = prepare_dataloader(dataset, batch_size)
    trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
    trainer.train(total_epochs)
    print('Will now destroy the process group\n')
    destroy_process_group()
    print('Process group destroyed')
    import sys
    # https://github.com/pytorch/pytorch/issues/76287
    sys.exit(0)


import random 
class SyntheticDataset(Dataset):
    def __init__(self, num_epochs, batch_size_per_gpu, world_size):
        self.num_epochs = num_epochs 
        self.batch_size_per_gpu = batch_size_per_gpu 
        self.num_iters = world_size 
        self.len = self.num_epochs * self.batch_size_per_gpu * self.num_iters

    def __getitem__(self, idx):
        data = torch.randn(3, 224, 224)
        target = random.randint(0, 999)
        return (data, target)

    def __len__(self):
        return self.len 


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='simple distributed training job')
    parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
    parser.add_argument('save_every', type=int, help='How often to save a snapshot')
    parser.add_argument('batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
    args = parser.parse_args()

    main(args.save_every, args.total_epochs, args.batch_size)
#!/bin/bash -l
#SBATCH --job-name="restnetDistri"
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=1
#SBATCH --output=%x%j.out
#SBATCH --error=%x%j.err
#SBATCH -p gpu
#SBATCH -q default
#SBATCH --time=01:10:00
#SBATCH -A lxp

module load env/staging/2023.1
module load PyTorch/2.1.2-foss-2023a-CUDA-12.1.1
module load torchvision/

nodes=($(scontrol show hostnames $SLURM_JOB_NODELIST))
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)

echo "The head node is ${head_node}"

# OPTIONAL: set to true if you want more details on NCCL communications
DEBUG=true
if [ "$DEBUG" == "true" ]; then
    export LOGLEVEL=INFO
    export NCCL_DEBUG=TRACE
    export TORCH_CPP_LG_LEVEL=INFO
else
    echo "Debug mode is off."
fi

# Define the file where PyTorch will make a snapshot in case a training is interrupted and will have to be restarted
snapshot_name="snapshot.pt"
snapshot_file="${PWD}/${snapshot_name}"

if [ -f "$snapshot_file" ]; then
    file_exists=true
    echo "snapshot file found"
else
    file_exists=false
    echo "no snapshot file was found"
fi

remove_snapshot=true
if [ "$remove_snapshot" == "true" ]; then
    if [ -f "$snapshot_file" ]; then
        rm ${snapshot_file}
        echo "snapshot file deleted"
    fi
fi

export NCCL_SOCKET_IFNAME=ib0
export NCCL_ASYNC_ERROR_HANDLING=1
export OMP_NUM_THREADS=8

# get free port
export random_port=$(python getPort.py)

echo "rdvz-endpoint is ${head_node_ip}:${random_port}"
endpoint="${head_node_ip}:${random_port}"

export NGPUS_PER_NODE=4

CUDA_VISIBLE_DEVICES="0,1,2,3" srun --cpus-per-task=8 --wait=60 --ntasks-per-node=1 --kill-on-bad-exit=1 torchrun --max_restarts 3 --nnodes ${SLURM_NNODES} --nproc_per_node ${NGPUS_PER_NODE} --rdzv_id 10000 --rdzv_backend c10d --rdzv_endpoint $endpoint --log_dir ${PWD}/log_torch resnet50_LXP.py 4 1 32