PyTorch Profiling with Nvidia Nsight
Overview
This tutorial demonstrates how to profile and optimize PyTorch Code with Nvidia Nsight.
Learning Objectives
You will learn how to:
- Compile a .nsys-rep file with Nsight Systems
- Different ways to analyze .nsys-rep files using UI and CLI
- Quick wins for profiling (e.g. NVTX markers)
- How to interpret .nsys-rep file
- Common ways to optimize PyTorch code
1. Nvidia Profiling Background
1.1. Why Profiling?
Profiling is the process of measuring where your code spends time.
You can use these measures to optimize your code which will shift the current bottleneck over time.
Using best practices, your code should be moderately efficient to begin with.
However, to find your program's best configuration, it will require experimentation due to the unique fine-tuning needs of every program.
Profiling assists your next experiment and provides you with detailed and highly customizable feedback in order to fine-tune your code beyond general guidelines.
1.2. Profiling GPU Code

GPU profiling differs from CPU profiling because of asynchronous execution. When you call a CUDA operation in PyTorch, the CPU queues the work and immediately continues. This means CPU timers can be misleading.
Without GPU-aware profiling tools, your training loop might look fast on the CPU side while the GPU sits idle half the time.
Sometimes data loading blocks, other times it could be inefficient memory allocation.
More concretely, sending too much data to the GPU at once may reduce the waiting time for the GPU, but if you are too generous with the transferred amount, your GPU could run out of memory, halting your program.
Here are some useful metrics to observe while Profiling GPU Code:
- Kernel execution time: How long each CUDA kernel actually runs on the GPU
- Memory transfers: CPU ↔ GPU data movement (often the hidden bottleneck)
- GPU utilization: Are you keeping the GPU busy, or is it waiting for data?
- Concurrency: Can multiple operations overlap, or are they serialized?
1.3 Nvidia Nsight Systems & Nvidia Nsight Compute
Nvidia provides two profiling tools:
Nsight Systems is your starting point. It captures a timeline of everything: CPU activity, GPU kernels, memory transfers, and framework operations (e.g. PyTorch, CUDA).
Nsight Compute is the deep-dive tool for individual GPU kernels. Once Nsight Systems tells you which kernel is slow, Nsight Compute tells you why: memory bandwidth, occupancy, warp stalls, cache hit rates. It's slow and detailed. Use it for targeted optimization of specific hotspots.
Rule of thumb: Start with Nsight Systems to find the slow regions, then focus on specific kernels with Nsight Compute only if needed.
2. How to use Nvidia Nsight System
2.1 How to generate .nsys-rep file (Nsight Systems Report file)
After profiling your python code, you receive both an intermediate .qdstrm file which is the raw data stream output of nsys profile and a final .nsys-rep file, which is more usable.
Both can be used in a CLI for analysis (e.g. using nsys analyze which is for experts or nsys stats which is an easy way to generate a series of summary or trace reports).
Nsys Analyze will give you important feedback like GPU gaps with precise timestamps and durations, and recommend you to use memory more efficiently as well as what is blocking the host. It gives you both the performance problems and the solutions.
2.1.1 Example generating Nsys-rep file from Resnet50 distributed PyTorch Code
To understand how to get an nsys-rep file for a distributed PyTorch application, check out the following which is a concrete example of using resnet50 with torchrun to interpret the python script.py. Note that you pass the torchrun outline to nsys profile. Additionally, you can wrap nsys profile around an interactive session (with srun ... nsys profile ...) to customize the allocated ressources for this specific command.
This Slurm batch will produce a .nsys-rep file, which you can then analyze in Nsight Systems UI or in the command line.
#!/bin/bash -l
#SBATCH --job-name="resnetNsysProfileOptimizations"
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=4
#SBATCH --output=%x%j.out
#SBATCH --error=%x%j.err
#SBATCH -p gpuv
#SBATCH --time=02:00:00
#SBATCH --account=u1xxxx
#SBATCH --disable-perfparanoid
#SBATCH --hint=nomultithread
export opt_name=test
export PROFDIR=$PWD
export TOTAL_EPOCHS=4
export BATCH_SIZE_PER_GPU=64
export NGPUS_PER_NODE=4
export NCCL_SOCKET_IFNAME=ib0
export NCCL_ASYNC_ERROR_HANDLING=1
export OMP_NUM_THREADS=8
nodes=($(scontrol show hostnames $SLURM_JOB_NODELIST))
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
export random_port=$(python getPort.py)
endpoint="${head_node_ip}:${random_port}"
CUDA_VISIBLE_DEVICES="0,1,2,3" srun --gpus=${NUM_GPUS} --kill-on-bad-exit=1 \
nsys profile \
--output="${PROFDIR}/profile_${opt_name}_multi_nodes_%h.%p" \
-t cuda,nvtx,osrt,cublas,cusparse \
--stats=true \
--capture-range=cudaProfilerApi \
--capture-range-end=stop \
--gpu-metrics-device=all \
--gpu-metrics-frequency=100 \
--gpu-metrics-set=ga100 \
--gpuctxsw=true \
torchrun \
--nnodes 2 \
--rdzv_id 10000 \
--rdzv_backend c10d \
--rdzv_endpoint $endpoint \
--log_dir log_torch_${opt_name} \
script.py ${TOTAL_EPOCHS} ${BATCH_SIZE_PER_GPU}
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import torch.distributed as dist
import os
from torchvision import models
import torch.cuda.nvtx as nvtx
import torch.cuda.profiler as profiler
def ddp_setup():
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
world_size = dist.get_world_size()
rank = dist.get_rank()
return world_size, rank
class Trainer:
def __init__(self, model: torch.nn.Module, train_data: DataLoader, optimizer: torch.optim.Optimizer) -> None:
self.local_rank = int(os.environ["LOCAL_RANK"])
self.global_rank = int(os.environ["RANK"])
self.model = model.to(self.local_rank)
self.train_data = train_data
self.optimizer = optimizer
self.model = DDP(self.model, device_ids=[self.local_rank])
def _run_batch(self, source, targets):
nvtx.range_push("Optimizer Zero Grad")
self.optimizer.zero_grad()
nvtx.range_pop()
nvtx.range_push("Forward Pass")
output = self.model(source)
nvtx.range_pop()
nvtx.range_push("Loss Computation")
loss = F.cross_entropy(output, targets)
nvtx.range_pop()
nvtx.range_push("Backward Pass")
loss.backward()
nvtx.range_pop()
nvtx.range_push("Optimizer Step")
self.optimizer.step()
nvtx.range_pop()
def _run_epoch(self, epoch):
print(f"[GPU{self.global_rank}] Epoch {epoch} | Steps: {len(self.train_data)}")
self.train_data.sampler.set_epoch(epoch)
nvtx.range_push(f"Epoch {epoch} Start")
for step, (source, targets) in enumerate(self.train_data):
nvtx.range_push(f"Step {step} Start")
source = source.to(self.local_rank)
targets = targets.to(self.local_rank)
self._run_batch(source, targets)
nvtx.range_pop() # Step end
nvtx.range_pop()
def train(self, max_epochs: int):
nvtx.range_push(f"Entire training ")
for epoch in range(max_epochs):
nvtx.range_push(f"Training Epoch {epoch}")
self._run_epoch(epoch)
nvtx.range_pop()
nvtx.range_pop()
def load_train_objs(num_epochs, batch_size_per_gpu, world_size):
train_set = SyntheticDataset(num_epochs, batch_size_per_gpu, world_size)
model = models.resnet50(pretrained=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
return train_set, model, optimizer
def prepare_dataloader(dataset: Dataset, batch_size: int):
return DataLoader(
dataset,
batch_size=batch_size,
pin_memory=True,
shuffle=False,
sampler=DistributedSampler(dataset),
num_workers=2
)
def main(total_epochs: int, batch_size: int):
profiler.start()
world_size, rank = ddp_setup()
batch_size_per_gpu = batch_size // world_size
dataset, model, optimizer = load_train_objs(total_epochs, batch_size_per_gpu, world_size)
train_data = prepare_dataloader(dataset, batch_size_per_gpu)
trainer = Trainer(model, train_data, optimizer)
trainer.train(total_epochs)
print('Will now destroy the process group\n')
destroy_process_group()
print('Process group destroyed')
profiler.stop()
import sys
sys.exit(0)
import random
class SyntheticDataset(Dataset):
def __init__(self, num_epochs, batch_size_per_gpu, world_size):
self.num_epochs = num_epochs
self.batch_size_per_gpu = batch_size_per_gpu
self.num_iters = world_size
self.len = 10000
def __getitem__(self, idx):
data = torch.randn(3, 224, 224)
target = random.randint(0, 999)
return (data, target)
def __len__(self):
return self.len
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='simple distributed training job')
parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
parser.add_argument('batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
args = parser.parse_args()
main(args.total_epochs, args.batch_size)
import socket
def find_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0)) # Bind to a free port provided by the host.
return s.getsockname()[1] # Return the port number assigned.
if __name__ == "__main__":
print(find_free_port())
2.1.2 What Nsys Analyze Output looks like
Assuming that you managed to get an nsys-rep file as output from profiling your PyTorch code, you can use the command nsys analyze yourfile.nsys-rep and expect an output similar to the following:
CUDA Async Memcpy with Pageable Memory (cuda_memcpy_async):
The following APIs use PAGEABLE memory which causes asynchronous CUDA memcpy
operations to block and be executed synchronously. This leads to low GPU
utilization.
Suggestion: If applicable, use PINNED memory instead.
Duration(ns) Start(ns) Src Kind Dst Kind Bytes (MB) PID Device ID Context ID Green Context ID Stream ID API Name
4,415 30,271,927,890 Device Pageable 0.001 1,358 0 1 7 cudaMemcpyAsync_v3020
2,944 30,272,062,545 Device Pageable 0.000 1,358 0 1 7 cudaMemcpyAsync_v3020
1,951 30,270,512,088 Pageable Device 0.001 1,358 0 1 7 cudaMemcpyAsync_v3020
1,344 30,271,992,465 Pageable Device 0.000 1,358 0 1 7 udaMemcpyAsync_v3020
CUDA Synchronous Memcpy (cuda_memcpy_sync):
There were no problems detected related to synchronous memcpy operations.
Note
As you can see, it also tells you areas in which there are no problems which saves you time and effort.
GPU Gaps (gpu_gaps):
The following are ranges where a GPU is idle for more than 500ms. Addressing
these gaps might improve application performance.
Suggestions:
1. Use CPU sampling data, OS Runtime blocked state backtraces, and/or OS
Runtime APIs related to thread synchronization to understand if a sluggish or
blocked CPU is causing the gaps.
2. Add NVTX annotations to CPU code to understand the reason behind the gaps.
Row# Duration (ns) Start (ns) PID Device ID Context ID
1 7,166,688,453 294,751,474,402 1,358 0 1
2 6,850,921,134 428,359,969,308 1,358 0 1
3 6,613,978,952 13,479,072,942 1,358 0 1
Important
Here it tells you about significant gpu gaps and gives you an intuition about how to approach reducing these gaps.
2.1.3 What Nsys Stats Output looks like
You get this output by doing the same thing as for analyze except this time you use stats nsys stats yourfile.nsys-rep.
Your output will consist of function calls including your manually injected nvtx calls around profiled instructions, OS runtime calls, Cuda API calls and other profiling tracks. The output is verbose and predictable, it doesn't give you suggestions, only function call timings and durations.
Here is an NVTX cutout of an example from the resnet50 application.
NVTX Range Summary (nvtx_sum):
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Style Range
18.5 577,432,614,920 1 577,432,614,920.0 577,432,614,920.0 577,432,614,920 577,432,614,920 0.0 PushPop :Entire training
9.3 89,430,857,837 2 44,715,428,918.5 144,715,428,918.5 140,980,403,844 148,450,453,993 5,282,123,116.2 PushPop :Training Epoch 3
9.2 286,144,540,046 2 143,072,270,023.0 143,072,270,023.0 139,452,922,756 146,691,617,290 5,118,529,991.9 PushPop :Epoch 3 Start
6.2 191,648,059,592 7,815 24,523,104.2 20,739,622.0 19,093,890 3,810,311,154 48,953,140.3 PushPop :Backward Pass
5.5 170,866,748,515 7,815 21,863,947.3 15,310,228.0 14,208,863 9,985,273,976 119,814,808.5 PushPop :Forward Pass
5.1 157,367,001,311 1 157,367,001,311.0 157,367,001,311.0 157,367,001,311 157,367,001,311 0.0 PushPop :Training Epoch 0
5.0 155,742,694,991 1 155,742,694,991.0 155,742,694,991.0 155,742,694,991 155,742,694,991 0.0 PushPop :Epoch 0 Start
4.5 139,941,070,858 1 139,941,070,858.0 139,941,070,858.0 139,941,070,858 139,941,070,858 0.0 PushPop :Training Epoch 2
4.5 139,144,084,287 1 139,144,084,287.0 139,144,084,287.0 139,144,084,287 139,144,084,287 0.0 PushPop :Training Epoch 1
4.4 138,457,001,555 1 138,457,001,555.0 138,457,001,555.0 138,457,001,555 138,457,001,555 0.0 PushPop :Epoch 2 Start
4.4 137,696,322,933 1 137,696,322,933.0 137,696,322,933.0 137,696,322,933 137,696,322,933 0.0 PushPop :Epoch 1 Start
0.5 15,255,037,717 5 3,051,007,543.4 427,632,138.0 89,534,776 14,078,480,421 6,167,859,568.7 PushPop :Step 0 Start
0.3 10,332,244,970 7,815 1,322,104.3 971,966.0 649,517 236,022,474 4,812,928.7 PushPop :Optimizer Step
The following is a cutout of the CUDA GPU Memory Operations Summary: In it you can clearly see the proportion of time and size of memory copy operations in different contexts, such as CPU(Host) to GPU(device), GPU to GPU, GPU to CPU. Even though this statistical data is easy to read and understand, it requires great expertise to infer actionable code optimizations from it. What you can clearly see is that the majority of time (66 percent) is spent doing memory operations, sending data from the CPU to GPU, which points to data-loading as a major time limiter. Looking at the ordering by size, you notice that memory setup which includes memory pinning, uses up a lot of the memory, even though it only accounts for 6.6 percent of the total time spent doing memory operations. Since pinning is an optimization technique, this points to an already efficient use of memory.
CUDA GPU MemOps Summary (by Time) (cuda_gpu_mem_time_sum):
Time (%) Total Time (ns) Count Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Operation
66.4 15,066,079,832 16,019 940,513.1 1,440.0 1,055 2,411,873 963,114.6 [CUDA memcpy Host-to-Device]
27.0 6,133,527,640 2,501,122 2,452.3 1,984.0 1,695 15,872 1,646.4 [CUDA memcpy Device-to-Device]
6.6 1,491,493,303 382,980 3,894.4 2,176.0 992 47,328 4,346.3 [CUDA memset]
0.0 13,343 4 3,335.8 3,184.0 2,880 4,095 534.0 [CUDA memcpy Device-to-Host]
CUDA GPU MemOps Summary (by Size) (cuda_gpu_mem_size_sum):
Total (MB) Count Avg (MB) Med (MB) Min (MB) Max (MB) StdDev (MB) Operation
1,444,665.052 382,980 3.772 0.001 0.000 67.109 7.713 [CUDA memset]
800,679.106 2,501,122 0.320 0.001 0.000 9.437 1.271 [CUDA memcpy Device-to-Device]
301,162.455 16,019 18.800 0.001 0.000 38.535 19.254 [CUDA memcpy Host-to-Device]
0.006 4 0.001 0.000 0.000 0.005 0.002 [CUDA memcpy Device-to-Host]
2.2 Tracking NVTX
We configured the profiler to track (-t=nvtx) nvtx function calls. NVTX is extremely useful for profiling, because you can set custom markers in the PyTorch code around instructions which will reappear in the nsight systems UI and tell you where those instructions spend time. You can stack these markers inside other markers, making it possible to view time spent by instructions within other instructions (e.g. how long optimization takes per step per epoch).
2.1.2 What else can be tracked ?
Besides NVTX, you also track common function calls, namely :
- osrt (Operating System Runtime calls)
- cuda (CUDA function calls)
- cublas (CUDA Basic Linear Algebra Subroutine Calls)
- cusparse (CUDA Sparse Matrix Computation Calls)
Besides these, there are many more, including metrics to measure mpi, nccl calls for detailed messaging protocol oversight.
2.1.3 Capture Range
Additionally, you get to decide the line of code from which the profiling starts and when it stops in the python script. For AI model scripts, the interesting part often is the training method. Narrowing the profiling scope allows you to save space in the nsys-rep file. NSight Systems can crash or lagg if too much data is collected. To solve for this, you do the following
import torch.cuda.profiler as profiler
and then in our main method, where you initialize the training method you set a start() and stop() signal as follows:
trainer = Trainer(model, train_data, optimizer)
profiler.start()
trainer.train(total_epochs)
dist.barrier()
profiler.stop()
We ackownledge the start and stop signal by setting capture-range equal to cudaProfilerApi.
2.1.4 GPU Metrics
To get detailed GPU Metrics for exactly our architecture, you set --gpu-metrics-device to all and define the targetted gpu architecture --gpu-metrics-set to ga100. Choose a probing frequency which suits your needs. It is recommended to start with 100-200 Hz (--gpu-metrics-frequency=100) to make the report file even smaller. If you want a greater level of detail, you can set it to a higher frequency.
2.2 How to set and view NVTX markers
The following code shows you how to set the nvtx markers around instructions. Again, note that you can stack markers around each other, because you want to be able to view multiple events happening simultaneously (e.g. multiple iterations of foward and backward passes in a epoch, multiple epochs during one training etc.)
import torch.cuda.nvtx as nvtx
def train(self, max_epochs: int):
nvtx.range_push(f"Entire training ")
for epoch in range(max_epochs):
nvtx.range_push(f"Epoch {epoch}")
self._run_epoch(epoch)
nvtx.range_pop()
nvtx.range_pop()
2.3 Visually interpreting the Nsys-Rep Example

This is a typical example of what a PyTorch nsys-report looks like in Nsight Systems when you just open it.
For each row you can view a type of activity.
2.3.1. CPU
The first row shows CPU activity, telling us that the CPU is very busy initially before the GPU starts.
2.3.2 GPU
Further below, you can see GPU metrics which give you detailed info such as SM Instructions, Warp Occupancy, mirroring the general GPU activity that can be viewed in a lower row titled CUDA HW. The gaps in CUDA HW indicate when the GPU is idle, something you want to minimize to train efficiently.
2.3.3 NVTX
How do you recognize at which point in time different PyTorch instructions are called? This is where NVTX markers become very useful. Looking at the NVTX row, you can see exactly which High-Level PyTorch instructions are called. In this case, the entire training was profiled and nothing else. You may also note that there is a huge difference in duration between the first epoch including overhead and the later epochs. Nsight Systems allows you to see exactly how much overhead there is and how it impacts the first, the last, and the central training epochs duration. By viewing the NVTX row during Epoch 1,2,3 side-by-side with the CUDA HW row, you notice that the GPU gaps happen at the end of every epoch.
2.3.4 Diagnosing Bottlenecks
It looks like the GPU is likely waiting for data from the CPU (Host). In this case, you would be advised to focus on the DataLoader's efficiency to make CPU-GPU communication smooth in an attempt to make these gaps smaller. Therefore, our recommended strategy is to try different DataLoader configurations (e.g. varying number of Data Workers, different batch sizes, different types of memory accesses) to see which one leads to the smallest GPU idle gaps.
2.3.5 Interacting with the Timeline
You can zoom inside the report to view an area of interest more closely which will also reveal previously hidden lower level function calls. Zoom using your mouse's scrollwheel and moving your cursor to where your focus is centered while doing so.
In this next screenshot, you see a zoomed area of interest, namely the GPU idle gap between 2 central training epochs. Note, that you can select a time slot with a horzontal drag and drop, appearing as a green bar extending every row. At the very top of the green bar, you can see the selected slot's duration in green ms, this precise measurement allows us to compare multiple runs effectively to determine which one produces minimal idle gaps.

2.3.5.1 Events View
There is a way to view all the recorded events from a row of the timeline in an interactive list. Right click on any row and click on Show in Events View, and Nsight Systems will display all of the events in a bottom Panel. In the image below, it was done with NVTX row.

2.3.6 Initial Overhead
The first epoch can increase in duration if we add optimizations like setting up pinned memory for efficient dataloading. With a small amount of epochs, the proportion of this first epoch can seem quite high.
Is it worth it? Absolutely it's worth it since in real use cases you would have many epochs and secondly, these optimizations are highly scalabe.
Once you increase the number of epochs to a useful amount, the proportion of the overhead would be insignificant compared to the whole training.
Nsys-Rep files show us the activity on a single compute node, but you can open multiple .nsys-rep files from different ranks in a single multi-report view.
This will merge the visual output and make it identical to viewing all the nodes at the same time.
To include the rank in the .nsys-rep output filename you can set the --output flag to the following value:
--output="${PROFDIR}/profile_${opt_name}_multi_nodes_%h.%p"
This will concatenate the host and partition name to the output filename, thereby generating a different nsys_rep file for each node.
2.4 Common Ways to optimize PyTorch
Let's explore some of the best practices for optimizing PyTorch code.
2.4.1. DataLoader
The DataLoader should be your optimization starting point since this is the most common bottleneck. In this scenario, the Host CPU(s) send data too inefficiently to the GPU(s) by default.
2.4.1.1. Distributed Data Sampler
By default, the dataloader uses a RandomSampler. You should set DistributedDataSampler in the DataLoader's constructor if you have a DD(Distributed Data Parallel) since it is required for correctness and prevents duplicated work, which improves effective throughput in multi-GPU training.
2.4.1.2. Num Workers Heuristics
One of the most important tweaks to the DataLoader is to set an appropriate amount of dataloder workers.
DataLoader uses only 1 worker by default, which you should always adjust to a value inbetween one and all minus one of the available physical CPU cores.
Why not use all of the available cores ? It is recommended to exclude at least 1 core which is free at all times to be used by the main thread. Moreover, sometimes fewer workers perform better, so start with 2-8 workers and empirically tune using profiling. Each Dataloader Worker will spawn a new process just for transferring data. Factors like resource contention and management overhead often make it inefficient to utilize every available CPU core simultaneously. The right amount requires a lot of experimentation and can result in dramatic improvements.
2.4.2. Memory Allocation
The second aspect is to optimize is memory allocation and memory calls.
2.4.2.1. Pinned / Non-pageable Memory
It is recommended to use pinned memory in the DataLoader so the loaded data cannot be removed from the RAM and become non-pageable. Pinned memory consists of page-locked system RAM that cannot be swapped out, enabling faster DMA transfers to the GPU.
2.4.2.2. Non-Blocking Memory Transfers
Note that when you use multiple DataWorkers in the DataLoader, you'd want to adjust each CPU-GPU memory transfer to be non-blocking in order to to enable asynchronous operations. Asynchronous operations are important both for avoiding blocking the main thread, and to allow the GPU to similatenously compute on one batch while the data for the next batch is being transferred over the PCIe bus.
Therefore, every .to() call should have (non_blocking = True) for maximum speed, but only when DataLoader has pin_memory set to True.
Note that some optimizations require additional optimizations to work in tandem, otherwise you could worsen your initial performance.
2.4.2.3 Batch Prefetching
Another effective way to reduce the amount of the time that the GPU waits for input data from the host is to set a prefetch factor in the DataLoader. Prefetching overlaps CPU preprocessing with GPU execution, reducing stalls. It increases the number of batches that exist simultaneously in memory, both on the CPU and potentially on the GPU. Workers prepare batches ahead of time which can cause RAM pressure, especially when using pinned memory. Pinned memory exhaustion is one of the most common causes of mysterious crashes when increasing prefetching. This is why prefetching should be introduced after baseline profiling.
3. Conclusion
Nvidia Nsight Systems provides both interactive visuals and targetted textual suggestions needed to understand what is happening behind the scenes of distributed workloads. This guide covers core techniques to optimize GPU workloads in a PyTorch context using Nsight and teaches basic Nsight usage to help get you started.