Multi-node/Multi-GPU fine-tuning with Ray/Deepspeed
What is Ray 
Ray is an open-source distributed computing framework that makes it easy to scale Python applications — especially those related to machine learning, data processing, and deep learning. It's built to enable high-performance parallel and distributed computing without needing to dive into the complexities of cluster setup and management.
Ray provides a unified API to scale Python code from a single laptop to a large cluster. Under the hood, it manages distributed task execution and memory efficiently.
-
Ray has several high-level libraries built on top of it, like:
-
Ray Tune – for hyperparameter tuning
-
Ray Train – for distributed model training
-
Ray Serve – for model deployment and serving
-
Ray Data – for scalable data loading and processing
-
We are interested into Ray Train and its integration with Hugging Face, PyTorch and Deepspeed as it can
-
distribute training across GPUs/nodes.
-
preprocess large datasets in parallel.
-
tune hyperparameters efficiently.
-
handle failures during long-running jobs.
Running on Meluxina GPU
We are going to fine-tune GPT-J-6B on the tiny shakespeare dataset based on the GPT-J-6B Fine-Tuning with Ray Train and DeepSpeed example provided by Ray.
#!/bin/bash -l
#SBATCH -A <ACCOUNT>
#SBATCH -q default
#SBATCH -p gpu
#SBATCH -t 24:0:0
#SBATCH -N 4
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=128
#SBATCH --gpus-per-task=4
#SBATCH --error="ray-%j.err"
#SBATCH --output="ray-%j.out"
module load env/release/2024.1
module load CUDA
module load Python
export HF_HOME="./huggingface"
export TRANSFORMERS_CACHE=${HF_HOME}
export HUGGING_FACE_HUB_CACHE=${HF_HOME}
export HUGGING_FACE_HUB_TOKEN=<your_token>
export GPUS_PER_TASK=4
RAY_ENV="ray-env"
mkdir -p ${RAY_ENV}
source ${RAY_ENV}/bin/activate
if [[ ! -d "ray-env" ]];then
pip install -U "ray[data,train,tune,serve]"
fi
export HEAD_HOSTNAME="$(hostname)"
export HEAD_IPADDRESS="$(hostname --ip-address)"
echo "HEAD NODE: ${HEAD_HOSTNAME}"
echo "IP ADDRESS: ${HEAD_IPADDRESS}"
echo "SSH TUNNEL (Execute on your local machine): ssh -p 8822 ${USER}@login.lxp.lu -NL 8000:${HEAD_IPADDRESS}:8000"
# We need to get an available random port
export RANDOM_PORT=$(python3 -c 'import socket; s = socket.socket(); s.bind(("", 0)); print(s.getsockname()[1]); s.close()')
# Command to start the head node
export RAY_CMD_HEAD="ray start --block --head --port=${RANDOM_PORT} --num-cpus ${SLURM_CPUS_PER_TASK} --num-gpus ${GPUS_PER_TASK} --include-dashboard True --dashboard-host 0.0.0.0 "
# Command to start workers
export RAY_CMD_WORKER="ray start --block --address=${HEAD_IPADDRESS}:${RANDOM_PORT} --num-cpus ${SLURM_CPUS_PER_TASK} --num-gpus ${GPUS_PER_TASK} "
# Start head node
echo "Starting head node"
srun -J "head ray node-step-%J" -N 1 --ntasks-per-node=1 -c ${SLURM_CPUS_PER_TASK} --gpus-per-task=${GPUS_PER_TASK} -w ${HEAD_HOSTNAME} ${RAY_CMD_HEAD} &
sleep 10
echo "Starting worker node"
srun -J "worker ray node-step-%J" -N $(( SLURM_NNODES-1 )) -n $(( SLURM_NNODES-1 )) --ntasks-per-node=1 -c ${SLURM_CPUS_PER_TASK} --gpus-per-task=${GPUS_PER_TASK} -x ${HEAD_HOSTNAME} ${RAY_CMD_WORKER} &
#sleep 30
# Start server on head to serve the model
#echo "Start training"
python "$@"
import numpy as np
import pandas as pd
import os,ray
import evaluate
import torch
import ray.data
from datasets import load_dataset
from pathlib import Path
from transformers import (
Trainer,
TrainingArguments,
GPTJForCausalLM,
AutoTokenizer,
default_data_collator,
)
from transformers.utils.logging import disable_progress_bar, enable_progress_bar
from ray import train
from ray.train.huggingface.transformers import prepare_trainer, RayTrainReportCallback
import evaluate
import torch
from transformers import (
Trainer,
TrainingArguments,
GPTJForCausalLM,
AutoTokenizer,
default_data_collator,
)
from transformers.utils.logging import disable_progress_bar, enable_progress_bar
from ray import train
from ray.train.huggingface.transformers import prepare_trainer, RayTrainReportCallback
from datasets import load_dataset
from ray.train.torch import TorchTrainer
from ray.train import RunConfig, ScalingConfig
ray.init()
model_name = "EleutherAI/gpt-j-6B"
use_gpu = True
num_workers = 16
cpus_per_worker = 8
block_size = 512
print("Loading tiny_shakespeare dataset")
current_dataset = load_dataset("tiny_shakespeare")
ray_datasets = {
"train": ray.data.from_huggingface(current_dataset["train"]),
"validation": ray.data.from_huggingface(current_dataset["validation"]),
}
from transformers import AutoTokenizer
def split_text(batch: pd.DataFrame) -> pd.DataFrame:
text = list(batch["text"])
flat_text = "".join(text)
split_text = [
x.strip()
for x in flat_text.split("\n")
if x.strip() and not x.strip()[-1] == ":"
]
return pd.DataFrame(split_text, columns=["text"])
def tokenize(batch: pd.DataFrame) -> dict:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
ret = tokenizer(
list(batch["text"]),
truncation=True,
max_length=block_size,
padding="max_length",
return_tensors="np",
)
ret["labels"] = ret["input_ids"].copy()
return dict(ret)
processed_datasets = {
key: (
ds.map_batches(split_text, batch_format="pandas")
.map_batches(tokenize, batch_format="pandas")
)
for key, ds in ray_datasets.items()
}
def train_func(config):
# Use the actual number of CPUs assigned by Ray
os.environ["OMP_NUM_THREADS"] = str(
train.get_context().get_trial_resources().bundles[-1].get("CPU", 1)
)
os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = os.environ.get('LOCAL_RANK')
# Enable tf32 for better performance
torch.backends.cuda.matmul.allow_tf32 = True
batch_size = config.get("batch_size", 4)
epochs = config.get("epochs", 2)
warmup_steps = config.get("warmup_steps", 0)
learning_rate = config.get("learning_rate", 0.00002)
weight_decay = config.get("weight_decay", 0.01)
steps_per_epoch = config.get("steps_per_epoch")
deepspeed = {
"fp16": {
"enabled": "auto",
"initial_scale_power": 8,
"hysteresis": 4,
"consecutive_hysteresis": True,
"fp16_scale_tolerance": 0.25
},
"bf16": {"enabled": "auto"},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
},
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True,
},
"overlap_comm": True,
"contiguous_gradients": True,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"gather_16bit_weights_on_model_save": True,
"round_robin_gradients": True,
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 10,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": False,
}
print("Preparing training arguments")
training_args = TrainingArguments(
"output",
logging_steps=1,
save_strategy="steps",
save_steps=steps_per_epoch,
max_steps=steps_per_epoch * epochs,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=1,
learning_rate=learning_rate,
weight_decay=weight_decay,
warmup_steps=warmup_steps,
label_names=["input_ids", "attention_mask"],
push_to_hub=False,
report_to="none",
disable_tqdm=True, # declutter the output a little
fp16=True,
gradient_checkpointing=True,
deepspeed=deepspeed,
)
disable_progress_bar()
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
print("Loading model")
model = GPTJForCausalLM.from_pretrained(model_name, use_cache=False)
model.resize_token_embeddings(len(tokenizer))
print("Model loaded")
enable_progress_bar()
metric = evaluate.load("accuracy")
train_ds = train.get_dataset_shard("train")
eval_ds = train.get_dataset_shard("validation")
train_ds_iterable = train_ds.iter_torch_batches(
batch_size=batch_size,
local_shuffle_buffer_size=train.get_context().get_world_size() * batch_size,
)
eval_ds_iterable = eval_ds.iter_torch_batches(batch_size=batch_size)
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds_iterable,
eval_dataset=eval_ds_iterable,
compute_metrics=compute_metrics,
tokenizer=tokenizer,
data_collator=default_data_collator,
)
# Add callback to report checkpoints to Ray Train
trainer.add_callback(RayTrainReportCallback())
trainer = prepare_trainer(trainer)
trainer.train()
if __name__ == "__main__":
batch_size = 16
train_ds_size = processed_datasets["train"].count()
steps_per_epoch = train_ds_size // (batch_size * num_workers)
trainer = TorchTrainer(
train_loop_per_worker=train_func,
train_loop_config={
"epochs": 1,
"batch_size": batch_size, # per device
"steps_per_epoch": steps_per_epoch,
},
scaling_config=ScalingConfig(
num_workers=num_workers,
use_gpu=use_gpu,
resources_per_worker={"GPU": 1, "CPU": cpus_per_worker},
),
datasets=processed_datasets,
run_config=RunConfig(storage_path=Path("./results").resolve(), name="test_experiment")
)
results = trainer.fit()
- To start fine-tuning the model, execute the following command on a login node:
sbatch start-ray-cluster.sh fine-tuning.py
Accessing the dashboard
The Ray Dashboard is a web-based UI that lets you monitor and manage your Ray cluster and applications in real time. It’s super handy when you're training large models or running distributed workloads.
-
To access the dashboard, you will need to create a ssh tunnel
-
Use the following grep command:
grep "SSH TUNNEL" ray-<jobid>.out
to retrieve the ssh command you need to execute on your laptop -
Once the ssh tunnel has been created, click on the following address: http://localhost:8265
Test the fine-tuned model
- Save the next python script into
test_model.py
from transformers import pipeline, AutoTokenizer, GPTJForCausalLM
import torch
local_storage="./results/test_experiment/<replace with your path>/checkpoint/"
model = GPTJForCausalLM.from_pretrained(local_storage)
tokenizer = AutoTokenizer.from_pretrained(local_storage)
pipe = pipeline(
model=model,
tokenizer=tokenizer,
task="text-generation",
torch_dtype=torch.float16,
device_map="auto",
)
# Generate from prompts!
for sentence in pipe(
["Romeo and Juliet", "Romeo", "Juliet"], do_sample=True, min_length=20
):
print(sentence)
- Execute the following command to generate prompts:
python test_model.py