1- #include < ATen/Operators.h>
2- #include < torch/all.h>
3- #include < torch/library.h>
1+ // LibTorch Stable ABI version of CUDA custom operators
2+ // This file uses the stable API for cross-version compatibility.
3+ // See: https://pytorch.org/docs/main/notes/libtorch_stable_abi.html
4+
5+ #include < torch/csrc/stable/library.h>
6+ #include < torch/csrc/stable/ops.h>
7+ #include < torch/csrc/stable/tensor.h>
8+ #include < torch/csrc/stable/accelerator.h>
9+ #include < torch/headeronly/core/ScalarType.h>
10+ #include < torch/headeronly/macros/Macros.h>
11+
12+ #include < torch/csrc/inductor/aoti_torch/c/shim.h>
413
514#include < cuda.h>
615#include < cuda_runtime.h>
7- #include < ATen/cuda/CUDAContext.h>
816
917namespace extension_cpp {
1018
@@ -13,21 +21,39 @@ __global__ void muladd_kernel(int numel, const float* a, const float* b, float c
1321 if (idx < numel) result[idx] = a[idx] * b[idx] + c;
1422}
1523
16- at::Tensor mymuladd_cuda (const at::Tensor& a, const at::Tensor& b, double c) {
17- TORCH_CHECK (a.sizes () == b.sizes ());
18- TORCH_CHECK (a.dtype () == at::kFloat );
19- TORCH_CHECK (b.dtype () == at::kFloat );
20- TORCH_INTERNAL_ASSERT (a.device ().type () == at::DeviceType::CUDA);
21- TORCH_INTERNAL_ASSERT (b.device ().type () == at::DeviceType::CUDA);
22- at::Tensor a_contig = a.contiguous ();
23- at::Tensor b_contig = b.contiguous ();
24- at::Tensor result = at::empty (a_contig.sizes (), a_contig.options ());
25- const float * a_ptr = a_contig.data_ptr <float >();
26- const float * b_ptr = b_contig.data_ptr <float >();
27- float * result_ptr = result.data_ptr <float >();
24+ torch::stable::Tensor mymuladd_cuda (
25+ const torch::stable::Tensor& a,
26+ const torch::stable::Tensor& b,
27+ double c) {
28+ STD_TORCH_CHECK (a.sizes ().equals (b.sizes ()), " Tensor sizes must match" );
29+ STD_TORCH_CHECK (
30+ a.scalar_type () == torch::headeronly::ScalarType::Float,
31+ " Input tensor a must be float32" );
32+ STD_TORCH_CHECK (
33+ b.scalar_type () == torch::headeronly::ScalarType::Float,
34+ " Input tensor b must be float32" );
35+ STD_TORCH_CHECK (
36+ a.device ().type () == torch::headeronly::DeviceType::CUDA,
37+ " Input tensor a must be on CUDA" );
38+ STD_TORCH_CHECK (
39+ b.device ().type () == torch::headeronly::DeviceType::CUDA,
40+ " Input tensor b must be on CUDA" );
41+
42+ torch::stable::Tensor a_contig = torch::stable::contiguous (a);
43+ torch::stable::Tensor b_contig = torch::stable::contiguous (b);
44+ torch::stable::Tensor result = torch::stable::empty_like (a_contig);
45+
46+ const float * a_ptr = a_contig.const_data_ptr <float >();
47+ const float * b_ptr = b_contig.const_data_ptr <float >();
48+ float * result_ptr = result.mutable_data_ptr <float >();
2849
2950 int numel = a_contig.numel ();
30- cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
51+
52+ void * stream_ptr = nullptr ;
53+ TORCH_ERROR_CODE_CHECK (
54+ aoti_torch_get_current_cuda_stream (a.get_device_index (), &stream_ptr));
55+ cudaStream_t stream = static_cast <cudaStream_t>(stream_ptr);
56+
3157 muladd_kernel<<<(numel+255 )/256 , 256 , 0 , stream>>> (numel, a_ptr, b_ptr, c, result_ptr);
3258 return result;
3359}
@@ -37,20 +63,38 @@ __global__ void mul_kernel(int numel, const float* a, const float* b, float* res
3763 if (idx < numel) result[idx] = a[idx] * b[idx];
3864}
3965
40- at::Tensor mymul_cuda (const at::Tensor& a, const at::Tensor& b) {
41- TORCH_CHECK (a.sizes () == b.sizes ());
42- TORCH_CHECK (a.dtype () == at::kFloat );
43- TORCH_CHECK (b.dtype () == at::kFloat );
44- TORCH_INTERNAL_ASSERT (a.device ().type () == at::DeviceType::CUDA);
45- TORCH_INTERNAL_ASSERT (b.device ().type () == at::DeviceType::CUDA);
46- at::Tensor a_contig = a.contiguous ();
47- at::Tensor b_contig = b.contiguous ();
48- at::Tensor result = at::empty (a_contig.sizes (), a_contig.options ());
49- const float * a_ptr = a_contig.data_ptr <float >();
50- const float * b_ptr = b_contig.data_ptr <float >();
51- float * result_ptr = result.data_ptr <float >();
66+ torch::stable::Tensor mymul_cuda (
67+ const torch::stable::Tensor& a,
68+ const torch::stable::Tensor& b) {
69+ STD_TORCH_CHECK (a.sizes ().equals (b.sizes ()), " Tensor sizes must match" );
70+ STD_TORCH_CHECK (
71+ a.scalar_type () == torch::headeronly::ScalarType::Float,
72+ " Input tensor a must be float32" );
73+ STD_TORCH_CHECK (
74+ b.scalar_type () == torch::headeronly::ScalarType::Float,
75+ " Input tensor b must be float32" );
76+ STD_TORCH_CHECK (
77+ a.device ().type () == torch::headeronly::DeviceType::CUDA,
78+ " Input tensor a must be on CUDA" );
79+ STD_TORCH_CHECK (
80+ b.device ().type () == torch::headeronly::DeviceType::CUDA,
81+ " Input tensor b must be on CUDA" );
82+
83+ torch::stable::Tensor a_contig = torch::stable::contiguous (a);
84+ torch::stable::Tensor b_contig = torch::stable::contiguous (b);
85+ torch::stable::Tensor result = torch::stable::empty_like (a_contig);
86+
87+ const float * a_ptr = a_contig.const_data_ptr <float >();
88+ const float * b_ptr = b_contig.const_data_ptr <float >();
89+ float * result_ptr = result.mutable_data_ptr <float >();
90+
5291 int numel = a_contig.numel ();
53- cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
92+
93+ void * stream_ptr = nullptr ;
94+ TORCH_ERROR_CODE_CHECK (
95+ aoti_torch_get_current_cuda_stream (a.get_device_index (), &stream_ptr));
96+ cudaStream_t stream = static_cast <cudaStream_t>(stream_ptr);
97+
5498 mul_kernel<<<(numel+255 )/256 , 256 , 0 , stream>>> (numel, a_ptr, b_ptr, result_ptr);
5599 return result;
56100}
@@ -60,32 +104,55 @@ __global__ void add_kernel(int numel, const float* a, const float* b, float* res
60104 if (idx < numel) result[idx] = a[idx] + b[idx];
61105}
62106
63- void myadd_out_cuda (const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
64- TORCH_CHECK (a.sizes () == b.sizes ());
65- TORCH_CHECK (b.sizes () == out.sizes ());
66- TORCH_CHECK (a.dtype () == at::kFloat );
67- TORCH_CHECK (b.dtype () == at::kFloat );
68- TORCH_CHECK (out.dtype () == at::kFloat );
69- TORCH_CHECK (out.is_contiguous ());
70- TORCH_INTERNAL_ASSERT (a.device ().type () == at::DeviceType::CUDA);
71- TORCH_INTERNAL_ASSERT (b.device ().type () == at::DeviceType::CUDA);
72- TORCH_INTERNAL_ASSERT (out.device ().type () == at::DeviceType::CUDA);
73- at::Tensor a_contig = a.contiguous ();
74- at::Tensor b_contig = b.contiguous ();
75- const float * a_ptr = a_contig.data_ptr <float >();
76- const float * b_ptr = b_contig.data_ptr <float >();
77- float * result_ptr = out.data_ptr <float >();
107+ // An example of an operator that mutates one of its inputs.
108+ void myadd_out_cuda (
109+ const torch::stable::Tensor& a,
110+ const torch::stable::Tensor& b,
111+ torch::stable::Tensor& out) {
112+ STD_TORCH_CHECK (a.sizes ().equals (b.sizes ()), " Tensor sizes must match" );
113+ STD_TORCH_CHECK (b.sizes ().equals (out.sizes ()), " Output tensor size must match inputs" );
114+ STD_TORCH_CHECK (
115+ a.scalar_type () == torch::headeronly::ScalarType::Float,
116+ " Input tensor a must be float32" );
117+ STD_TORCH_CHECK (
118+ b.scalar_type () == torch::headeronly::ScalarType::Float,
119+ " Input tensor b must be float32" );
120+ STD_TORCH_CHECK (
121+ out.scalar_type () == torch::headeronly::ScalarType::Float,
122+ " Output tensor must be float32" );
123+ STD_TORCH_CHECK (out.is_contiguous (), " Output tensor must be contiguous" );
124+ STD_TORCH_CHECK (
125+ a.device ().type () == torch::headeronly::DeviceType::CUDA,
126+ " Input tensor a must be on CUDA" );
127+ STD_TORCH_CHECK (
128+ b.device ().type () == torch::headeronly::DeviceType::CUDA,
129+ " Input tensor b must be on CUDA" );
130+ STD_TORCH_CHECK (
131+ out.device ().type () == torch::headeronly::DeviceType::CUDA,
132+ " Output tensor must be on CUDA" );
133+
134+ torch::stable::Tensor a_contig = torch::stable::contiguous (a);
135+ torch::stable::Tensor b_contig = torch::stable::contiguous (b);
136+
137+ const float * a_ptr = a_contig.const_data_ptr <float >();
138+ const float * b_ptr = b_contig.const_data_ptr <float >();
139+ float * result_ptr = out.mutable_data_ptr <float >();
140+
78141 int numel = a_contig.numel ();
79- cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
142+
143+ void * stream_ptr = nullptr ;
144+ TORCH_ERROR_CODE_CHECK (
145+ aoti_torch_get_current_cuda_stream (a.get_device_index (), &stream_ptr));
146+ cudaStream_t stream = static_cast <cudaStream_t>(stream_ptr);
147+
80148 add_kernel<<<(numel+255 )/256 , 256 , 0 , stream>>> (numel, a_ptr, b_ptr, result_ptr);
81149}
82150
83-
84151// Registers CUDA implementations for mymuladd, mymul, myadd_out
85- TORCH_LIBRARY_IMPL (extension_cpp, CUDA, m) {
86- m.impl (" mymuladd" , &mymuladd_cuda);
87- m.impl (" mymul" , &mymul_cuda);
88- m.impl (" myadd_out" , &myadd_out_cuda);
152+ STABLE_TORCH_LIBRARY_IMPL (extension_cpp, CUDA, m) {
153+ m.impl (" mymuladd" , TORCH_BOX ( &mymuladd_cuda) );
154+ m.impl (" mymul" , TORCH_BOX ( &mymul_cuda) );
155+ m.impl (" myadd_out" , TORCH_BOX ( &myadd_out_cuda) );
89156}
90157
91158}
0 commit comments