Skip to content

Commit 479f1c2

Browse files
Update Custom C++ and CUDA Operators to use PyTorch stable API/ABI
1 parent 10eefc3 commit 479f1c2

File tree

1 file changed

+155
-62
lines changed

1 file changed

+155
-62
lines changed

advanced_source/cpp_custom_ops.rst

Lines changed: 155 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Custom C++ and CUDA Operators
1616
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
1717
:class-card: card-prerequisites
1818

19-
* PyTorch 2.4 or later
19+
* PyTorch 2.10 or later
2020
* Basic understanding of C++ and CUDA programming
2121

2222
.. note::
@@ -62,12 +62,19 @@ Using ``cpp_extension`` is as simple as writing the following ``setup.py``:
6262
6363
setup(name="extension_cpp",
6464
ext_modules=[
65-
cpp_extension.CppExtension(
65+
cpp_extension.CppExtension(
6666
"extension_cpp",
6767
["muladd.cpp"],
68-
# define Py_LIMITED_API with min version 3.9 to expose only the stable
69-
# limited API subset from Python.h
70-
extra_compile_args={"cxx": ["-DPy_LIMITED_API=0x03090000"]},
68+
extra_compile_args={
69+
"cxx": [
70+
# define Py_LIMITED_API with min version 3.9 to expose only the stable
71+
# limited API subset from Python.h
72+
"-DPy_LIMITED_API=0x03090000",
73+
# define TORCH_TARGET_VERSION with min version 2.10 to expose only the
74+
# stable API subset from torch
75+
"-DTORCH_TARGET_VERSION=0x020a000000000000",
76+
]
77+
},
7178
py_limited_api=True)], # Build 1 wheel across multiple Python versions
7279
cmdclass={'build_ext': cpp_extension.BuildExtension},
7380
options={"bdist_wheel": {"py_limited_api": "cp39"}} # 3.9 is minimum supported Python version
@@ -78,6 +85,9 @@ If you need to compile CUDA code (for example, ``.cu`` files), then instead use
7885
Please see `extension-cpp <https://github.com/pytorch/extension-cpp>`_ for an
7986
example for how this is set up.
8087

88+
CPython Agnosticism
89+
^^^^^^^^^^^^^^^^^^^
90+
8191
The above example represents what we refer to as a CPython agnostic wheel, meaning we are
8292
building a single wheel that can be run across multiple CPython versions (similar to pure
8393
Python packages). CPython agnosticism is desirable in minimizing the number of wheels your
@@ -148,33 +158,67 @@ like so:
148158
cmdclass={'build_ext': cpp_extension.BuildExtension},
149159
)
150160
161+
LibTorch Stable ABI (PyTorch Agnosticism)
162+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
163+
164+
In addition to CPython agnosticism, there is a second axis of wheel compatibility:
165+
**LibTorch agnosticism**. While CPython agnosticism allows building a single wheel
166+
that works across multiple Python versions (3.9, 3.10, 3.11, etc.), LibTorch agnosticism
167+
allows building a single wheel that works across multiple PyTorch versions (2.10, 2.11, 2.12, etc.).
168+
These two concepts are orthogonal and can be combined.
169+
170+
To achieve LibTorch agnosticism, you must use the **LibTorch Stable ABI**, which provides
171+
a stable C API for interacting with PyTorch tensors and operators. For example, instead of
172+
using ``at::Tensor``, you must use ``torch::stable::Tensor``. For comprehensive
173+
documentation on the stable ABI, including migration guides, supported types, and
174+
stack-based API conventions, see the
175+
`LibTorch Stable ABI documentation <https://pytorch.org/docs/main/notes/libtorch_stable_abi.html>`_.
176+
177+
The setup.py above already includes ``TORCH_TARGET_VERSION=0x020a000000000000``, which indicates that
178+
the extension targets the LibTorch Stable ABI with a minimum supported PyTorch version of 2.10. The version format is:
179+
``[MAJ 1 byte][MIN 1 byte][PATCH 1 byte][ABI TAG 5 bytes]``, so 2.10.0 = ``0x020a000000000000``.
180+
181+
See the section below for examples of code using the LibTorch Stable ABI.
182+
151183

152184
Defining the custom op and adding backend implementations
153185
---------------------------------------------------------
154-
First, let's write a C++ function that computes ``mymuladd``:
186+
First, let's write a C++ function that computes ``mymuladd`` using the LibTorch Stable ABI:
155187

156188
.. code-block:: cpp
157189
158-
at::Tensor mymuladd_cpu(at::Tensor a, const at::Tensor& b, double c) {
159-
TORCH_CHECK(a.sizes() == b.sizes());
160-
TORCH_CHECK(a.dtype() == at::kFloat);
161-
TORCH_CHECK(b.dtype() == at::kFloat);
162-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
163-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
164-
at::Tensor a_contig = a.contiguous();
165-
at::Tensor b_contig = b.contiguous();
166-
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
167-
const float* a_ptr = a_contig.data_ptr<float>();
168-
const float* b_ptr = b_contig.data_ptr<float>();
169-
float* result_ptr = result.data_ptr<float>();
190+
#include <torch/csrc/stable/library.h>
191+
#include <torch/csrc/stable/ops.h>
192+
#include <torch/csrc/stable/tensor.h>
193+
#include <torch/headeronly/core/ScalarType.h>
194+
#include <torch/headeronly/macros/Macros.h>
195+
196+
torch::stable::Tensor mymuladd_cpu(
197+
const torch::stable::Tensor& a,
198+
const torch::stable::Tensor& b,
199+
double c) {
200+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
201+
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
202+
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
203+
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU);
204+
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU);
205+
206+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
207+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
208+
torch::stable::Tensor result = torch::stable::empty_like(a_contig);
209+
210+
const float* a_ptr = a_contig.const_data_ptr<float>();
211+
const float* b_ptr = b_contig.const_data_ptr<float>();
212+
float* result_ptr = result.mutable_data_ptr<float>();
213+
170214
for (int64_t i = 0; i < result.numel(); i++) {
171215
result_ptr[i] = a_ptr[i] * b_ptr[i] + c;
172216
}
173217
return result;
174218
}
175219
176220
In order to use this from PyTorch’s Python frontend, we need to register it
177-
as a PyTorch operator using the ``TORCH_LIBRARY`` API. This will automatically
221+
as a PyTorch operator using the ``STABLE_TORCH_LIBRARY`` macro. This will automatically
178222
bind the operator to Python.
179223

180224
Operator registration is a two step-process:
@@ -188,7 +232,7 @@ Defining an operator
188232
To define an operator, follow these steps:
189233

190234
1. select a namespace for an operator. We recommend the namespace be the name of your top-level
191-
project; well use "extension_cpp" in our tutorial.
235+
project; we'll use "extension_cpp" in our tutorial.
192236
2. provide a schema string that specifies the input/output types of the operator and if an
193237
input Tensors will be mutated. We support more types in addition to Tensor and float;
194238
please see `The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_
@@ -199,7 +243,7 @@ To define an operator, follow these steps:
199243

200244
.. code-block:: cpp
201245
202-
TORCH_LIBRARY(extension_cpp, m) {
246+
STABLE_TORCH_LIBRARY(extension_cpp, m) {
203247
// Note that "float" in the schema corresponds to the C++ double type
204248
// and the Python float type.
205249
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
@@ -209,46 +253,89 @@ This makes the operator available from Python via ``torch.ops.extension_cpp.mymu
209253

210254
Registering backend implementations for an operator
211255
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
212-
Use ``TORCH_LIBRARY_IMPL`` to register a backend implementation for the operator.
256+
Use ``STABLE_TORCH_LIBRARY_IMPL`` to register a backend implementation for the operator.
257+
Note that we wrap the function pointer with ``TORCH_BOX()`` - this is required for
258+
stable ABI functions to handle argument boxing/unboxing correctly.
213259

214260
.. code-block:: cpp
215261
216-
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
217-
m.impl("mymuladd", &mymuladd_cpu);
262+
STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
263+
m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu));
218264
}
219265
220266
If you also have a CUDA implementation of ``myaddmul``, you can register it
221-
in a separate ``TORCH_LIBRARY_IMPL`` block:
267+
in a separate ``STABLE_TORCH_LIBRARY_IMPL`` block:
222268

223269
.. code-block:: cpp
224270
271+
#include <torch/csrc/stable/library.h>
272+
#include <torch/csrc/stable/ops.h>
273+
#include <torch/csrc/stable/tensor.h>
274+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
275+
#include <cuda.h>
276+
#include <cuda_runtime.h>
277+
225278
__global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) {
226279
int idx = blockIdx.x * blockDim.x + threadIdx.x;
227280
if (idx < numel) result[idx] = a[idx] * b[idx] + c;
228281
}
229282
230-
at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
231-
TORCH_CHECK(a.sizes() == b.sizes());
232-
TORCH_CHECK(a.dtype() == at::kFloat);
233-
TORCH_CHECK(b.dtype() == at::kFloat);
234-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
235-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
236-
at::Tensor a_contig = a.contiguous();
237-
at::Tensor b_contig = b.contiguous();
238-
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
239-
const float* a_ptr = a_contig.data_ptr<float>();
240-
const float* b_ptr = b_contig.data_ptr<float>();
241-
float* result_ptr = result.data_ptr<float>();
283+
torch::stable::Tensor mymuladd_cuda(
284+
const torch::stable::Tensor& a,
285+
const torch::stable::Tensor& b,
286+
double c) {
287+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
288+
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
289+
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
290+
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA);
291+
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA);
292+
293+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
294+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
295+
torch::stable::Tensor result = torch::stable::empty_like(a_contig);
296+
297+
const float* a_ptr = a_contig.const_data_ptr<float>();
298+
const float* b_ptr = b_contig.const_data_ptr<float>();
299+
float* result_ptr = result.mutable_data_ptr<float>();
242300
243301
int numel = a_contig.numel();
244-
muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr);
302+
303+
// For now, we rely on the raw shim API to get the current CUDA stream.
304+
// This will be improved in a future release.
305+
void* stream_ptr = nullptr;
306+
TORCH_ERROR_CODE_CHECK(
307+
aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr));
308+
cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);
309+
310+
muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr);
245311
return result;
246312
}
247313
248-
TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
249-
m.impl("mymuladd", &mymuladd_cuda);
314+
STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
315+
m.impl("mymuladd", TORCH_BOX(&mymuladd_cuda));
250316
}
251317
318+
Reverting to the Non-Stable LibTorch API
319+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
320+
321+
The LibTorch Stable ABI/API is still under active development, and certain APIs may not
322+
yet be available in ``torch/csrc/stable``, ``torch/headeronly``, or the C shims
323+
(``torch/csrc/stable/c/shim.h``).
324+
325+
If you need an API that is not yet available in the stable ABI/API, you can revert to
326+
the regular ATen API by:
327+
328+
1. Removing ``-DTORCH_TARGET_VERSION`` from your ``extra_compile_args``
329+
2. Using ``TORCH_LIBRARY`` instead of ``STABLE_TORCH_LIBRARY``
330+
3. Using ``TORCH_LIBRARY_IMPL`` instead of ``STABLE_TORCH_LIBRARY_IMPL``
331+
4. Reverting to ATen APIs (e.g. using ``at::Tensor`` instead of ``torch::stable::Tensor`` etc.)
332+
333+
Note that doing so means you will need to build separate wheels for each PyTorch
334+
version you want to support.
335+
336+
For reference, see the `PyTorch 2.9.1 version of this tutorial <https://github.com/pytorch/tutorials/blob/10eefc3b761a5b5407862b2336493b7ab859640f/advanced_source/cpp_custom_ops.rst>`_
337+
which uses the non-stable API, as well as `this commit of the extension-cpp repository <https://github.com/pytorch/extension-cpp/tree/0ec4969c7bc8e15a8456e5eb9d9ca0a7ec15bc95>`_.
338+
252339
Adding ``torch.compile`` support for an operator
253340
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
254341

@@ -327,7 +414,7 @@ three ways:
327414
for more details:
328415

329416
.. code-block:: cpp
330-
417+
331418
#include <Python.h>
332419
333420
extern "C" {
@@ -531,21 +618,27 @@ Let's author a ``myadd_out(a, b, out)`` operator, which writes the contents of `
531618
.. code-block:: cpp
532619
533620
// An example of an operator that mutates one of its inputs.
534-
void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
535-
TORCH_CHECK(a.sizes() == b.sizes());
536-
TORCH_CHECK(b.sizes() == out.sizes());
537-
TORCH_CHECK(a.dtype() == at::kFloat);
538-
TORCH_CHECK(b.dtype() == at::kFloat);
539-
TORCH_CHECK(out.dtype() == at::kFloat);
540-
TORCH_CHECK(out.is_contiguous());
541-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
542-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
543-
TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU);
544-
at::Tensor a_contig = a.contiguous();
545-
at::Tensor b_contig = b.contiguous();
546-
const float* a_ptr = a_contig.data_ptr<float>();
547-
const float* b_ptr = b_contig.data_ptr<float>();
548-
float* result_ptr = out.data_ptr<float>();
621+
void myadd_out_cpu(
622+
const torch::stable::Tensor& a,
623+
const torch::stable::Tensor& b,
624+
torch::stable::Tensor& out) {
625+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
626+
STD_TORCH_CHECK(b.sizes().equals(out.sizes()));
627+
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
628+
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
629+
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Float);
630+
STD_TORCH_CHECK(out.is_contiguous());
631+
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU);
632+
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU);
633+
STD_TORCH_CHECK(out.device().type() == torch::headeronly::DeviceType::CPU);
634+
635+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
636+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
637+
638+
const float* a_ptr = a_contig.const_data_ptr<float>();
639+
const float* b_ptr = b_contig.const_data_ptr<float>();
640+
float* result_ptr = out.mutable_data_ptr<float>();
641+
549642
for (int64_t i = 0; i < out.numel(); i++) {
550643
result_ptr[i] = a_ptr[i] + b_ptr[i];
551644
}
@@ -555,18 +648,18 @@ When defining the operator, we must specify that it mutates the out Tensor in th
555648

556649
.. code-block:: cpp
557650
558-
TORCH_LIBRARY(extension_cpp, m) {
651+
STABLE_TORCH_LIBRARY(extension_cpp, m) {
559652
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
560653
m.def("mymul(Tensor a, Tensor b) -> Tensor");
561654
// New!
562655
m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()");
563656
}
564657
565-
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
566-
m.impl("mymuladd", &mymuladd_cpu);
567-
m.impl("mymul", &mymul_cpu);
658+
STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
659+
m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu));
660+
m.impl("mymul", TORCH_BOX(&mymul_cpu));
568661
// New!
569-
m.impl("myadd_out", &myadd_out_cpu);
662+
m.impl("myadd_out", TORCH_BOX(&myadd_out_cpu));
570663
}
571664
572665
.. note::
@@ -577,6 +670,6 @@ When defining the operator, we must specify that it mutates the out Tensor in th
577670
Conclusion
578671
----------
579672
In this tutorial, we went over the recommended approach to integrating Custom C++
580-
and CUDA operators with PyTorch. The ``TORCH_LIBRARY/torch.library`` APIs are fairly
673+
and CUDA operators with PyTorch. The ``STABLE_TORCH_LIBRARY/torch.library`` APIs are fairly
581674
low-level. For more information about how to use the API, see
582675
`The Custom Operators Manual <https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html#the-custom-operators-manual>`_.

0 commit comments

Comments
 (0)