diff --git a/advanced_source/cpp_custom_ops.rst b/advanced_source/cpp_custom_ops.rst index 512c39b2a68..604f0f00efb 100644 --- a/advanced_source/cpp_custom_ops.rst +++ b/advanced_source/cpp_custom_ops.rst @@ -16,7 +16,7 @@ Custom C++ and CUDA Operators .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites :class-card: card-prerequisites - * PyTorch 2.4 or later + * PyTorch 2.10 or later * Basic understanding of C++ and CUDA programming .. note:: @@ -62,12 +62,19 @@ Using ``cpp_extension`` is as simple as writing the following ``setup.py``: setup(name="extension_cpp", ext_modules=[ - cpp_extension.CppExtension( + cpp_extension.CppExtension( "extension_cpp", ["muladd.cpp"], - # define Py_LIMITED_API with min version 3.9 to expose only the stable - # limited API subset from Python.h - extra_compile_args={"cxx": ["-DPy_LIMITED_API=0x03090000"]}, + extra_compile_args={ + "cxx": [ + # define Py_LIMITED_API with min version 3.9 to expose only the stable + # limited API subset from Python.h + "-DPy_LIMITED_API=0x03090000", + # define TORCH_TARGET_VERSION with min version 2.10 to expose only the + # stable API subset from torch + "-DTORCH_TARGET_VERSION=0x020a000000000000", + ] + }, py_limited_api=True)], # Build 1 wheel across multiple Python versions cmdclass={'build_ext': cpp_extension.BuildExtension}, options={"bdist_wheel": {"py_limited_api": "cp39"}} # 3.9 is minimum supported Python version @@ -78,6 +85,9 @@ If you need to compile CUDA code (for example, ``.cu`` files), then instead use Please see `extension-cpp `_ for an example for how this is set up. +CPython Agnosticism +^^^^^^^^^^^^^^^^^^^ + The above example represents what we refer to as a CPython agnostic wheel, meaning we are building a single wheel that can be run across multiple CPython versions (similar to pure Python packages). CPython agnosticism is desirable in minimizing the number of wheels your @@ -148,25 +158,59 @@ like so: cmdclass={'build_ext': cpp_extension.BuildExtension}, ) +LibTorch Stable ABI (PyTorch Agnosticism) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In addition to CPython agnosticism, there is a second axis of wheel compatibility: +**LibTorch agnosticism**. While CPython agnosticism allows building a single wheel +that works across multiple Python versions (3.9, 3.10, 3.11, etc.), LibTorch agnosticism +allows building a single wheel that works across multiple PyTorch versions (2.10, 2.11, 2.12, etc.). +These two concepts are orthogonal and can be combined. + +To achieve LibTorch agnosticism, you must use the **LibTorch Stable ABI**, which provides +a stable C API for interacting with PyTorch tensors and operators. For example, instead of +using ``at::Tensor``, you must use ``torch::stable::Tensor``. For comprehensive +documentation on the stable ABI, including migration guides, supported types, and +stack-based API conventions, see the +`LibTorch Stable ABI documentation `_. + +The setup.py above already includes ``TORCH_TARGET_VERSION=0x020a000000000000``, which indicates that +the extension targets the LibTorch Stable ABI with a minimum supported PyTorch version of 2.10. The version format is: +``[MAJ 1 byte][MIN 1 byte][PATCH 1 byte][ABI TAG 5 bytes]``, so 2.10.0 = ``0x020a000000000000``. + +See the section below for examples of code using the LibTorch Stable ABI. + Defining the custom op and adding backend implementations --------------------------------------------------------- -First, let's write a C++ function that computes ``mymuladd``: +First, let's write a C++ function that computes ``mymuladd`` using the LibTorch Stable ABI: .. code-block:: cpp - at::Tensor mymuladd_cpu(at::Tensor a, const at::Tensor& b, double c) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); - TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = result.data_ptr(); + #include + #include + #include + #include + #include + + torch::stable::Tensor mymuladd_cpu( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + double c) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + torch::stable::Tensor result = torch::stable::empty_like(a_contig); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = result.mutable_data_ptr(); + for (int64_t i = 0; i < result.numel(); i++) { result_ptr[i] = a_ptr[i] * b_ptr[i] + c; } @@ -174,7 +218,7 @@ First, let's write a C++ function that computes ``mymuladd``: } In order to use this from PyTorch’s Python frontend, we need to register it -as a PyTorch operator using the ``TORCH_LIBRARY`` API. This will automatically +as a PyTorch operator using the ``STABLE_TORCH_LIBRARY`` macro. This will automatically bind the operator to Python. Operator registration is a two step-process: @@ -188,7 +232,7 @@ Defining an operator To define an operator, follow these steps: 1. select a namespace for an operator. We recommend the namespace be the name of your top-level - project; we’ll use "extension_cpp" in our tutorial. + project; we'll use "extension_cpp" in our tutorial. 2. provide a schema string that specifies the input/output types of the operator and if an input Tensors will be mutated. We support more types in addition to Tensor and float; please see `The Custom Operators Manual `_ @@ -199,7 +243,7 @@ To define an operator, follow these steps: .. code-block:: cpp - TORCH_LIBRARY(extension_cpp, m) { + STABLE_TORCH_LIBRARY(extension_cpp, m) { // Note that "float" in the schema corresponds to the C++ double type // and the Python float type. m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); @@ -209,46 +253,91 @@ This makes the operator available from Python via ``torch.ops.extension_cpp.mymu Registering backend implementations for an operator ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Use ``TORCH_LIBRARY_IMPL`` to register a backend implementation for the operator. +Use ``STABLE_TORCH_LIBRARY_IMPL`` to register a backend implementation for the operator. +Note that we wrap the function pointer with ``TORCH_BOX()`` - this is required for +stable ABI functions to handle argument boxing/unboxing correctly. .. code-block:: cpp - TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { - m.impl("mymuladd", &mymuladd_cpu); + STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu)); } If you also have a CUDA implementation of ``myaddmul``, you can register it -in a separate ``TORCH_LIBRARY_IMPL`` block: +in a separate ``STABLE_TORCH_LIBRARY_IMPL`` block: .. code-block:: cpp + #include + #include + #include + #include + #include + #include + __global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < numel) result[idx] = a[idx] * b[idx] + c; } - at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); - TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = result.data_ptr(); + torch::stable::Tensor mymuladd_cuda( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + double c) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + torch::stable::Tensor result = torch::stable::empty_like(a_contig); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = result.mutable_data_ptr(); int numel = a_contig.numel(); - muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr); + + // For now, we rely on the raw shim API to get the current CUDA stream. + // This will be improved in a future release. + // When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to + // check the error code and throw an appropriate runtime_error otherwise. + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + + muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr); return result; } - TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) { - m.impl("mymuladd", &mymuladd_cuda); + STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) { + m.impl("mymuladd", TORCH_BOX(&mymuladd_cuda)); } +Reverting to the Non-Stable LibTorch API +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The LibTorch Stable ABI/API is still under active development, and certain APIs may not +yet be available in ``torch/csrc/stable``, ``torch/headeronly``, or the C shims +(``torch/csrc/stable/c/shim.h``). + +If you need an API that is not yet available in the stable ABI/API, you can revert to +the regular ATen API by: + +1. Removing ``-DTORCH_TARGET_VERSION`` from your ``extra_compile_args`` +2. Using ``TORCH_LIBRARY`` instead of ``STABLE_TORCH_LIBRARY`` +3. Using ``TORCH_LIBRARY_IMPL`` instead of ``STABLE_TORCH_LIBRARY_IMPL`` +4. Reverting to ATen APIs (e.g. using ``at::Tensor`` instead of ``torch::stable::Tensor`` etc.) + +Note that doing so means you will need to build separate wheels for each PyTorch +version you want to support. + +For reference, see the `PyTorch 2.9.1 version of this tutorial `_ +which uses the non-stable API, as well as `this commit of the extension-cpp repository `_. + Adding ``torch.compile`` support for an operator ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -327,7 +416,7 @@ three ways: for more details: .. code-block:: cpp - + #include extern "C" { @@ -380,8 +469,7 @@ three ways: Adding training (autograd) support for an operator ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Use ``torch.library.register_autograd`` to add training support for an operator. Prefer -this over directly using Python ``torch.autograd.Function`` or C++ ``torch::autograd::Function``; -you must use those in a very specific way to avoid silent incorrectness (see +this over directly using Python ``torch.autograd.Function`` (see `The Custom Operators Manual `_ for more details). @@ -421,35 +509,40 @@ custom operator and then call that from the backward: .. code-block:: cpp // New! a mymul_cpu kernel - at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_CHECK(a.device().type() == at::DeviceType::CPU); - TORCH_CHECK(b.device().type() == at::DeviceType::CPU); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = result.data_ptr(); + torch::stable::Tensor mymul_cpu( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + torch::stable::Tensor result = torch::stable::empty_like(a_contig); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = result.mutable_data_ptr(); + for (int64_t i = 0; i < result.numel(); i++) { result_ptr[i] = a_ptr[i] * b_ptr[i]; } return result; } - TORCH_LIBRARY(extension_cpp, m) { + STABLE_TORCH_LIBRARY(extension_cpp, m) { m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); // New! defining the mymul operator m.def("mymul(Tensor a, Tensor b) -> Tensor"); } - TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { - m.impl("mymuladd", &mymuladd_cpu); + STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu)); // New! registering the cpu kernel for the mymul operator - m.impl("mymul", &mymul_cpu); + m.impl("mymul", TORCH_BOX(&mymul_cpu)); } .. code-block:: python @@ -531,21 +624,27 @@ Let's author a ``myadd_out(a, b, out)`` operator, which writes the contents of ` .. code-block:: cpp // An example of an operator that mutates one of its inputs. - void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(b.sizes() == out.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_CHECK(out.dtype() == at::kFloat); - TORCH_CHECK(out.is_contiguous()); - TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); - TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); - TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = out.data_ptr(); + void myadd_out_cpu( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + torch::stable::Tensor& out) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(b.sizes().equals(out.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(out.is_contiguous()); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(out.device().type() == torch::headeronly::DeviceType::CPU); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = out.mutable_data_ptr(); + for (int64_t i = 0; i < out.numel(); i++) { result_ptr[i] = a_ptr[i] + b_ptr[i]; } @@ -555,18 +654,18 @@ When defining the operator, we must specify that it mutates the out Tensor in th .. code-block:: cpp - TORCH_LIBRARY(extension_cpp, m) { + STABLE_TORCH_LIBRARY(extension_cpp, m) { m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); m.def("mymul(Tensor a, Tensor b) -> Tensor"); // New! m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()"); } - TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { - m.impl("mymuladd", &mymuladd_cpu); - m.impl("mymul", &mymul_cpu); + STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu)); + m.impl("mymul", TORCH_BOX(&mymul_cpu)); // New! - m.impl("myadd_out", &myadd_out_cpu); + m.impl("myadd_out", TORCH_BOX(&myadd_out_cpu)); } .. note:: @@ -577,6 +676,6 @@ When defining the operator, we must specify that it mutates the out Tensor in th Conclusion ---------- In this tutorial, we went over the recommended approach to integrating Custom C++ -and CUDA operators with PyTorch. The ``TORCH_LIBRARY/torch.library`` APIs are fairly +and CUDA operators with PyTorch. The ``STABLE_TORCH_LIBRARY/torch.library`` APIs are fairly low-level. For more information about how to use the API, see `The Custom Operators Manual `_.