@@ -16,7 +16,7 @@ Custom C++ and CUDA Operators
1616 .. grid-item-card :: :octicon:`list-unordered;1em;` Prerequisites
1717 :class-card: card-prerequisites
1818
19- * PyTorch 2.4 or later
19+ * PyTorch 2.10 or later
2020 * Basic understanding of C++ and CUDA programming
2121
2222.. note ::
@@ -62,12 +62,19 @@ Using ``cpp_extension`` is as simple as writing the following ``setup.py``:
6262
6363 setup(name = " extension_cpp" ,
6464 ext_modules = [
65- cpp_extension.CppExtension(
65+ cpp_extension.CppExtension(
6666 " extension_cpp" ,
6767 [" muladd.cpp" ],
68- # define Py_LIMITED_API with min version 3.9 to expose only the stable
69- # limited API subset from Python.h
70- extra_compile_args = {" cxx" : [" -DPy_LIMITED_API=0x03090000" ]},
68+ extra_compile_args = {
69+ " cxx" : [
70+ # define Py_LIMITED_API with min version 3.9 to expose only the stable
71+ # limited API subset from Python.h
72+ " -DPy_LIMITED_API=0x03090000" ,
73+ # define TORCH_TARGET_VERSION with min version 2.10 to expose only the
74+ # stable API subset from torch
75+ " -DTORCH_TARGET_VERSION=0x020a000000000000" ,
76+ ]
77+ },
7178 py_limited_api = True )], # Build 1 wheel across multiple Python versions
7279 cmdclass = {' build_ext' : cpp_extension.BuildExtension},
7380 options = {" bdist_wheel" : {" py_limited_api" : " cp39" }} # 3.9 is minimum supported Python version
@@ -78,6 +85,9 @@ If you need to compile CUDA code (for example, ``.cu`` files), then instead use
7885Please see `extension-cpp <https://github.com/pytorch/extension-cpp >`_ for an
7986example for how this is set up.
8087
88+ CPython Agnosticism
89+ ^^^^^^^^^^^^^^^^^^^
90+
8191The above example represents what we refer to as a CPython agnostic wheel, meaning we are
8292building a single wheel that can be run across multiple CPython versions (similar to pure
8393Python packages). CPython agnosticism is desirable in minimizing the number of wheels your
@@ -148,33 +158,67 @@ like so:
148158 cmdclass = {' build_ext' : cpp_extension.BuildExtension},
149159 )
150160
161+ LibTorch Stable ABI (PyTorch Agnosticism)
162+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
163+
164+ In addition to CPython agnosticism, there is a second axis of wheel compatibility:
165+ **LibTorch agnosticism **. While CPython agnosticism allows building a single wheel
166+ that works across multiple Python versions (3.9, 3.10, 3.11, etc.), LibTorch agnosticism
167+ allows building a single wheel that works across multiple PyTorch versions (2.10, 2.11, 2.12, etc.).
168+ These two concepts are orthogonal and can be combined.
169+
170+ To achieve LibTorch agnosticism, you must use the **LibTorch Stable ABI **, which provides
171+ a stable C API for interacting with PyTorch tensors and operators. For example, instead of
172+ using ``at::Tensor ``, you must use ``torch::stable::Tensor ``. For comprehensive
173+ documentation on the stable ABI, including migration guides, supported types, and
174+ stack-based API conventions, see the
175+ `LibTorch Stable ABI documentation <https://pytorch.org/docs/main/notes/libtorch_stable_abi.html >`_.
176+
177+ The setup.py above already includes ``TORCH_TARGET_VERSION=0x020a000000000000 ``, which indicates that
178+ the extension targets the LibTorch Stable ABI with a minimum supported PyTorch version of 2.10. The version format is:
179+ ``[MAJ 1 byte][MIN 1 byte][PATCH 1 byte][ABI TAG 5 bytes] ``, so 2.10.0 = ``0x020a000000000000 ``.
180+
181+ See the section below for examples of code using the LibTorch Stable ABI.
182+
151183
152184Defining the custom op and adding backend implementations
153185---------------------------------------------------------
154- First, let's write a C++ function that computes ``mymuladd ``:
186+ First, let's write a C++ function that computes ``mymuladd `` using the LibTorch Stable ABI :
155187
156188.. code-block :: cpp
157189
158- at::Tensor mymuladd_cpu(at::Tensor a, const at::Tensor& b, double c) {
159- TORCH_CHECK(a.sizes() == b.sizes());
160- TORCH_CHECK(a.dtype() == at::kFloat);
161- TORCH_CHECK(b.dtype() == at::kFloat);
162- TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
163- TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
164- at::Tensor a_contig = a.contiguous();
165- at::Tensor b_contig = b.contiguous();
166- at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
167- const float* a_ptr = a_contig.data_ptr<float>();
168- const float* b_ptr = b_contig.data_ptr<float>();
169- float* result_ptr = result.data_ptr<float>();
190+ #include <torch/csrc/stable/library.h>
191+ #include <torch/csrc/stable/ops.h>
192+ #include <torch/csrc/stable/tensor.h>
193+ #include <torch/headeronly/core/ScalarType.h>
194+ #include <torch/headeronly/macros/Macros.h>
195+
196+ torch::stable::Tensor mymuladd_cpu(
197+ const torch::stable::Tensor& a,
198+ const torch::stable::Tensor& b,
199+ double c) {
200+ STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
201+ STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
202+ STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
203+ STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU);
204+ STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU);
205+
206+ torch::stable::Tensor a_contig = torch::stable::contiguous(a);
207+ torch::stable::Tensor b_contig = torch::stable::contiguous(b);
208+ torch::stable::Tensor result = torch::stable::empty_like(a_contig);
209+
210+ const float* a_ptr = a_contig.const_data_ptr<float>();
211+ const float* b_ptr = b_contig.const_data_ptr<float>();
212+ float* result_ptr = result.mutable_data_ptr<float>();
213+
170214 for (int64_t i = 0; i < result.numel(); i++) {
171215 result_ptr[i] = a_ptr[i] * b_ptr[i] + c;
172216 }
173217 return result;
174218 }
175219
176220 In order to use this from PyTorch’s Python frontend, we need to register it
177- as a PyTorch operator using the ``TORCH_LIBRARY `` API . This will automatically
221+ as a PyTorch operator using the ``STABLE_TORCH_LIBRARY `` macro . This will automatically
178222bind the operator to Python.
179223
180224Operator registration is a two step-process:
@@ -188,7 +232,7 @@ Defining an operator
188232To define an operator, follow these steps:
189233
1902341. select a namespace for an operator. We recommend the namespace be the name of your top-level
191- project; we’ ll use "extension_cpp" in our tutorial.
235+ project; we' ll use "extension_cpp" in our tutorial.
1922362. provide a schema string that specifies the input/output types of the operator and if an
193237 input Tensors will be mutated. We support more types in addition to Tensor and float;
194238 please see `The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html >`_
@@ -199,7 +243,7 @@ To define an operator, follow these steps:
199243
200244.. code-block :: cpp
201245
202- TORCH_LIBRARY (extension_cpp, m) {
246+ STABLE_TORCH_LIBRARY (extension_cpp, m) {
203247 // Note that "float" in the schema corresponds to the C++ double type
204248 // and the Python float type.
205249 m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
@@ -209,46 +253,89 @@ This makes the operator available from Python via ``torch.ops.extension_cpp.mymu
209253
210254Registering backend implementations for an operator
211255^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
212- Use ``TORCH_LIBRARY_IMPL `` to register a backend implementation for the operator.
256+ Use ``STABLE_TORCH_LIBRARY_IMPL `` to register a backend implementation for the operator.
257+ Note that we wrap the function pointer with ``TORCH_BOX() `` - this is required for
258+ stable ABI functions to handle argument boxing/unboxing correctly.
213259
214260.. code-block :: cpp
215261
216- TORCH_LIBRARY_IMPL (extension_cpp, CPU, m) {
217- m.impl("mymuladd", &mymuladd_cpu);
262+ STABLE_TORCH_LIBRARY_IMPL (extension_cpp, CPU, m) {
263+ m.impl("mymuladd", TORCH_BOX( &mymuladd_cpu) );
218264 }
219265
220266 If you also have a CUDA implementation of ``myaddmul ``, you can register it
221- in a separate ``TORCH_LIBRARY_IMPL `` block:
267+ in a separate ``STABLE_TORCH_LIBRARY_IMPL `` block:
222268
223269.. code-block :: cpp
224270
271+ #include <torch/csrc/stable/library.h>
272+ #include <torch/csrc/stable/ops.h>
273+ #include <torch/csrc/stable/tensor.h>
274+ #include <torch/csrc/inductor/aoti_torch/c/shim.h>
275+ #include <cuda.h>
276+ #include <cuda_runtime.h>
277+
225278 __global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) {
226279 int idx = blockIdx.x * blockDim.x + threadIdx.x;
227280 if (idx < numel) result[idx] = a[idx] * b[idx] + c;
228281 }
229282
230- at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
231- TORCH_CHECK(a.sizes() == b.sizes());
232- TORCH_CHECK(a.dtype() == at::kFloat);
233- TORCH_CHECK(b.dtype() == at::kFloat);
234- TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
235- TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
236- at::Tensor a_contig = a.contiguous();
237- at::Tensor b_contig = b.contiguous();
238- at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
239- const float* a_ptr = a_contig.data_ptr<float>();
240- const float* b_ptr = b_contig.data_ptr<float>();
241- float* result_ptr = result.data_ptr<float>();
283+ torch::stable::Tensor mymuladd_cuda(
284+ const torch::stable::Tensor& a,
285+ const torch::stable::Tensor& b,
286+ double c) {
287+ STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
288+ STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
289+ STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
290+ STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA);
291+ STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA);
292+
293+ torch::stable::Tensor a_contig = torch::stable::contiguous(a);
294+ torch::stable::Tensor b_contig = torch::stable::contiguous(b);
295+ torch::stable::Tensor result = torch::stable::empty_like(a_contig);
296+
297+ const float* a_ptr = a_contig.const_data_ptr<float>();
298+ const float* b_ptr = b_contig.const_data_ptr<float>();
299+ float* result_ptr = result.mutable_data_ptr<float>();
242300
243301 int numel = a_contig.numel();
244- muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr);
302+
303+ // For now, we rely on the raw shim API to get the current CUDA stream.
304+ // This will be improved in a future release.
305+ void* stream_ptr = nullptr;
306+ TORCH_ERROR_CODE_CHECK(
307+ aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr));
308+ cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);
309+
310+ muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr);
245311 return result;
246312 }
247313
248- TORCH_LIBRARY_IMPL (extension_cpp, CUDA, m) {
249- m.impl("mymuladd", &mymuladd_cuda);
314+ STABLE_TORCH_LIBRARY_IMPL (extension_cpp, CUDA, m) {
315+ m.impl("mymuladd", TORCH_BOX( &mymuladd_cuda) );
250316 }
251317
318+ Reverting to the Non-Stable LibTorch API
319+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
320+
321+ The LibTorch Stable ABI/API is still under active development, and certain APIs may not
322+ yet be available in ``torch/csrc/stable ``, ``torch/headeronly ``, or the C shims
323+ (``torch/csrc/stable/c/shim.h ``).
324+
325+ If you need an API that is not yet available in the stable ABI/API, you can revert to
326+ the regular ATen API by:
327+
328+ 1. Removing ``-DTORCH_TARGET_VERSION `` from your ``extra_compile_args ``
329+ 2. Using ``TORCH_LIBRARY `` instead of ``STABLE_TORCH_LIBRARY ``
330+ 3. Using ``TORCH_LIBRARY_IMPL `` instead of ``STABLE_TORCH_LIBRARY_IMPL ``
331+ 4. Reverting to ATen APIs (e.g. using ``at::Tensor `` instead of ``torch::stable::Tensor `` etc.)
332+
333+ Note that doing so means you will need to build separate wheels for each PyTorch
334+ version you want to support.
335+
336+ For reference, see the `PyTorch 2.9.1 version of this tutorial <https://github.com/pytorch/tutorials/blob/10eefc3b761a5b5407862b2336493b7ab859640f/advanced_source/cpp_custom_ops.rst >`_
337+ which uses the non-stable API, as well as `this commit of the extension-cpp repository <https://github.com/pytorch/extension-cpp/tree/0ec4969c7bc8e15a8456e5eb9d9ca0a7ec15bc95 >`_.
338+
252339Adding ``torch.compile `` support for an operator
253340^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
254341
@@ -327,7 +414,7 @@ three ways:
327414 for more details:
328415
329416.. code-block :: cpp
330-
417+
331418 #include <Python.h>
332419
333420 extern "C" {
@@ -531,21 +618,27 @@ Let's author a ``myadd_out(a, b, out)`` operator, which writes the contents of `
531618.. code-block :: cpp
532619
533620 // An example of an operator that mutates one of its inputs.
534- void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
535- TORCH_CHECK(a.sizes() == b.sizes());
536- TORCH_CHECK(b.sizes() == out.sizes());
537- TORCH_CHECK(a.dtype() == at::kFloat);
538- TORCH_CHECK(b.dtype() == at::kFloat);
539- TORCH_CHECK(out.dtype() == at::kFloat);
540- TORCH_CHECK(out.is_contiguous());
541- TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
542- TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
543- TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU);
544- at::Tensor a_contig = a.contiguous();
545- at::Tensor b_contig = b.contiguous();
546- const float* a_ptr = a_contig.data_ptr<float>();
547- const float* b_ptr = b_contig.data_ptr<float>();
548- float* result_ptr = out.data_ptr<float>();
621+ void myadd_out_cpu(
622+ const torch::stable::Tensor& a,
623+ const torch::stable::Tensor& b,
624+ torch::stable::Tensor& out) {
625+ STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
626+ STD_TORCH_CHECK(b.sizes().equals(out.sizes()));
627+ STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
628+ STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
629+ STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Float);
630+ STD_TORCH_CHECK(out.is_contiguous());
631+ STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU);
632+ STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU);
633+ STD_TORCH_CHECK(out.device().type() == torch::headeronly::DeviceType::CPU);
634+
635+ torch::stable::Tensor a_contig = torch::stable::contiguous(a);
636+ torch::stable::Tensor b_contig = torch::stable::contiguous(b);
637+
638+ const float* a_ptr = a_contig.const_data_ptr<float>();
639+ const float* b_ptr = b_contig.const_data_ptr<float>();
640+ float* result_ptr = out.mutable_data_ptr<float>();
641+
549642 for (int64_t i = 0; i < out.numel(); i++) {
550643 result_ptr[i] = a_ptr[i] + b_ptr[i];
551644 }
@@ -555,18 +648,18 @@ When defining the operator, we must specify that it mutates the out Tensor in th
555648
556649.. code-block :: cpp
557650
558- TORCH_LIBRARY (extension_cpp, m) {
651+ STABLE_TORCH_LIBRARY (extension_cpp, m) {
559652 m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
560653 m.def("mymul(Tensor a, Tensor b) -> Tensor");
561654 // New!
562655 m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()");
563656 }
564657
565- TORCH_LIBRARY_IMPL (extension_cpp, CPU, m) {
566- m.impl("mymuladd", &mymuladd_cpu);
567- m.impl("mymul", &mymul_cpu);
658+ STABLE_TORCH_LIBRARY_IMPL (extension_cpp, CPU, m) {
659+ m.impl("mymuladd", TORCH_BOX( &mymuladd_cpu) );
660+ m.impl("mymul", TORCH_BOX( &mymul_cpu) );
568661 // New!
569- m.impl("myadd_out", &myadd_out_cpu);
662+ m.impl("myadd_out", TORCH_BOX( &myadd_out_cpu) );
570663 }
571664
572665 .. note ::
@@ -577,6 +670,6 @@ When defining the operator, we must specify that it mutates the out Tensor in th
577670Conclusion
578671----------
579672In this tutorial, we went over the recommended approach to integrating Custom C++
580- and CUDA operators with PyTorch. The ``TORCH_LIBRARY /torch.library `` APIs are fairly
673+ and CUDA operators with PyTorch. The ``STABLE_TORCH_LIBRARY /torch.library `` APIs are fairly
581674low-level. For more information about how to use the API, see
582675`The Custom Operators Manual <https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html#the-custom-operators-manual >`_.
0 commit comments