Skip to content

Commit 19df0d9

Browse files
authored
Merge pull request NVIDIA#579 from msaroufim/pytorch_example
PyTorch example
2 parents 5404764 + 7d3582f commit 19df0d9

File tree

2 files changed

+107
-1
lines changed

2 files changed

+107
-1
lines changed

cuda_core/examples/pytorch_example.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
## Usage: pip install "cuda-core[cu12]"
6+
## python python_example.py
7+
import sys
8+
9+
import torch
10+
11+
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch
12+
13+
# SAXPY kernel - passing a as a pointer to avoid any type issues
14+
code = """
15+
template<typename T>
16+
__global__ void saxpy_kernel(const T* a, const T* x, const T* y, T* out, size_t N) {
17+
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
18+
if (tid < N) {
19+
// Dereference a to get the scalar value
20+
out[tid] = (*a) * x[tid] + y[tid];
21+
}
22+
}
23+
"""
24+
25+
dev = Device()
26+
dev.set_current()
27+
28+
# Get PyTorch's current stream
29+
pt_stream = torch.cuda.current_stream()
30+
print(f"PyTorch stream: {pt_stream}")
31+
32+
33+
# Create a wrapper class that implements __cuda_stream__
34+
class PyTorchStreamWrapper:
35+
def __init__(self, pt_stream):
36+
self.pt_stream = pt_stream
37+
38+
def __cuda_stream__(self):
39+
stream_id = self.pt_stream.cuda_stream
40+
return (0, stream_id) # Return format required by CUDA Python
41+
42+
43+
s = PyTorchStreamWrapper(pt_stream)
44+
45+
# prepare program
46+
arch = "".join(f"{i}" for i in dev.compute_capability)
47+
program_options = ProgramOptions(std="c++11", arch=f"sm_{arch}")
48+
prog = Program(code, code_type="c++", options=program_options)
49+
mod = prog.compile(
50+
"cubin",
51+
logs=sys.stdout,
52+
name_expressions=("saxpy_kernel<float>", "saxpy_kernel<double>"),
53+
)
54+
55+
# Run in single precision
56+
ker = mod.get_kernel("saxpy_kernel<float>")
57+
dtype = torch.float32
58+
59+
# prepare input/output
60+
size = 64
61+
# Use a single element tensor for 'a'
62+
a = torch.tensor([10.0], dtype=dtype, device="cuda")
63+
x = torch.rand(size, dtype=dtype, device="cuda")
64+
y = torch.rand(size, dtype=dtype, device="cuda")
65+
out = torch.empty_like(x)
66+
67+
# prepare launch
68+
block = 32
69+
grid = int((size + block - 1) // block)
70+
config = LaunchConfig(grid=grid, block=block)
71+
ker_args = (a.data_ptr(), x.data_ptr(), y.data_ptr(), out.data_ptr(), size)
72+
73+
# launch kernel on our stream
74+
launch(s, config, ker, *ker_args)
75+
76+
# check result
77+
assert torch.allclose(out, a.item() * x + y)
78+
print("Single precision test passed!")
79+
80+
# let's repeat again with double precision
81+
ker = mod.get_kernel("saxpy_kernel<double>")
82+
dtype = torch.float64
83+
84+
# prepare input
85+
size = 128
86+
# Use a single element tensor for 'a'
87+
a = torch.tensor([42.0], dtype=dtype, device="cuda")
88+
x = torch.rand(size, dtype=dtype, device="cuda")
89+
y = torch.rand(size, dtype=dtype, device="cuda")
90+
91+
# prepare output
92+
out = torch.empty_like(x)
93+
94+
# prepare launch
95+
block = 64
96+
grid = int((size + block - 1) // block)
97+
config = LaunchConfig(grid=grid, block=block)
98+
ker_args = (a.data_ptr(), x.data_ptr(), y.data_ptr(), out.data_ptr(), size)
99+
100+
# launch kernel on PyTorch's stream
101+
launch(s, config, ker, *ker_args)
102+
103+
# check result
104+
assert torch.allclose(out, a * x + y)
105+
print("Double precision test passed!")
106+
print("All tests passed successfully!")

cuda_core/tests/example_tests/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def run_example(samples_path, filename, env=None):
3737
exec(script, env if env else {}) # nosec B102
3838
except ImportError as e:
3939
# for samples requiring any of optional dependencies
40-
for m in ("cupy",):
40+
for m in ("cupy", "torch"):
4141
if f"No module named '{m}'" in str(e):
4242
pytest.skip(f"{m} not installed, skipping related tests")
4343
break

0 commit comments

Comments
 (0)