Skip to content

PyTorch

The following presents a MWE of a multi-node/multi-GPU PyTorch script. This script demonstrates the integration of SLURM with PyTorch for distributed training across multiple nodes and GPUs on Meluxina. It might be used as a basis to users willing to move their PyTorch training script from serial to distributed.

Source code

Slurm script

#!/bin/bash -l
#SBATCH --time=01:30:00
#SBATCH --account=lxp
#SBATCH --partition=gpu
#SBATCH --qos=default
#SBATCH --nodes=2
#SBATCH --ntasks=2
#SBATCH --ntasks-per-node=1
#SBATCH --job-name=multinode_training

module load env/release/2022.1
module load PyTorch/1.12.0-foss-2022a-CUDA-11.7.0

nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)

echo "The head node is ${head_node}"

# OPTIONAL: set to true if you want more details on NCCL communications 
DEBUG=true
if [ "$DEBUG" == "true" ]; then
    export LOGLEVEL=INFO
    export NCCL_DEBUG=TRACE
    export TORCH_CPP_LG_LEVEL=INFO 
else
    echo "Debug mode is off."
fi

# Define the file where PyTorch will make a snapshot in case a training is interrupted and will have to be restarted
snapshot_name="snapshot.pt"
snapshot_file="${PWD}/${snapshot_name}"

if [ -f "$snapshot_file" ]; then
    file_exists=true
    echo "snapshot file found"
else
    file_exists=false
    echo "no snapshot file was found"
fi

remove_snapshot=true
if [ "$remove_snapshot" == "true" ]; then
    if [ -f "$snapshot_file" ]; then
        rm ${snapshot_file}
        echo "snapshot file deleted"
    fi
fi

export NCCL_SOCKET_IFNAME=ib0
export NCCL_ASYNC_ERROR_HANDLING=1
export OMP_NUM_THREADS=8

# get free port
export random_port=$(python getPort.py)

echo "rdvz-endpoint is ${head_node_ip}:${random_port}"
endpoint="${head_node_ip}:${random_port}"

export NGPUS_PER_NODE=4

CUDA_VISIBLE_DEVICES="0,1,2,3" srun --cpus-per-task=8 --wait=60 --ntasks-per-node=1 --kill-on-bad-exit=1 torchrun \
--max_restarts 3 \
--nnodes ${SLURM_NNODES} \
--nproc_per_node ${NGPUS_PER_NODE}  \
--rdzv_id 10000 \
--rdzv_backend c10d \
--rdzv_endpoint $endpoint \
--log_dir ${PWD}/log_torch \
multinode_multiGPU.py 10 5 3000 

Python code

This code is based on https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multinode.py which is related to the interesting series of tutorials https://pytorch.org/tutorials/beginner/ddp_series_intro.html. We saved this file as multinode_multiGPU.py in what follows.

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datautils import MyTrainDataset

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

def ddp_setup():
    init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        save_every: int,
        snapshot_path: str,
    ) -> None:
        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.global_rank = int(os.environ["RANK"])
        self.model = model.to(self.local_rank)
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every
        self.epochs_run = 0
        self.snapshot_path = snapshot_path
        if os.path.exists(snapshot_path):
            print("Loading snapshot")
            self._load_snapshot(snapshot_path)

        self.model = DDP(self.model, device_ids=[self.local_rank])

    def _load_snapshot(self, snapshot_path):
        loc = f"cuda:{self.local_rank}"
        snapshot = torch.load(snapshot_path, map_location=loc)
        self.model.load_state_dict(snapshot["MODEL_STATE"])
        self.epochs_run = snapshot["EPOCHS_RUN"]
        print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

    def _run_batch(self, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source)
        loss = F.cross_entropy(output, targets)
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        b_sz = len(next(iter(self.train_data))[0])
        print(f"[GPU{self.global_rank}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
        self.train_data.sampler.set_epoch(epoch)
        for source, targets in self.train_data:
            source = source.to(self.local_rank)
            targets = targets.to(self.local_rank)
            self._run_batch(source, targets)

    def _save_snapshot(self, epoch):
        snapshot = {
            "MODEL_STATE": self.model.module.state_dict(),
            "EPOCHS_RUN": epoch,
        }
        torch.save(snapshot, self.snapshot_path)
        print(f"Epoch {epoch} | Training snapshot saved at {self.snapshot_path}")

    def train(self, max_epochs: int):
        for epoch in range(self.epochs_run, max_epochs):
            self._run_epoch(epoch)
            if self.local_rank == 0 and epoch % self.save_every == 0:
                self._save_snapshot(epoch)


def load_train_objs(dataset_size:int):
    train_set = MyTrainDataset(dataset_size)  # load your dataset
    model = torch.nn.Linear(20, 1)  # load your model
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    return train_set, model, optimizer


def prepare_dataloader(dataset: Dataset, batch_size: int):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        sampler=DistributedSampler(dataset)
    )


def main(save_every: int,\
          total_epochs: int,\
              batch_size: int,\
                dataset_size:int,\
                  snapshot_path: str = os.path.join(os.getcwd(), 'snapshot.pt')):
    ddp_setup()
    dataset, model, optimizer = load_train_objs(dataset_size)
    train_data = prepare_dataloader(dataset, batch_size)
    trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
    trainer.train(total_epochs)
    print('Will now destroy the process group\n')
    destroy_process_group()
    print('Process group destroyed')
    import sys
    # https://github.com/pytorch/pytorch/issues/76287
    sys.exit(0)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='simple distributed training job')
    parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
    parser.add_argument('save_every', type=int, help='How often to save a snapshot')
    parser.add_argument('dataset_size', type=int, help='model size')
    parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
    args = parser.parse_args()

    main(args.save_every, args.total_epochs, args.batch_size, args.dataset_size)

You will also need to have in the same directory datautils.py:

import torch
from torch.utils.data import Dataset

class MyTrainDataset(Dataset):
    def __init__(self, size):
        self.size = size
        self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]

    def __len__(self):
        return self.size

    def __getitem__(self, index):
        return self.data[index]

as well as the getPort.py

import socket

def find_free_port():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(('', 0))  # Bind to a free port provided by the host.
        return s.getsockname()[1]  # Return the port number assigned.

if __name__ == "__main__":
    print(find_free_port())

Explanations

SLURM Directives

#!/bin/bash -l
#SBATCH --time=01:30:00
#SBATCH --account=YOURACCOUNT
#SBATCH --partition=gpu
#SBATCH --qos=default
#SBATCH --nodes=2
#SBATCH --ntasks=2
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-task=1
#SBATCH --cpus-per-task=8
#SBATCH --job-name=multinode_training
  • #!/bin/bash -l: Specifies the script is run in the Bash shell.
  • --time=01:30:00: Allocates 1 hour and 30 minutes for the job.
  • --account=YOURACCOUNT: Specifies the account for billing. Change it with your account.
  • --partition=gpu: Assigns the job to the GPU partition (necessary to benefit acceleration coming from PyTorch with CUDA support)
  • --qos=default: Sets the Quality of Service to default.
  • --nodes=2: Requests 2 nodes.
  • --ntasks=2: Total number of tasks across all nodes. Be extra careful here, indeed, there is only one task per node corresponding to one torchrun instance per node. If one puts --ntasks=8, as 8 is the total number of GPUs involved in the training here, some communications issues will occur and it is not straightforward to understand what happens!
  • --gpus-per-task=1: Assigns one GPU per task.
  • --cpus-per-task=8: Allocates 8 CPU cores per task.
  • --job-name=multinode_training: Names the job.

Module Loading

With the 2022 stack

module load env/release/2022.1
module load PyTorch/1.12.0-foss-2022a-CUDA-11.7.0

Loads the environment and the specific version of PyTorch compatible with CUDA 11.7.0. If you do module load PyTorch/1.12.0-foss-2022a you will not have CUDA support!

Atlernative: using the 2023 stack

module load env/release/2023.1
module load  PyTorch/2.1.2-foss-2023a-CUDA-12.1.1

Same remark as above! Ensure to load the CUDA-compliant version of the module.

Port and Endpoint Setup

nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
random_port=$((RANDOM % 19000 + 1000))
echo "rdvz-endpoint is ${head_node_ip}:${random_port}"
endpoint="${head_node_ip}:${random_port}"
  • Extracts the list of nodes allocated to the job and identifies the head node and its IP address.
  • Generates a random port for rendezvous and defines the endpoint using the head node's IP address and the random port.

Running the PyTorch Script

CUDA_VISIBLE_DEVICES="0,1,2,3" srun  --cpus-per-task=8 --wait=60 --ntasks-per-node=1 --kill-on-bad-exit=1 torchrun \
--max_restarts 3 \
--nnodes ${SLURM_NNODES} \
--nproc_per_node ${NGPUS_PER_NODE}  \
--rdzv_id 10000 \
--rdzv_backend c10d \
--rdzv_endpoint $endpoint \
--log_dir ${PWD}/log_torch \
multinode_multiGPU.py 10 5 3000 
  • CUDA_VISIBLE_DEVICES specifies which CUDA devices (GPUs) are visible to each task.
  • We use torchrun to launch the PyTorch script multinode_multiGPU.py wrapped in srun . It specifies the number of nodes, processes per node, rendezvous parameters, and logging directory. The script arguments (10 5 3000) corresponds to the number of epochs, the number of epochs between each update of the checkpoint saved in the snapshot file, and the size of the problem passed to theload_train_objs()function, respectively.