My First CUDA Kernel: Learning GPU Programming from Scratch

Author

The Fire Hacker

Published

January 16, 2025

The Beginning: Why Learn CUDA?

Today I ran my first custom GPU code! Along with distributed pre-training AI models, I wanted to understand what’s happening at the low level. I was looking for a good resource to learn about kernel development for both inference & training. My interests were in MoE routing kernels, however I decided to start simple: compile and run kernels on local GPUs. I have a few gaming laptops and decided to run kernels on them.

I decided to start with the simplest possible program: adding two vectors together. GPU Mode’s reference kernels are perfect to build a working end-to-end workflow.

The Program: Overall Architecture

Before diving into the details, let me explain the complete program flow. We’ll be working with the vectoradd_py problem from the reference-kernels repository.

File Structure & Purpose

problems/pmpp/vectoradd_py/
├── run_local.py        # Main benchmark script
├── submission.py       # Our custom CUDA kernel implementation
├── reference.py        # Correctness checking
├── task.py            # Data structure definitions
└── README.md          # Problem description

The Execution Flow:

  1. run_local.py - The orchestrator that:
    • Generates test data of various sizes
    • Calls our custom kernel
    • Measures performance with CUDA events
    • Verifies correctness against reference implementation
  2. submission.py - Contains our CUDA kernel using PyTorch’s inline compilation:
    • CUDA C++ code written as Python strings
    • Compiled on-the-fly using load_inline
    • Creates a Python module we can call
  3. The Magic: JIT Compilation Process

When we use torch.utils.cpp_extension.load_inline, here’s what happens behind the scenes:

Your Python Code
        ↓
torch.utils.cpp_extension.load_inline
        ↓
Generates .cpp and .cu source files
        ↓
Writes build.ninja file
        ↓
ninja → nvcc/g++ compile → add_cuda.so
        ↓
dlopen() loads .so into Python process

This is ahead-of-time compilation - once compiled, your kernel is fixed machine code running directly on the GPU!

The Challenge: What Was I Trying to Achieve?

My goal was simple but specific: 1. Write actual CUDA code that runs on my RTX 2050 laptop GPU 2. Understand how thousands of threads work together 3. Measure real performance and understand the numbers 4. Learn why GPUs are so powerful for AI workloads

I found the perfect learning resource which I modified for my use: GPU Mode’s reference-kernels repository fork. It’s a collection of progressively harder GPU programming challenges, starting with vector addition.

The Big Picture: How GPUs Think Differently

Before diving into code, here’s the mental shift that changed everything for me:

CPU Thinking: “Do step 1, then step 2, then step 3…” GPU Thinking: “Do ALL the steps at once, everywhere!”

Imagine you need to paint 1000 fence posts. A CPU is like one very fast painter who paints each post perfectly, one after another. A GPU is like hiring 1000 amateur painters who each paint one post simultaneously. Even if each painter is slower, getting all posts done at once is way faster!

For vector addition (C = A + B), instead of:

for i in range(million):
    C[i] = A[i] + B[i]  # One at a time

The GPU does:

Thread 0: C[0] = A[0] + B[0]
Thread 1: C[1] = A[1] + B[1]
Thread 2: C[2] = A[2] + B[2]
... (all at the same time!)
Thread 999999: C[999999] = A[999999] + B[999999]

Setting Up: The Journey to “Hello GPU”

Getting CUDA working on Windows with WSL2 was an adventure. Here’s what actually worked:

Step 1: Install CUDA Toolkit

First, I needed the CUDA compiler (nvcc) to turn my code into GPU instructions:

# Get NVIDIA's official CUDA repository
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y cuda-toolkit-12-1

# Tell the system where CUDA lives
export CUDA_HOME=/usr/local/cuda-12.1
export PATH=$CUDA_HOME/bin:$PATH

Step 2: PyTorch with CUDA Support

PyTorch makes it easy to compile CUDA code on-the-fly:

pip install --index-url https://download.pytorch.org/whl/cu121 torch

The Code: Understanding Every Line

Now for the exciting part - the actual GPU code! Let me explain what each piece does and why it matters.

The GPU Kernel: Where the Magic Happens

template <typename scalar_t>
__global__ void add_kernel(const scalar_t* __restrict__ A,
                           const scalar_t* __restrict__ B,
                           scalar_t* __restrict__ C,
                           int N) {
    // Who am I? Calculate my unique ID among thousands of threads
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    // Am I responsible for a valid element?
    if (idx < N) {
        C[idx] = A[idx] + B[idx];  // Do my one simple job
    }
}

What’s happening here:

  • __global__: This function runs on the GPU. It’s called from the CPU but executes on thousands of GPU cores simultaneously.

  • Thread Identity Crisis (solved!): Each thread needs to know which element to process. Think of it like a massive factory where each worker needs to know which item on the conveyor belt is theirs:

    • threadIdx.x: “I’m worker #5 in my team”
    • blockIdx.x: “My team is team #3”
    • blockDim.x: “Each team has 256 workers”
    • So my global position is: 3 * 256 + 5 = 773 - I handle element 773!
  • if (idx < N): Safety first! We might launch more threads than we have data (for efficiency reasons), so each thread checks if it has real work to do.

Launching the Kernel: Mission Control

torch::Tensor add_cuda(torch::Tensor A, torch::Tensor B) {
    int N = A.numel();  // How many elements total?
    auto C = torch::empty_like(A);  // Prepare output space

    // Configure the thread army
    const int threads = 256;  // Threads per block (team size)
    const int blocks = (N + threads - 1) / threads;  // How many teams needed?

    // LAUNCH! Send thousands of threads to work
    add_kernel<scalar_t><<<blocks, threads>>>(
        A.data_ptr<scalar_t>(),
        B.data_ptr<scalar_t>(),
        C.data_ptr<scalar_t>(),
        N
    );
}

The Strategy: - We organize threads into blocks (teams) of 256 threads each - Why 256? It’s a multiple of 32 (warp size - the GPU’s natural execution unit) - The <<<blocks, threads>>> syntax is CUDA’s special way to say “launch this many blocks with this many threads each”

The Python Bridge: Making it Usable

PyTorch’s load_inline is brilliant - it compiles CUDA code on-the-fly:

add_module = load_inline(
    name='add_cuda',
    cpp_sources=add_cpp_source,
    cuda_sources=add_cuda_source,
    functions=['add_cuda'],
    verbose=True,  # Show me what's happening!
)

First time you run this, you’ll see:

Detected CUDA files, patching ldflags
Building extension module add_cuda...
ninja: no work to do.
Loading extension module add_cuda...

That’s nvcc compiling your GPU code into a Python module!

The Benchmarking: Measuring Reality

The run_local.py script does something clever - it automatically picks test sizes based on available GPU memory:

# How much GPU memory is free?
free_bytes, _ = torch.cuda.mem_get_info()
budget = int(free_bytes * 0.8)  # Use 80% to be safe

# For 2D matrices: need space for A, B, and C
s_max = int(math.sqrt(budget / (3 * bytes_per_elem)))

This prevents the dreaded “CUDA out of memory” error!

Timing GPU Code: It’s Tricky!

You can’t use regular Python timing for GPU code because GPU operations are asynchronous. The solution? CUDA Events:

t0 = torch.cuda.Event(enable_timing=True)
t1 = torch.cuda.Event(enable_timing=True)

t0.record()              # Start timer ON THE GPU
custom_kernel(case)      # Run kernel
t1.record()              # Stop timer ON THE GPU
torch.cuda.synchronize() # Wait for GPU to finish
elapsed = t0.elapsed_time(t1)  # Get time in milliseconds

The Results: What I Learned from the Numbers

Running on my RTX 2050 (4GB VRAM):

size=4096:  mean=0.9877 ms    (67 million elements)
size=8192:  mean=3.8027 ms    (268 million elements)
size=12288: mean=8.5460 ms    (603 million elements)
size=16384: mean=150.3063 ms  (1.07 billion elements) ← WHAT?!

The Mystery of the Slow 16384

Why did 16384×16384 suddenly become 17x slower? This taught me a crucial lesson about GPU architecture:

The Problem: With 256 threads per block, processing 268,435,456 elements needs 1,048,576 blocks!

The GPU scheduler choked trying to manage over a million tiny work units. It’s like trying to manage a million separate construction crews for a project - the coordination overhead kills you!

The Solution: Grid-stride loops - have each thread process multiple elements:

for (int idx = blockIdx.x * blockDim.x + threadIdx.x;
     idx < N;
     idx += blockDim.x * gridDim.x) {
    C[idx] = A[idx] + B[idx];
}

Now you can cap blocks at a reasonable number (like 10,000) and each thread handles multiple elements.

Memory Bandwidth: The Real Bottleneck

For the 8192×8192 case: - Data moved: 268M elements × 2 bytes × 3 arrays = 1.6 GB - Time: 3.8 ms - Bandwidth: 421 GB/s

My RTX 2050’s theoretical max is ~200 GB/s, so we’re doing great! Wait, how are we exceeding theoretical max? Cache! Some data gets reused from the GPU’s L2 cache.

The Revelations: What Changed My Understanding

  1. GPUs are not fast CPUs - They’re a completely different beast. They’re terrible at complex branching logic but amazing at doing the same simple thing everywhere.

  2. Memory movement dominates - For simple operations like addition, you spend more time moving data than computing. This is why AI models use operations like matrix multiplication that do lots of compute per memory access.

  3. Launch configuration matters hugely - Too many blocks? Scheduling overhead. Too few? Underutilization. It’s an art.

  4. The power of parallel thinking - Once you start thinking “what can happen simultaneously?” instead of “what comes next?”, you see opportunities everywhere.

What’s Next?

Now that I’ve got basic kernels working, I’m excited to explore: - Shared memory: Using the 48KB of ultra-fast memory shared within each block - Warp-level operations: Leveraging the fact that 32 threads execute in lockstep - Reduction operations: How do you sum a billion numbers in parallel? - Matrix multiplication: The operation that powers all of deep learning

Resources That Helped Me

The Journey Continues

Starting with vector addition might seem trivial, but it opened the door to understanding how modern AI actually works at the hardware level. Every transformer model, every diffusion model, every neural network - they’re all built on these fundamental parallel operations.

The moment it clicked that my GPU was running 65,536 threads simultaneously, each doing their tiny part of the work, was magical. It’s not just faster computing - it’s a fundamentally different way of solving problems.

Next week: I’m going to tackle matrix multiplication and see if I can beat PyTorch’s built-in implementation (spoiler: probably not, but I’ll learn tons trying!).


Want to try this yourself? Clone the reference-kernels repo and start with problems/pmpp/vectoradd_py/. The journey from CPU thinking to GPU thinking is worth it!