Skip to content

Commit 98ea688

Browse files
Migrate extension-cpp to stable API/ABI
1 parent 0ec4969 commit 98ea688

File tree

6 files changed

+238
-108
lines changed

6 files changed

+238
-108
lines changed

.github/scripts/setup-env.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ pip install --progress-bar=off -r requirements.txt
101101
echo '::endgroup::'
102102

103103
echo '::group::Install extension-cpp'
104-
python setup.py develop
104+
pip install -e . --no-build-isolation
105105
echo '::endgroup::'
106106

107107
echo '::group::Collect environment information'

README.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
# C++/CUDA Extensions in PyTorch
1+
# C++/CUDA Extensions in PyTorch with LibTorch Stable ABI
2+
3+
An example of writing a C++/CUDA extension for PyTorch using the **LibTorch Stable ABI**.
4+
See [here](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html) for the accompanying tutorial.
25

3-
An example of writing a C++/CUDA extension for PyTorch. See
4-
[here](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html) for the accompanying tutorial.
56
This repo demonstrates how to write an example `extension_cpp.ops.mymuladd`
6-
custom op that has both custom CPU and CUDA kernels.
7+
custom op that has both custom CPU and CUDA kernels, with cross-version
8+
compatibility using the stable ABI.
9+
10+
The examples in this repo work with PyTorch 2.10+.
711

8-
The examples in this repo work with PyTorch 2.4+.
912

1013
To build:
1114
```

extension_cpp/csrc/cuda/muladd.cu

Lines changed: 118 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
#include <ATen/Operators.h>
2-
#include <torch/all.h>
3-
#include <torch/library.h>
1+
// LibTorch Stable ABI version of CUDA custom operators
2+
// This file uses the stable API for cross-version compatibility.
3+
// See: https://pytorch.org/docs/main/notes/libtorch_stable_abi.html
4+
5+
#include <torch/csrc/stable/library.h>
6+
#include <torch/csrc/stable/ops.h>
7+
#include <torch/csrc/stable/tensor.h>
8+
#include <torch/csrc/stable/accelerator.h>
9+
#include <torch/headeronly/core/ScalarType.h>
10+
#include <torch/headeronly/macros/Macros.h>
11+
12+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
413

514
#include <cuda.h>
615
#include <cuda_runtime.h>
7-
#include <ATen/cuda/CUDAContext.h>
816

917
namespace extension_cpp {
1018

@@ -13,21 +21,39 @@ __global__ void muladd_kernel(int numel, const float* a, const float* b, float c
1321
if (idx < numel) result[idx] = a[idx] * b[idx] + c;
1422
}
1523

16-
at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
17-
TORCH_CHECK(a.sizes() == b.sizes());
18-
TORCH_CHECK(a.dtype() == at::kFloat);
19-
TORCH_CHECK(b.dtype() == at::kFloat);
20-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
21-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
22-
at::Tensor a_contig = a.contiguous();
23-
at::Tensor b_contig = b.contiguous();
24-
at::Tensor result = at::empty(a_contig.sizes(), a_contig.options());
25-
const float* a_ptr = a_contig.data_ptr<float>();
26-
const float* b_ptr = b_contig.data_ptr<float>();
27-
float* result_ptr = result.data_ptr<float>();
24+
torch::stable::Tensor mymuladd_cuda(
25+
const torch::stable::Tensor& a,
26+
const torch::stable::Tensor& b,
27+
double c) {
28+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()), "Tensor sizes must match");
29+
STD_TORCH_CHECK(
30+
a.scalar_type() == torch::headeronly::ScalarType::Float,
31+
"Input tensor a must be float32");
32+
STD_TORCH_CHECK(
33+
b.scalar_type() == torch::headeronly::ScalarType::Float,
34+
"Input tensor b must be float32");
35+
STD_TORCH_CHECK(
36+
a.device().type() == torch::headeronly::DeviceType::CUDA,
37+
"Input tensor a must be on CUDA");
38+
STD_TORCH_CHECK(
39+
b.device().type() == torch::headeronly::DeviceType::CUDA,
40+
"Input tensor b must be on CUDA");
41+
42+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
43+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
44+
torch::stable::Tensor result = torch::stable::empty_like(a_contig);
45+
46+
const float* a_ptr = a_contig.const_data_ptr<float>();
47+
const float* b_ptr = b_contig.const_data_ptr<float>();
48+
float* result_ptr = result.mutable_data_ptr<float>();
2849

2950
int numel = a_contig.numel();
30-
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
51+
52+
void* stream_ptr = nullptr;
53+
TORCH_ERROR_CODE_CHECK(
54+
aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr));
55+
cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);
56+
3157
muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr);
3258
return result;
3359
}
@@ -37,20 +63,38 @@ __global__ void mul_kernel(int numel, const float* a, const float* b, float* res
3763
if (idx < numel) result[idx] = a[idx] * b[idx];
3864
}
3965

40-
at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) {
41-
TORCH_CHECK(a.sizes() == b.sizes());
42-
TORCH_CHECK(a.dtype() == at::kFloat);
43-
TORCH_CHECK(b.dtype() == at::kFloat);
44-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
45-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
46-
at::Tensor a_contig = a.contiguous();
47-
at::Tensor b_contig = b.contiguous();
48-
at::Tensor result = at::empty(a_contig.sizes(), a_contig.options());
49-
const float* a_ptr = a_contig.data_ptr<float>();
50-
const float* b_ptr = b_contig.data_ptr<float>();
51-
float* result_ptr = result.data_ptr<float>();
66+
torch::stable::Tensor mymul_cuda(
67+
const torch::stable::Tensor& a,
68+
const torch::stable::Tensor& b) {
69+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()), "Tensor sizes must match");
70+
STD_TORCH_CHECK(
71+
a.scalar_type() == torch::headeronly::ScalarType::Float,
72+
"Input tensor a must be float32");
73+
STD_TORCH_CHECK(
74+
b.scalar_type() == torch::headeronly::ScalarType::Float,
75+
"Input tensor b must be float32");
76+
STD_TORCH_CHECK(
77+
a.device().type() == torch::headeronly::DeviceType::CUDA,
78+
"Input tensor a must be on CUDA");
79+
STD_TORCH_CHECK(
80+
b.device().type() == torch::headeronly::DeviceType::CUDA,
81+
"Input tensor b must be on CUDA");
82+
83+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
84+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
85+
torch::stable::Tensor result = torch::stable::empty_like(a_contig);
86+
87+
const float* a_ptr = a_contig.const_data_ptr<float>();
88+
const float* b_ptr = b_contig.const_data_ptr<float>();
89+
float* result_ptr = result.mutable_data_ptr<float>();
90+
5291
int numel = a_contig.numel();
53-
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
92+
93+
void* stream_ptr = nullptr;
94+
TORCH_ERROR_CODE_CHECK(
95+
aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr));
96+
cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);
97+
5498
mul_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr);
5599
return result;
56100
}
@@ -60,32 +104,55 @@ __global__ void add_kernel(int numel, const float* a, const float* b, float* res
60104
if (idx < numel) result[idx] = a[idx] + b[idx];
61105
}
62106

63-
void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
64-
TORCH_CHECK(a.sizes() == b.sizes());
65-
TORCH_CHECK(b.sizes() == out.sizes());
66-
TORCH_CHECK(a.dtype() == at::kFloat);
67-
TORCH_CHECK(b.dtype() == at::kFloat);
68-
TORCH_CHECK(out.dtype() == at::kFloat);
69-
TORCH_CHECK(out.is_contiguous());
70-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
71-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
72-
TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CUDA);
73-
at::Tensor a_contig = a.contiguous();
74-
at::Tensor b_contig = b.contiguous();
75-
const float* a_ptr = a_contig.data_ptr<float>();
76-
const float* b_ptr = b_contig.data_ptr<float>();
77-
float* result_ptr = out.data_ptr<float>();
107+
// An example of an operator that mutates one of its inputs.
108+
void myadd_out_cuda(
109+
const torch::stable::Tensor& a,
110+
const torch::stable::Tensor& b,
111+
torch::stable::Tensor& out) {
112+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()), "Tensor sizes must match");
113+
STD_TORCH_CHECK(b.sizes().equals(out.sizes()), "Output tensor size must match inputs");
114+
STD_TORCH_CHECK(
115+
a.scalar_type() == torch::headeronly::ScalarType::Float,
116+
"Input tensor a must be float32");
117+
STD_TORCH_CHECK(
118+
b.scalar_type() == torch::headeronly::ScalarType::Float,
119+
"Input tensor b must be float32");
120+
STD_TORCH_CHECK(
121+
out.scalar_type() == torch::headeronly::ScalarType::Float,
122+
"Output tensor must be float32");
123+
STD_TORCH_CHECK(out.is_contiguous(), "Output tensor must be contiguous");
124+
STD_TORCH_CHECK(
125+
a.device().type() == torch::headeronly::DeviceType::CUDA,
126+
"Input tensor a must be on CUDA");
127+
STD_TORCH_CHECK(
128+
b.device().type() == torch::headeronly::DeviceType::CUDA,
129+
"Input tensor b must be on CUDA");
130+
STD_TORCH_CHECK(
131+
out.device().type() == torch::headeronly::DeviceType::CUDA,
132+
"Output tensor must be on CUDA");
133+
134+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
135+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
136+
137+
const float* a_ptr = a_contig.const_data_ptr<float>();
138+
const float* b_ptr = b_contig.const_data_ptr<float>();
139+
float* result_ptr = out.mutable_data_ptr<float>();
140+
78141
int numel = a_contig.numel();
79-
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
142+
143+
void* stream_ptr = nullptr;
144+
TORCH_ERROR_CODE_CHECK(
145+
aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr));
146+
cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);
147+
80148
add_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr);
81149
}
82150

83-
84151
// Registers CUDA implementations for mymuladd, mymul, myadd_out
85-
TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
86-
m.impl("mymuladd", &mymuladd_cuda);
87-
m.impl("mymul", &mymul_cuda);
88-
m.impl("myadd_out", &myadd_out_cuda);
152+
STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
153+
m.impl("mymuladd", TORCH_BOX(&mymuladd_cuda));
154+
m.impl("mymul", TORCH_BOX(&mymul_cuda));
155+
m.impl("myadd_out", TORCH_BOX(&myadd_out_cuda));
89156
}
90157

91158
}

0 commit comments

Comments
 (0)