1- #include < ATen/Operators.h>
2- #include < torch/all.h>
3- #include < torch/library.h>
1+ // LibTorch Stable ABI version of CUDA custom operators
2+ // This file uses the stable API for cross-version compatibility.
3+ // See: https://pytorch.org/docs/main/notes/libtorch_stable_abi.html
4+
5+ #include < torch/csrc/stable/library.h>
6+ #include < torch/csrc/stable/ops.h>
7+ #include < torch/csrc/stable/tensor.h>
8+ #include < torch/csrc/stable/accelerator.h>
9+ #include < torch/headeronly/core/ScalarType.h>
10+ #include < torch/headeronly/macros/Macros.h>
11+
12+ #include < torch/csrc/inductor/aoti_torch/c/shim.h>
413
514#include < cuda.h>
615#include < cuda_runtime.h>
7- #include < ATen/cuda/CUDAContext.h>
816
917namespace extension_cpp {
1018
@@ -13,21 +21,33 @@ __global__ void muladd_kernel(int numel, const float* a, const float* b, float c
1321 if (idx < numel) result[idx] = a[idx] * b[idx] + c;
1422}
1523
16- at::Tensor mymuladd_cuda (const at::Tensor& a, const at::Tensor& b, double c) {
17- TORCH_CHECK (a.sizes () == b.sizes ());
18- TORCH_CHECK (a.dtype () == at::kFloat );
19- TORCH_CHECK (b.dtype () == at::kFloat );
20- TORCH_INTERNAL_ASSERT (a.device ().type () == at::DeviceType::CUDA);
21- TORCH_INTERNAL_ASSERT (b.device ().type () == at::DeviceType::CUDA);
22- at::Tensor a_contig = a.contiguous ();
23- at::Tensor b_contig = b.contiguous ();
24- at::Tensor result = at::empty (a_contig.sizes (), a_contig.options ());
25- const float * a_ptr = a_contig.data_ptr <float >();
26- const float * b_ptr = b_contig.data_ptr <float >();
27- float * result_ptr = result.data_ptr <float >();
24+ torch::stable::Tensor mymuladd_cuda (
25+ const torch::stable::Tensor& a,
26+ const torch::stable::Tensor& b,
27+ double c) {
28+ STD_TORCH_CHECK (a.sizes ().equals (b.sizes ()));
29+ STD_TORCH_CHECK (a.scalar_type () == torch::headeronly::ScalarType::Float);
30+ STD_TORCH_CHECK (b.scalar_type () == torch::headeronly::ScalarType::Float);
31+ STD_TORCH_CHECK (a.device ().type () == torch::headeronly::DeviceType::CUDA);
32+ STD_TORCH_CHECK (b.device ().type () == torch::headeronly::DeviceType::CUDA);
33+
34+ torch::stable::Tensor a_contig = torch::stable::contiguous (a);
35+ torch::stable::Tensor b_contig = torch::stable::contiguous (b);
36+ torch::stable::Tensor result = torch::stable::empty_like (a_contig);
37+
38+ const float * a_ptr = a_contig.const_data_ptr <float >();
39+ const float * b_ptr = b_contig.const_data_ptr <float >();
40+ float * result_ptr = result.mutable_data_ptr <float >();
2841
2942 int numel = a_contig.numel ();
30- cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
43+
44+ // For now, we rely on the raw shim API to get the current CUDA stream.
45+ // This will be improved in a future release.
46+ void * stream_ptr = nullptr ;
47+ TORCH_ERROR_CODE_CHECK (
48+ aoti_torch_get_current_cuda_stream (a.get_device_index (), &stream_ptr));
49+ cudaStream_t stream = static_cast <cudaStream_t>(stream_ptr);
50+
3151 muladd_kernel<<<(numel+255 )/256 , 256 , 0 , stream>>> (numel, a_ptr, b_ptr, c, result_ptr);
3252 return result;
3353}
@@ -37,20 +57,30 @@ __global__ void mul_kernel(int numel, const float* a, const float* b, float* res
3757 if (idx < numel) result[idx] = a[idx] * b[idx];
3858}
3959
40- at::Tensor mymul_cuda (const at::Tensor& a, const at::Tensor& b) {
41- TORCH_CHECK (a.sizes () == b.sizes ());
42- TORCH_CHECK (a.dtype () == at::kFloat );
43- TORCH_CHECK (b.dtype () == at::kFloat );
44- TORCH_INTERNAL_ASSERT (a.device ().type () == at::DeviceType::CUDA);
45- TORCH_INTERNAL_ASSERT (b.device ().type () == at::DeviceType::CUDA);
46- at::Tensor a_contig = a.contiguous ();
47- at::Tensor b_contig = b.contiguous ();
48- at::Tensor result = at::empty (a_contig.sizes (), a_contig.options ());
49- const float * a_ptr = a_contig.data_ptr <float >();
50- const float * b_ptr = b_contig.data_ptr <float >();
51- float * result_ptr = result.data_ptr <float >();
60+ torch::stable::Tensor mymul_cuda (
61+ const torch::stable::Tensor& a,
62+ const torch::stable::Tensor& b) {
63+ STD_TORCH_CHECK (a.sizes ().equals (b.sizes ()));
64+ STD_TORCH_CHECK (a.scalar_type () == torch::headeronly::ScalarType::Float);
65+ STD_TORCH_CHECK (b.scalar_type () == torch::headeronly::ScalarType::Float);
66+ STD_TORCH_CHECK (a.device ().type () == torch::headeronly::DeviceType::CUDA);
67+ STD_TORCH_CHECK (b.device ().type () == torch::headeronly::DeviceType::CUDA);
68+
69+ torch::stable::Tensor a_contig = torch::stable::contiguous (a);
70+ torch::stable::Tensor b_contig = torch::stable::contiguous (b);
71+ torch::stable::Tensor result = torch::stable::empty_like (a_contig);
72+
73+ const float * a_ptr = a_contig.const_data_ptr <float >();
74+ const float * b_ptr = b_contig.const_data_ptr <float >();
75+ float * result_ptr = result.mutable_data_ptr <float >();
76+
5277 int numel = a_contig.numel ();
53- cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
78+
79+ void * stream_ptr = nullptr ;
80+ TORCH_ERROR_CODE_CHECK (
81+ aoti_torch_get_current_cuda_stream (a.get_device_index (), &stream_ptr));
82+ cudaStream_t stream = static_cast <cudaStream_t>(stream_ptr);
83+
5484 mul_kernel<<<(numel+255 )/256 , 256 , 0 , stream>>> (numel, a_ptr, b_ptr, result_ptr);
5585 return result;
5686}
@@ -60,32 +90,43 @@ __global__ void add_kernel(int numel, const float* a, const float* b, float* res
6090 if (idx < numel) result[idx] = a[idx] + b[idx];
6191}
6292
63- void myadd_out_cuda (const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
64- TORCH_CHECK (a.sizes () == b.sizes ());
65- TORCH_CHECK (b.sizes () == out.sizes ());
66- TORCH_CHECK (a.dtype () == at::kFloat );
67- TORCH_CHECK (b.dtype () == at::kFloat );
68- TORCH_CHECK (out.dtype () == at::kFloat );
69- TORCH_CHECK (out.is_contiguous ());
70- TORCH_INTERNAL_ASSERT (a.device ().type () == at::DeviceType::CUDA);
71- TORCH_INTERNAL_ASSERT (b.device ().type () == at::DeviceType::CUDA);
72- TORCH_INTERNAL_ASSERT (out.device ().type () == at::DeviceType::CUDA);
73- at::Tensor a_contig = a.contiguous ();
74- at::Tensor b_contig = b.contiguous ();
75- const float * a_ptr = a_contig.data_ptr <float >();
76- const float * b_ptr = b_contig.data_ptr <float >();
77- float * result_ptr = out.data_ptr <float >();
93+ // An example of an operator that mutates one of its inputs.
94+ void myadd_out_cuda (
95+ const torch::stable::Tensor& a,
96+ const torch::stable::Tensor& b,
97+ torch::stable::Tensor& out) {
98+ STD_TORCH_CHECK (a.sizes ().equals (b.sizes ()));
99+ STD_TORCH_CHECK (b.sizes ().equals (out.sizes ()));
100+ STD_TORCH_CHECK (a.scalar_type () == torch::headeronly::ScalarType::Float);
101+ STD_TORCH_CHECK (b.scalar_type () == torch::headeronly::ScalarType::Float);
102+ STD_TORCH_CHECK (out.scalar_type () == torch::headeronly::ScalarType::Float);
103+ STD_TORCH_CHECK (out.is_contiguous ());
104+ STD_TORCH_CHECK (a.device ().type () == torch::headeronly::DeviceType::CUDA);
105+ STD_TORCH_CHECK (b.device ().type () == torch::headeronly::DeviceType::CUDA);
106+ STD_TORCH_CHECK (out.device ().type () == torch::headeronly::DeviceType::CUDA);
107+
108+ torch::stable::Tensor a_contig = torch::stable::contiguous (a);
109+ torch::stable::Tensor b_contig = torch::stable::contiguous (b);
110+
111+ const float * a_ptr = a_contig.const_data_ptr <float >();
112+ const float * b_ptr = b_contig.const_data_ptr <float >();
113+ float * result_ptr = out.mutable_data_ptr <float >();
114+
78115 int numel = a_contig.numel ();
79- cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
116+
117+ void * stream_ptr = nullptr ;
118+ TORCH_ERROR_CODE_CHECK (
119+ aoti_torch_get_current_cuda_stream (a.get_device_index (), &stream_ptr));
120+ cudaStream_t stream = static_cast <cudaStream_t>(stream_ptr);
121+
80122 add_kernel<<<(numel+255 )/256 , 256 , 0 , stream>>> (numel, a_ptr, b_ptr, result_ptr);
81123}
82124
83-
84125// Registers CUDA implementations for mymuladd, mymul, myadd_out
85- TORCH_LIBRARY_IMPL (extension_cpp, CUDA, m) {
86- m.impl (" mymuladd" , &mymuladd_cuda);
87- m.impl (" mymul" , &mymul_cuda);
88- m.impl (" myadd_out" , &myadd_out_cuda);
126+ STABLE_TORCH_LIBRARY_IMPL (extension_cpp, CUDA, m) {
127+ m.impl (" mymuladd" , TORCH_BOX ( &mymuladd_cuda) );
128+ m.impl (" mymul" , TORCH_BOX ( &mymul_cuda) );
129+ m.impl (" myadd_out" , TORCH_BOX ( &myadd_out_cuda) );
89130}
90131
91132}
0 commit comments