A distributed training example
The following presents a MWE of a multi-node/multi-GPU PyTorch script. This script demonstrates the integration of SLURM with PyTorch for distributed training across GPUs on Meluxina. It might be used as a basis to users willing to move their PyTorch training script from serial to distributed.
For the sake of the example, we used a Resnet50 model from the torchvision
module. Feel free to tweak the slurm launcher and the training script to make some tests!
Source code
To reproduce this example, you will need to create tree separate files in the directory from which you would like to work:
- a first python snippet
getPort.py
that will allow us to choose the port on which the node responsible to orchestrate the communication when necessary during the training, - a second python script
resnet50_LXP.py
containing the training code, - a slurm script
slurm_resnet50_LXP.sh
that we will use to launch the training
To test this example, execute the following command sbatch slurm_resnet50_LXP.sh
.
import socket
def find_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0)) # Bind to a free port provided by the host.
return s.getsockname()[1] # Return the port number assigned.
if __name__ == "__main__":
print(find_free_port())
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
# from datautils import MyTrainDataset
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import torch.distributed as dist
import os
from torchvision import models
def ddp_setup():
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
world_size = dist.get_world_size()
rank = dist.get_rank()
return world_size, rank
class Trainer:
def __init__(
self,
model: torch.nn.Module,
train_data: DataLoader,
optimizer: torch.optim.Optimizer,
save_every: int,
snapshot_path: str,
) -> None:
self.local_rank = int(os.environ["LOCAL_RANK"])
self.global_rank = int(os.environ["RANK"])
self.model = model.to(self.local_rank)
self.train_data = train_data
self.optimizer = optimizer
self.save_every = save_every
self.epochs_run = 0
self.snapshot_path = snapshot_path
if os.path.exists(snapshot_path):
print("Loading snapshot")
self._load_snapshot(snapshot_path)
self.model = DDP(self.model, device_ids=[self.local_rank])
def _load_snapshot(self, snapshot_path):
loc = f"cuda:{self.local_rank}"
snapshot = torch.load(snapshot_path, map_location=loc)
self.model.load_state_dict(snapshot["MODEL_STATE"])
self.epochs_run = snapshot["EPOCHS_RUN"]
print(f"Resuming training from snapshot at Epoch {self.epochs_run}")
def _run_batch(self, source, targets):
self.optimizer.zero_grad()
output = self.model(source)
loss = F.cross_entropy(output, targets)
loss.backward()
self.optimizer.step()
def _run_epoch(self, epoch):
b_sz = len(next(iter(self.train_data))[0])
print(f"[GPU{self.global_rank}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
self.train_data.sampler.set_epoch(epoch)
for source, targets in self.train_data:
source = source.to(self.local_rank)
targets = targets.to(self.local_rank)
self._run_batch(source, targets)
def _save_snapshot(self, epoch):
snapshot = {
"MODEL_STATE": self.model.module.state_dict(),
"EPOCHS_RUN": epoch,
}
torch.save(snapshot, self.snapshot_path)
print(f"Epoch {epoch} | Training snapshot saved at {self.snapshot_path}")
def train(self, max_epochs: int):
for epoch in range(self.epochs_run, max_epochs):
self._run_epoch(epoch)
if self.local_rank == 0 and epoch % self.save_every == 0:
self._save_snapshot(epoch)
def load_train_objs(num_epochs, batch_size_per_gpu, world_size):
train_set = SyntheticDataset(num_epochs, batch_size_per_gpu, world_size)
model = models.resnet50(pretrained=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
return train_set, model, optimizer
def prepare_dataloader(dataset: Dataset, batch_size: int):
return DataLoader(
dataset,
batch_size=batch_size,
pin_memory=True,
shuffle=False,
sampler=DistributedSampler(dataset),
num_workers=32
)
def main(save_every: int,\
total_epochs: int,\
batch_size: int,\
snapshot_path: str = os.path.join(os.getcwd(), 'snapshot.pt')):
world_size, rank = ddp_setup()
batch_size_per_gpu = int(batch_size / 4 )
dataset, model, optimizer = load_train_objs(total_epochs, batch_size_per_gpu, world_size)
train_data = prepare_dataloader(dataset, batch_size)
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
trainer.train(total_epochs)
print('Will now destroy the process group\n')
destroy_process_group()
print('Process group destroyed')
import sys
# https://github.com/pytorch/pytorch/issues/76287
sys.exit(0)
import random
class SyntheticDataset(Dataset):
def __init__(self, num_epochs, batch_size_per_gpu, world_size):
self.num_epochs = num_epochs
self.batch_size_per_gpu = batch_size_per_gpu
self.num_iters = world_size
self.len = self.num_epochs * self.batch_size_per_gpu * self.num_iters
def __getitem__(self, idx):
data = torch.randn(3, 224, 224)
target = random.randint(0, 999)
return (data, target)
def __len__(self):
return self.len
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='simple distributed training job')
parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
parser.add_argument('save_every', type=int, help='How often to save a snapshot')
parser.add_argument('batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
args = parser.parse_args()
main(args.save_every, args.total_epochs, args.batch_size)
#!/bin/bash -l
#SBATCH --job-name="restnetDistri"
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=1
#SBATCH --output=%x%j.out
#SBATCH --error=%x%j.err
#SBATCH -p gpu
#SBATCH -q default
#SBATCH --time=01:10:00
#SBATCH -A lxp
module load env/staging/2023.1
module load PyTorch/2.1.2-foss-2023a-CUDA-12.1.1
module load torchvision/
nodes=($(scontrol show hostnames $SLURM_JOB_NODELIST))
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
echo "The head node is ${head_node}"
# OPTIONAL: set to true if you want more details on NCCL communications
DEBUG=true
if [ "$DEBUG" == "true" ]; then
export LOGLEVEL=INFO
export NCCL_DEBUG=TRACE
export TORCH_CPP_LG_LEVEL=INFO
else
echo "Debug mode is off."
fi
# Define the file where PyTorch will make a snapshot in case a training is interrupted and will have to be restarted
snapshot_name="snapshot.pt"
snapshot_file="${PWD}/${snapshot_name}"
if [ -f "$snapshot_file" ]; then
file_exists=true
echo "snapshot file found"
else
file_exists=false
echo "no snapshot file was found"
fi
remove_snapshot=true
if [ "$remove_snapshot" == "true" ]; then
if [ -f "$snapshot_file" ]; then
rm ${snapshot_file}
echo "snapshot file deleted"
fi
fi
export NCCL_SOCKET_IFNAME=ib0
export NCCL_ASYNC_ERROR_HANDLING=1
export OMP_NUM_THREADS=8
# get free port
export random_port=$(python getPort.py)
echo "rdvz-endpoint is ${head_node_ip}:${random_port}"
endpoint="${head_node_ip}:${random_port}"
export NGPUS_PER_NODE=4
CUDA_VISIBLE_DEVICES="0,1,2,3" srun --cpus-per-task=8 --wait=60 --ntasks-per-node=1 --kill-on-bad-exit=1 torchrun --max_restarts 3 --nnodes ${SLURM_NNODES} --nproc_per_node ${NGPUS_PER_NODE} --rdzv_id 10000 --rdzv_backend c10d --rdzv_endpoint $endpoint --log_dir ${PWD}/log_torch resnet50_LXP.py 4 1 32