|
| 1 | +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +## Usage: pip install "cuda-core[cu12]" |
| 6 | +## python python_example.py |
| 7 | +import sys |
| 8 | + |
| 9 | +import torch |
| 10 | + |
| 11 | +from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch |
| 12 | + |
| 13 | +# SAXPY kernel - passing a as a pointer to avoid any type issues |
| 14 | +code = """ |
| 15 | +template<typename T> |
| 16 | +__global__ void saxpy_kernel(const T* a, const T* x, const T* y, T* out, size_t N) { |
| 17 | + const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; |
| 18 | + if (tid < N) { |
| 19 | + // Dereference a to get the scalar value |
| 20 | + out[tid] = (*a) * x[tid] + y[tid]; |
| 21 | + } |
| 22 | +} |
| 23 | +""" |
| 24 | + |
| 25 | +dev = Device() |
| 26 | +dev.set_current() |
| 27 | + |
| 28 | +# Get PyTorch's current stream |
| 29 | +pt_stream = torch.cuda.current_stream() |
| 30 | +print(f"PyTorch stream: {pt_stream}") |
| 31 | + |
| 32 | + |
| 33 | +# Create a wrapper class that implements __cuda_stream__ |
| 34 | +class PyTorchStreamWrapper: |
| 35 | + def __init__(self, pt_stream): |
| 36 | + self.pt_stream = pt_stream |
| 37 | + |
| 38 | + def __cuda_stream__(self): |
| 39 | + stream_id = self.pt_stream.cuda_stream |
| 40 | + return (0, stream_id) # Return format required by CUDA Python |
| 41 | + |
| 42 | + |
| 43 | +s = PyTorchStreamWrapper(pt_stream) |
| 44 | + |
| 45 | +# prepare program |
| 46 | +arch = "".join(f"{i}" for i in dev.compute_capability) |
| 47 | +program_options = ProgramOptions(std="c++11", arch=f"sm_{arch}") |
| 48 | +prog = Program(code, code_type="c++", options=program_options) |
| 49 | +mod = prog.compile( |
| 50 | + "cubin", |
| 51 | + logs=sys.stdout, |
| 52 | + name_expressions=("saxpy_kernel<float>", "saxpy_kernel<double>"), |
| 53 | +) |
| 54 | + |
| 55 | +# Run in single precision |
| 56 | +ker = mod.get_kernel("saxpy_kernel<float>") |
| 57 | +dtype = torch.float32 |
| 58 | + |
| 59 | +# prepare input/output |
| 60 | +size = 64 |
| 61 | +# Use a single element tensor for 'a' |
| 62 | +a = torch.tensor([10.0], dtype=dtype, device="cuda") |
| 63 | +x = torch.rand(size, dtype=dtype, device="cuda") |
| 64 | +y = torch.rand(size, dtype=dtype, device="cuda") |
| 65 | +out = torch.empty_like(x) |
| 66 | + |
| 67 | +# prepare launch |
| 68 | +block = 32 |
| 69 | +grid = int((size + block - 1) // block) |
| 70 | +config = LaunchConfig(grid=grid, block=block) |
| 71 | +ker_args = (a.data_ptr(), x.data_ptr(), y.data_ptr(), out.data_ptr(), size) |
| 72 | + |
| 73 | +# launch kernel on our stream |
| 74 | +launch(s, config, ker, *ker_args) |
| 75 | + |
| 76 | +# check result |
| 77 | +assert torch.allclose(out, a.item() * x + y) |
| 78 | +print("Single precision test passed!") |
| 79 | + |
| 80 | +# let's repeat again with double precision |
| 81 | +ker = mod.get_kernel("saxpy_kernel<double>") |
| 82 | +dtype = torch.float64 |
| 83 | + |
| 84 | +# prepare input |
| 85 | +size = 128 |
| 86 | +# Use a single element tensor for 'a' |
| 87 | +a = torch.tensor([42.0], dtype=dtype, device="cuda") |
| 88 | +x = torch.rand(size, dtype=dtype, device="cuda") |
| 89 | +y = torch.rand(size, dtype=dtype, device="cuda") |
| 90 | + |
| 91 | +# prepare output |
| 92 | +out = torch.empty_like(x) |
| 93 | + |
| 94 | +# prepare launch |
| 95 | +block = 64 |
| 96 | +grid = int((size + block - 1) // block) |
| 97 | +config = LaunchConfig(grid=grid, block=block) |
| 98 | +ker_args = (a.data_ptr(), x.data_ptr(), y.data_ptr(), out.data_ptr(), size) |
| 99 | + |
| 100 | +# launch kernel on PyTorch's stream |
| 101 | +launch(s, config, ker, *ker_args) |
| 102 | + |
| 103 | +# check result |
| 104 | +assert torch.allclose(out, a * x + y) |
| 105 | +print("Double precision test passed!") |
| 106 | +print("All tests passed successfully!") |
0 commit comments