Multi-Node & Multi-GPU Inference with vLLM
Running Llama 3.1 – 405B – FP8 on MeluXina
Overview
This tutorial demonstrates how to run a very large LLM (Llama 3.1 405B FP8) model on multiple GPU nodes using vLLM on MeluXina.
Learning Objectives
You will learn how to:
- Use tensor parellelism and pipeline parallelism to serve a model that does not fit on a single GPU or a single node with 4 GPUs.
- Setup a SLURM launcher to create a distributed Ray cluster using vLLM
- Start an inference server for Llama 3.1 - 405B - FP8
- Query the inference server from a remote machine (e.g., your laptop) using SSH port forwarding with
curl.
Estimated time: ~30 minutes
1. vLLM and Hugging Face Access
1.1. What is vLLM?
vLLM is an efficient and highly flexible library for serving large language models (LLMs).
It is optimized for high throughput and low latency, enabling fast and scalable inference across a wide range of machine learning models. Built with advanced optimization techniques, such as dynamic batching and memory-efficient model serving, vLLM ensures that even large models can be served with minimal resource overhead. This makes it ideal for deploying models in production environments where speed and efficiency are crucial.
vLLM also supports various model architectures and frameworks, making it versatile for a wide array of applications, from natural language processing to machine translation and beyond.
1.2. Hugging Face access token
To download the model weights, we need a Hugging Face (HF) access token.
Steps to generate a token:
- If not already done, create a profile on Hugging Face.
- Once your profile is created, go to the "Setting > Access Tokens" page to generate a token.
- Click on “New token” and select
Readas the Type. For more information, see the Hugging Face documentation. - Copy the token and save it in a safe place (for example in a password manager).
- In your interactive session, set an environment variable:
export MYHFTOKEN=hf_ ... # paste the token content here
1.3. Request access to the model
⚠️ Before moving on, you must request access to the model used here on Hugging Face and wait for the author to grant access.
Without approval, you will not be able to use the model. This can take up to a couple of hours.
2. Llama 3.1 – 405B – FP8 and Memory Planning
For this tutorial, we use the FP8 version of the Llama 3.1 405B model.
- Our A100 GPU cards do not have native support for FP8 computation, but FP8 quantization is used through weight-only FP8 compression, leveraging the Marlin kernel.
- This may slightly degrade performance for compute-heavy workloads, but it reduces the number of GPUs needed to run the model.
Beginner's note
Neural networks use a lot of numbers (weights, activations).
Those numbers are usually stored in different floating-point formats, e.g.:
- FP32 → 32 bits (4 bytes) per number (high precision, large memory)
- FP16 → 16 bits (2 bytes) per number
- FP8 → 8 bits (1 byte) per number
The Benefit: Moving from FP16 to FP8 halves the memory needed for the weights, which is critical for fitting very large models (like Llama 3.1 405B) onto fewer GPUs.
Native support means: the GPU hardware has built-in instructions designed to do FP8 math directly (like H100 does for FP8).
On A100, there is no no native FP8 compute:
- The A100's primary high-speed compute formats are FP16/BF16/FP32.
- The A100 doesn’t have special hardware instructions optimized specifically for FP8. This implies that running FP8 requires a software/kernel workaround on A100.
- Quantization = storing and/or computing with fewer bits.
- The Steps: Model weights (originally FP16/FP32) are approximated using the smaller 8-bit FP8 format to take up less memory.
- The Trade-off: Saves huge amounts of memory, but may cause a slight, manageable reduction in accuracy or speed.
- Weight-Only Scheme: Only the weights (the stored knowledge) are compressed to FP8 for maximum memory savings.
- Higher Precision Activations: The activations (intermediate calculations) are kept in a higher format like FP16 or BF16.
- The Result: You get big memory savings from compressed weights while maintaining good numerical stability and accuracy from higher-precision activations.
The Marlin kernel is a highly optimized, low-level program designed to run on the GPU.
- On A100, it bridges the gap for non-native FP8:
- It quickly reads the compressed FP8 weights.
- It efficiently dequantizes (converts) them back to FP16.
- It performs the matrix math using the A100's native, fast FP16 instructions.
Summary: Marlin lets you store weights in tiny FP8 while still achieving fast computation on A100 hardware.
2.1. Estimating required number of nodes
We assume the following:
- FP8 represents 1 byte of memory per parameter.
- Llama 3.1 has 405 billion parameters.
- A node on MeluXina has 4 × A100 40GB → 160GB of GPU memory.
- vLLM defines a
gpu_memory_utilizationparameter, which by default is0.9. - With this parameter of
0.9, we need378GB/(160GB*0.9)in total and thus 3 GPU nodes.
However:
- If you actually run with 3 GPU nodes, you will observe a CUDA OOM.
- Memory utilization is not exactly balanced between all GPUs because we mix tensor parallelism and pipeline parallelism.
Empirical tests on MeluXina have shown that 5 nodes are enough to run this model safely. In this tutorial, we will therefore request 5 GPU nodes.
3. Server-side Setup on Meluxina
3.1. Create a working directory
Connect to MeluXina.
ssh YOURUSER@login.lxp.lu -p 8822
Once connected to MeluXina, start from an empty directory:
mkdir tutorial-vLLM && cd tutorial-vLLM
3.2. Getting an interactive job and pulling the container
We request an interactive job on the gpu partition. Replace YOURPROJECT by your project id (e.g., p200xxx).
salloc -A YOURPROJECT -t 01:00:00 -q dev -p gpu --reservation=gpudev -N1
To avoid installing all dependencies and a Python virtual environment for the vLLM inference server, we will pull the container using the Apptainer tool.
module load Apptainer
apptainer pull docker://vllm/vllm-openai:v0.6.3
Notes:
⚠️ Pulling the container takes some time.
Once apptainer has finished, you should see the file vllm-openai_v0.6.3.sif in your current directory.
4. Preparing the SLURM Launcher Script
We now prepare a multi-node launcher that:
- Spawns a Ray head node and Ray workers using Apptainer.
- Starts a vLLM server on the head node.
- Uses environment variables for HF token, cache, and model name.
Create the file named launcher_vllm_multinode.sh and paste the following script.
In the script, update <Choose a path> with your own path and <your Hugging Face token> with your own access token generated earlier.
#!/bin/bash -l
#SBATCH -A lxp
#SBATCH -q default
#SBATCH -p gpu
#SBATCH -t 2:0:0
#SBATCH -N 5
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=128
#SBATCH --gpus-per-task=4
#SBATCH --error="vllm-%j.err"
#SBATCH --output="vllm-%j.out"
module --force purge
module load env/release/2023.1
module load Apptainer/1.3.1-GCCcore-12.3.0
# Fix pmix error (munge)
export PMIX_MCA_psec=native
# Choose a directory for the cache
export LOCAL_HF_CACHE="<Choose a path>/HF_cache"
mkdir -p ${LOCAL_HF_CACHE}
export HF_TOKEN="<your Hugging Face token>"
# Make sure the path to the SIF image is correct
# Here, the SIF image is in the same directory as this script
export SIF_IMAGE="vllm-openai_v0.6.3.sif"
export APPTAINER_ARGS=" --nvccli -B {$LOCAL_HF_CACHE}:/root/.cache/huggingface --env HF_HOME=/root/.cache/huggingface --env HUGGING_FACE_HUB_TOKEN=${HF_TOKEN}"
# Make sure your have been granted access to the model
export HF_MODEL="meta-llama/Llama-3.1-405B-FP8"
export HEAD_HOSTNAME="$(hostname)"
export HEAD_IPADDRESS="$(hostname --ip-address)"
echo "HEAD NODE: ${HEAD_HOSTNAME}"
echo "IP ADDRESS: ${HEAD_IPADDRESS}"
echo "SSH TUNNEL (Execute on your local machine): ssh -p 8822 ${USER}@login.lxp.lu -NL 8000:${HEAD_IPADDRESS}:8000"
# We need to get an available random port
export RANDOM_PORT=$(python3 -c 'import socket; s = socket.socket(); s.bind(("", 0)); print(s.getsockname()[1]); s.close()')
# Command to start the head node
export RAY_CMD_HEAD="ray start --block --head --port=${RANDOM_PORT}"
# Command to start workers
export RAY_CMD_WORKER="ray start --block --address=${HEAD_IPADDRESS}:${RANDOM_PORT}"
export TENSOR_PARALLEL_SIZE=4 # Set it to the number of GPU per node
export PIPELINE_PARALLEL_SIZE=${SLURM_NNODES} # Set it to the number of allocated GPU nodes
# Start head node
echo "Starting head node"
srun -J "head ray node-step-%J" -N 1 --ntasks-per-node=1 -c $(( SLURM_CPUS_PER_TASK/2 )) -w ${HEAD_HOSTNAME} apptainer exec ${APPTAINER_ARGS} ${SIF_IMAGE} ${RAY_CMD_HEAD} &
sleep 10
echo "Starting worker node"
srun -J "worker ray node-step-%J" -N $(( SLURM_NNODES-1 )) --ntasks-per-node=1 -c ${SLURM_CPUS_PER_TASK} -x ${HEAD_HOSTNAME} apptainer exec ${APPTAINER_ARGS} ${SIF_IMAGE} ${RAY_CMD_WORKER} &
sleep 30
# Start server on head to serve the model
echo "Starting server"
apptainer exec ${APPTAINER_ARGS} ${SIF_IMAGE} vllm serve ${HF_MODEL} --tensor-parallel-size ${TENSOR_PARALLEL_SIZE} --pipeline-parallel-size ${PIPELINE_PARALLEL_SIZE}
Then, submit the job:
sbatch launcher_vllm_multinode.sh
⏳ The setup of all workers and the inference server can take some time (Ray cluster startup + model loading).
5. Checking That the Server Is Running
Once everything is running, open the generated output file: vllm-<JOB ID>.out. Replace <JOB ID> with the actual job ID.
tail -f vllm-<JOB ID>.out # Check last lines of the file
You should see repeated monitoring lines like:
eqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-26 15:30:02 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-26 15:30:12 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-26 15:30:22 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-26 15:30:32 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-26 15:30:42 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-26 15:30:52 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 09-26 15:31:02 metrics.py:351] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
This indicates the server is idle and ready for requests.
6. Setting Up SSH Port Forwarding
To send requests from your local machine, you need an SSH tunnel from your laptop to the head node.
6.1. Retrieve the SSH command from the output
On MeluXina, you can retrieve the SSH command from the output.
grep -oE 'ssh -p 8822 .*:8000' vllm-<JOB ID>.out
<JOB ID> with the JOB ID of the vLLM inference server job.
Then, it will return the SSH command as follows.
ssh -p 8822 YOURUSER@login.lxp.lu 8000:<ip address>:8000
6.2. What SSH port forwarding does
SSH port forwarding (or SSH tunneling) uses SSH to create a secure tunnel between your local machine and the remote server.
- Local port (on your machine):
8000 - Remote endpoint (on MeluXina):
<ip address>:8000
6.3. Running the SSH tunnel
On your local machine:
- Open a new terminal.
- Paste and run the SSH command printed by the
grepcommand.
It is normal that the command doesn’t produce any output: it just establishes the tunnel and then stays open.
7. Querying the Inference Server with curl
With the SSH tunnel active, you can now send requests to http://localhost:8000 on your local machine.
- Use the SSH forwarding command from above.
- Then, in another local terminal, run the following command as an example:
curl -X POST -H "Content-Type: application/json" http://localhost:8000/v1/completions -d '{
"model": "meta-llama/Llama-3.1-405B-FP8",
"prompt": "San Francisco is a"
}'
You will see an output similar to the following or at least a response with the same structure:
{
"id":"cmpl-38c658fe804541eab7907a40234a61ae",
"object":"text_completion","created":1727358365,
"model":"meta-llama/Llama-3.1-405B-FP8",
"choices":[{"index":0,
"text":" top holiday destination featuring scenic beauty and great ethnic and cultural diversity. San Francisco is",
"logprobs":null,
"finish_reason":"length","stop_reason":null,
"prompt_logprobs":null}],
"usage":{"prompt_tokens":5,"total_tokens":21,"completion_tokens":16}
}
curl -X POST -H "Content-Type: application/json" http://localhost:8000/v1/completions -d '{
"model": "meta-llama/Llama-3.1-405B-FP8",
"prompt": "Luxembourg is a"
}'
8. Making a mock chatbot (Advanced)
vllm has much more to offer, do not hesitate to check their rich documentation. Here, we just wante to highlight the fact that other pre-trained models can be easily tested with the provided script. If you wish to use another large model, just replace the environment variable HF_MODEL in the script launcher_vllm_multinode.sh
8.1 Example: Mock chatbot for a Mixtral model
In this example, we ran a server but this time using mistralai/Mixtral-8x7B-Instruct-v0.1.
The list of other models you can run with vllm can be found here.
To this end, just change the HF_MODEL variable in the bash script above:
export HF_MODEL="mistralai/Mixtral-8x22B-v0.1"
Once this is setup and that the SSH port forwarding is running to, you can easily make a small chatbot by running the following python code in another local terminal:
import gradio as gr
import requests
def chat_with_model(user_input, chat_history):
headers = {
"Content-Type": "application/json",
}
data = {
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": user_input}
]
}
response = requests.post("http://localhost:8000/v1/chat/completions", headers=headers, json=data)
response_json = response.json()
assistant_message = response_json['choices'][0]['message']['content']
chat_history.append((user_input, assistant_message))
return chat_history, chat_history
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
with gr.Row():
txt = gr.Textbox(show_label=False, container=False, placeholder="Type your prompt here...")
txt.submit(chat_with_model, [txt, chatbot], [chatbot, chatbot])
demo.launch()
Here we called this script chatbot.py.
$ python3 chatbot.py
...
* Running on local URL: http://127.0.0.1:7860
When opening the provided URL in a web client, you will an interface to converse with the model as shown below.

Note that the tasks supported vary depending on the model in use. For detailed information, please visit http://127.0.0.1:8000/docs to review the expected syntax for interacting with the server API.