Skip to content

Multi-node/Multi-GPU fine-tuning with Ray/Deepspeed

What is Ray Ray

Ray is an open-source distributed computing framework that makes it easy to scale Python applications — especially those related to machine learning, data processing, and deep learning. It's built to enable high-performance parallel and distributed computing without needing to dive into the complexities of cluster setup and management.

Ray provides a unified API to scale Python code from a single laptop to a large cluster. Under the hood, it manages distributed task execution and memory efficiently.

  • Ray has several high-level libraries built on top of it, like:

    • Ray Tune – for hyperparameter tuning

    • Ray Train – for distributed model training

    • Ray Serve – for model deployment and serving

    • Ray Data – for scalable data loading and processing

We are interested into Ray Train and its integration with Hugging Face, PyTorch and Deepspeed as it can

  • distribute training across GPUs/nodes.

  • preprocess large datasets in parallel.

  • tune hyperparameters efficiently.

  • handle failures during long-running jobs.

Running on Meluxina GPU

We are going to fine-tune GPT-J-6B on the tiny shakespeare dataset based on the GPT-J-6B Fine-Tuning with Ray Train and DeepSpeed example provided by Ray.

#!/bin/bash -l

#SBATCH -A <ACCOUNT>
#SBATCH -q default
#SBATCH -p gpu
#SBATCH -t 24:0:0
#SBATCH -N 4
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=128
#SBATCH --gpus-per-task=4
#SBATCH --error="ray-%j.err"
#SBATCH --output="ray-%j.out"

module load env/release/2024.1
module load CUDA
module load Python


export HF_HOME="./huggingface"
export TRANSFORMERS_CACHE=${HF_HOME}
export HUGGING_FACE_HUB_CACHE=${HF_HOME}
export HUGGING_FACE_HUB_TOKEN=<your_token>

export GPUS_PER_TASK=4

RAY_ENV="ray-env"
mkdir -p ${RAY_ENV}
source ${RAY_ENV}/bin/activate

if [[ ! -d "ray-env" ]];then
    pip install -U "ray[data,train,tune,serve]"
fi

export HEAD_HOSTNAME="$(hostname)"
export HEAD_IPADDRESS="$(hostname --ip-address)"

echo "HEAD NODE: ${HEAD_HOSTNAME}"
echo "IP ADDRESS: ${HEAD_IPADDRESS}"
echo "SSH TUNNEL (Execute on your local machine): ssh -p 8822 ${USER}@login.lxp.lu  -NL 8000:${HEAD_IPADDRESS}:8000"  

# We need to get an available random port
export RANDOM_PORT=$(python3 -c 'import socket; s = socket.socket(); s.bind(("", 0)); print(s.getsockname()[1]); s.close()')

# Command to start the head node
export RAY_CMD_HEAD="ray start --block --head --port=${RANDOM_PORT} --num-cpus ${SLURM_CPUS_PER_TASK} --num-gpus ${GPUS_PER_TASK} --include-dashboard True --dashboard-host 0.0.0.0 "
# Command to start workers
export RAY_CMD_WORKER="ray start --block --address=${HEAD_IPADDRESS}:${RANDOM_PORT} --num-cpus ${SLURM_CPUS_PER_TASK} --num-gpus ${GPUS_PER_TASK} "


# Start head node
echo "Starting head node"
srun -J "head ray node-step-%J" -N 1 --ntasks-per-node=1  -c ${SLURM_CPUS_PER_TASK} --gpus-per-task=${GPUS_PER_TASK} -w ${HEAD_HOSTNAME} ${RAY_CMD_HEAD} &
sleep 10
echo "Starting worker node"
srun -J "worker ray node-step-%J" -N $(( SLURM_NNODES-1 )) -n $(( SLURM_NNODES-1 )) --ntasks-per-node=1 -c ${SLURM_CPUS_PER_TASK} --gpus-per-task=${GPUS_PER_TASK} -x ${HEAD_HOSTNAME} ${RAY_CMD_WORKER} &
#sleep 30
# Start server on head to serve the model
#echo "Start training"
python "$@"
import numpy as np
import pandas as pd
import os,ray
import evaluate
import torch
import ray.data
from datasets import load_dataset
from pathlib import Path
from transformers import (
    Trainer,
    TrainingArguments,
    GPTJForCausalLM,
    AutoTokenizer,
    default_data_collator,
)
from transformers.utils.logging import disable_progress_bar, enable_progress_bar
from ray import train
from ray.train.huggingface.transformers import prepare_trainer, RayTrainReportCallback
import evaluate
import torch
from transformers import (
    Trainer,
    TrainingArguments,
    GPTJForCausalLM,
    AutoTokenizer,
    default_data_collator,
)
from transformers.utils.logging import disable_progress_bar, enable_progress_bar
from ray import train
from ray.train.huggingface.transformers import prepare_trainer, RayTrainReportCallback
from datasets import load_dataset
from ray.train.torch import TorchTrainer
from ray.train import RunConfig, ScalingConfig

ray.init()

model_name = "EleutherAI/gpt-j-6B"
use_gpu = True
num_workers = 16
cpus_per_worker = 8
block_size = 512

print("Loading tiny_shakespeare dataset")
current_dataset = load_dataset("tiny_shakespeare")


ray_datasets = {
    "train": ray.data.from_huggingface(current_dataset["train"]),
    "validation": ray.data.from_huggingface(current_dataset["validation"]),
}


from transformers import AutoTokenizer


def split_text(batch: pd.DataFrame) -> pd.DataFrame:
    text = list(batch["text"])
    flat_text = "".join(text)
    split_text = [
        x.strip()
        for x in flat_text.split("\n")
        if x.strip() and not x.strip()[-1] == ":"
    ]
    return pd.DataFrame(split_text, columns=["text"])


def tokenize(batch: pd.DataFrame) -> dict:
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    tokenizer.pad_token = tokenizer.eos_token
    ret = tokenizer(
        list(batch["text"]),
        truncation=True,
        max_length=block_size,
        padding="max_length",
        return_tensors="np",
    )
    ret["labels"] = ret["input_ids"].copy()
    return dict(ret)


processed_datasets = {
    key: (
        ds.map_batches(split_text, batch_format="pandas")
        .map_batches(tokenize, batch_format="pandas")
    )
    for key, ds in ray_datasets.items()
}

def train_func(config):
    # Use the actual number of CPUs assigned by Ray
    os.environ["OMP_NUM_THREADS"] = str(
        train.get_context().get_trial_resources().bundles[-1].get("CPU", 1)
    )
    os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = os.environ.get('LOCAL_RANK')
    # Enable tf32 for better performance
    torch.backends.cuda.matmul.allow_tf32 = True

    batch_size = config.get("batch_size", 4)
    epochs = config.get("epochs", 2)
    warmup_steps = config.get("warmup_steps", 0)
    learning_rate = config.get("learning_rate", 0.00002)
    weight_decay = config.get("weight_decay", 0.01)
    steps_per_epoch = config.get("steps_per_epoch")

    deepspeed = {
        "fp16": {
            "enabled": "auto",
            "initial_scale_power": 8,
            "hysteresis": 4,
            "consecutive_hysteresis": True,
            "fp16_scale_tolerance": 0.25
        },
        "bf16": {"enabled": "auto"},
        "optimizer": {
            "type": "AdamW",
            "params": {
                "lr": "auto",
                "betas": "auto",
                "eps": "auto",
            },
        },
        "zero_optimization": {
            "stage": 3,
            "offload_optimizer": {
                "device": "cpu",
                "pin_memory": True,
            },
            "overlap_comm": True,
            "contiguous_gradients": True,
            "reduce_bucket_size": "auto",
            "stage3_prefetch_bucket_size": "auto",
            "stage3_param_persistence_threshold": "auto",
            "gather_16bit_weights_on_model_save": True,
            "round_robin_gradients": True,
        },
        "gradient_accumulation_steps": "auto",
        "gradient_clipping": "auto",
        "steps_per_print": 10,
        "train_batch_size": "auto",
        "train_micro_batch_size_per_gpu": "auto",
        "wall_clock_breakdown": False,
    }

    print("Preparing training arguments")
    training_args = TrainingArguments(
        "output",
        logging_steps=1,
        save_strategy="steps",
        save_steps=steps_per_epoch,
        max_steps=steps_per_epoch * epochs,
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=1,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        warmup_steps=warmup_steps,
        label_names=["input_ids", "attention_mask"],
        push_to_hub=False,
        report_to="none",
        disable_tqdm=True,  # declutter the output a little
        fp16=True,
        gradient_checkpointing=True,
        deepspeed=deepspeed,
    )
    disable_progress_bar()

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    print("Loading model")

    model = GPTJForCausalLM.from_pretrained(model_name, use_cache=False)
    model.resize_token_embeddings(len(tokenizer))

    print("Model loaded")

    enable_progress_bar()

    metric = evaluate.load("accuracy")

    train_ds = train.get_dataset_shard("train")
    eval_ds = train.get_dataset_shard("validation")

    train_ds_iterable = train_ds.iter_torch_batches(
        batch_size=batch_size,
        local_shuffle_buffer_size=train.get_context().get_world_size() * batch_size,
    )
    eval_ds_iterable = eval_ds.iter_torch_batches(batch_size=batch_size)

    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        return metric.compute(predictions=predictions, references=labels)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds_iterable,
        eval_dataset=eval_ds_iterable,
        compute_metrics=compute_metrics,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
    )

    # Add callback to report checkpoints to Ray Train
    trainer.add_callback(RayTrainReportCallback())
    trainer = prepare_trainer(trainer)
    trainer.train()

if __name__ == "__main__":
    batch_size = 16
    train_ds_size = processed_datasets["train"].count()
    steps_per_epoch = train_ds_size // (batch_size * num_workers)
    trainer = TorchTrainer(
    train_loop_per_worker=train_func,
    train_loop_config={
        "epochs": 1,
        "batch_size": batch_size,  # per device
        "steps_per_epoch": steps_per_epoch,
    },
    scaling_config=ScalingConfig(
        num_workers=num_workers,
        use_gpu=use_gpu,
        resources_per_worker={"GPU": 1, "CPU": cpus_per_worker},
    ),
    datasets=processed_datasets,
    run_config=RunConfig(storage_path=Path("./results").resolve(), name="test_experiment")
    )
    results = trainer.fit()
  • To start fine-tuning the model, execute the following command on a login node: sbatch start-ray-cluster.sh fine-tuning.py

Accessing the dashboard

The Ray Dashboard is a web-based UI that lets you monitor and manage your Ray cluster and applications in real time. It’s super handy when you're training large models or running distributed workloads.

  • To access the dashboard, you will need to create a ssh tunnel

  • Use the following grep command: grep "SSH TUNNEL" ray-<jobid>.out to retrieve the ssh command you need to execute on your laptop

  • Once the ssh tunnel has been created, click on the following address: http://localhost:8265

Test the fine-tuned model

  • Save the next python script into test_model.py
from transformers import pipeline, AutoTokenizer, GPTJForCausalLM
import torch

local_storage="./results/test_experiment/<replace with your path>/checkpoint/"
model = GPTJForCausalLM.from_pretrained(local_storage)
tokenizer = AutoTokenizer.from_pretrained(local_storage)

pipe = pipeline(
    model=model,
    tokenizer=tokenizer,
    task="text-generation",
    torch_dtype=torch.float16,
    device_map="auto",
)

# Generate from prompts!
for sentence in pipe(
    ["Romeo and Juliet", "Romeo", "Juliet"], do_sample=True, min_length=20
):
    print(sentence)
  • Execute the following command to generate prompts:python test_model.py