Skip to content

Commit 07ef7e7

Browse files
Migrate extension-cpp to stable API/ABI
1 parent 0ec4969 commit 07ef7e7

File tree

7 files changed

+187
-109
lines changed

7 files changed

+187
-109
lines changed

.github/scripts/setup-env.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ pip install --progress-bar=off -r requirements.txt
101101
echo '::endgroup::'
102102

103103
echo '::group::Install extension-cpp'
104-
python setup.py develop
104+
pip install -e . --no-build-isolation
105105
echo '::endgroup::'
106106

107107
echo '::group::Collect environment information'

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
- python-version: 3.13
2020
runner: linux.g5.4xlarge.nvidia.gpu
2121
gpu-arch-type: cuda
22-
gpu-arch-version: "12.4"
22+
gpu-arch-version: "12.9"
2323
fail-fast: false
2424
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
2525
permissions:

README.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1-
# C++/CUDA Extensions in PyTorch
1+
# C++/CUDA Extensions in PyTorch with LibTorch Stable ABI
2+
3+
An example of writing a C++/CUDA extension for PyTorch using the [LibTorch Stable ABI](https://pytorch.org/docs/main/notes/libtorch_stable_abi.html).
4+
See [here](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html) for the accompanying tutorial.
25

3-
An example of writing a C++/CUDA extension for PyTorch. See
4-
[here](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html) for the accompanying tutorial.
56
This repo demonstrates how to write an example `extension_cpp.ops.mymuladd`
6-
custom op that has both custom CPU and CUDA kernels.
7+
custom op that has both custom CPU and CUDA kernels, it leverages the LibTorch
8+
Stable ABI to ensure that the extension built can be run with any version of
9+
PyTorch >= 2.10.0.
10+
11+
The examples in this repo work with PyTorch 2.10+. For an example of how to use
12+
the non-stable subset of LibTorch, see [this previous commit](https://github.com/pytorch/extension-cpp/tree/0ec4969c7bc8e15a8456e5eb9d9ca0a7ec15bc95).
713

8-
The examples in this repo work with PyTorch 2.4+.
914

1015
To build:
1116
```

extension_cpp/csrc/cuda/muladd.cu

Lines changed: 92 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
#include <ATen/Operators.h>
2-
#include <torch/all.h>
3-
#include <torch/library.h>
1+
// LibTorch Stable ABI version of CUDA custom operators
2+
// This file uses the stable API for cross-version compatibility.
3+
// See: https://pytorch.org/docs/main/notes/libtorch_stable_abi.html
4+
5+
#include <torch/csrc/stable/library.h>
6+
#include <torch/csrc/stable/ops.h>
7+
#include <torch/csrc/stable/tensor.h>
8+
#include <torch/csrc/stable/accelerator.h>
9+
#include <torch/headeronly/core/ScalarType.h>
10+
#include <torch/headeronly/macros/Macros.h>
11+
12+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
413

514
#include <cuda.h>
615
#include <cuda_runtime.h>
7-
#include <ATen/cuda/CUDAContext.h>
816

917
namespace extension_cpp {
1018

@@ -13,21 +21,33 @@ __global__ void muladd_kernel(int numel, const float* a, const float* b, float c
1321
if (idx < numel) result[idx] = a[idx] * b[idx] + c;
1422
}
1523

16-
at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
17-
TORCH_CHECK(a.sizes() == b.sizes());
18-
TORCH_CHECK(a.dtype() == at::kFloat);
19-
TORCH_CHECK(b.dtype() == at::kFloat);
20-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
21-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
22-
at::Tensor a_contig = a.contiguous();
23-
at::Tensor b_contig = b.contiguous();
24-
at::Tensor result = at::empty(a_contig.sizes(), a_contig.options());
25-
const float* a_ptr = a_contig.data_ptr<float>();
26-
const float* b_ptr = b_contig.data_ptr<float>();
27-
float* result_ptr = result.data_ptr<float>();
24+
torch::stable::Tensor mymuladd_cuda(
25+
const torch::stable::Tensor& a,
26+
const torch::stable::Tensor& b,
27+
double c) {
28+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
29+
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
30+
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
31+
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA);
32+
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA);
33+
34+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
35+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
36+
torch::stable::Tensor result = torch::stable::empty_like(a_contig);
37+
38+
const float* a_ptr = a_contig.const_data_ptr<float>();
39+
const float* b_ptr = b_contig.const_data_ptr<float>();
40+
float* result_ptr = result.mutable_data_ptr<float>();
2841

2942
int numel = a_contig.numel();
30-
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
43+
44+
// For now, we rely on the raw shim API to get the current CUDA stream.
45+
// This will be improved in a future release.
46+
void* stream_ptr = nullptr;
47+
TORCH_ERROR_CODE_CHECK(
48+
aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr));
49+
cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);
50+
3151
muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr);
3252
return result;
3353
}
@@ -37,20 +57,30 @@ __global__ void mul_kernel(int numel, const float* a, const float* b, float* res
3757
if (idx < numel) result[idx] = a[idx] * b[idx];
3858
}
3959

40-
at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) {
41-
TORCH_CHECK(a.sizes() == b.sizes());
42-
TORCH_CHECK(a.dtype() == at::kFloat);
43-
TORCH_CHECK(b.dtype() == at::kFloat);
44-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
45-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
46-
at::Tensor a_contig = a.contiguous();
47-
at::Tensor b_contig = b.contiguous();
48-
at::Tensor result = at::empty(a_contig.sizes(), a_contig.options());
49-
const float* a_ptr = a_contig.data_ptr<float>();
50-
const float* b_ptr = b_contig.data_ptr<float>();
51-
float* result_ptr = result.data_ptr<float>();
60+
torch::stable::Tensor mymul_cuda(
61+
const torch::stable::Tensor& a,
62+
const torch::stable::Tensor& b) {
63+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
64+
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
65+
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
66+
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA);
67+
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA);
68+
69+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
70+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
71+
torch::stable::Tensor result = torch::stable::empty_like(a_contig);
72+
73+
const float* a_ptr = a_contig.const_data_ptr<float>();
74+
const float* b_ptr = b_contig.const_data_ptr<float>();
75+
float* result_ptr = result.mutable_data_ptr<float>();
76+
5277
int numel = a_contig.numel();
53-
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
78+
79+
void* stream_ptr = nullptr;
80+
TORCH_ERROR_CODE_CHECK(
81+
aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr));
82+
cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);
83+
5484
mul_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr);
5585
return result;
5686
}
@@ -60,32 +90,43 @@ __global__ void add_kernel(int numel, const float* a, const float* b, float* res
6090
if (idx < numel) result[idx] = a[idx] + b[idx];
6191
}
6292

63-
void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
64-
TORCH_CHECK(a.sizes() == b.sizes());
65-
TORCH_CHECK(b.sizes() == out.sizes());
66-
TORCH_CHECK(a.dtype() == at::kFloat);
67-
TORCH_CHECK(b.dtype() == at::kFloat);
68-
TORCH_CHECK(out.dtype() == at::kFloat);
69-
TORCH_CHECK(out.is_contiguous());
70-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
71-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
72-
TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CUDA);
73-
at::Tensor a_contig = a.contiguous();
74-
at::Tensor b_contig = b.contiguous();
75-
const float* a_ptr = a_contig.data_ptr<float>();
76-
const float* b_ptr = b_contig.data_ptr<float>();
77-
float* result_ptr = out.data_ptr<float>();
93+
// An example of an operator that mutates one of its inputs.
94+
void myadd_out_cuda(
95+
const torch::stable::Tensor& a,
96+
const torch::stable::Tensor& b,
97+
torch::stable::Tensor& out) {
98+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
99+
STD_TORCH_CHECK(b.sizes().equals(out.sizes()));
100+
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
101+
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
102+
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Float);
103+
STD_TORCH_CHECK(out.is_contiguous());
104+
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA);
105+
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA);
106+
STD_TORCH_CHECK(out.device().type() == torch::headeronly::DeviceType::CUDA);
107+
108+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
109+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
110+
111+
const float* a_ptr = a_contig.const_data_ptr<float>();
112+
const float* b_ptr = b_contig.const_data_ptr<float>();
113+
float* result_ptr = out.mutable_data_ptr<float>();
114+
78115
int numel = a_contig.numel();
79-
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
116+
117+
void* stream_ptr = nullptr;
118+
TORCH_ERROR_CODE_CHECK(
119+
aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr));
120+
cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);
121+
80122
add_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr);
81123
}
82124

83-
84125
// Registers CUDA implementations for mymuladd, mymul, myadd_out
85-
TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
86-
m.impl("mymuladd", &mymuladd_cuda);
87-
m.impl("mymul", &mymul_cuda);
88-
m.impl("myadd_out", &myadd_out_cuda);
126+
STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
127+
m.impl("mymuladd", TORCH_BOX(&mymuladd_cuda));
128+
m.impl("mymul", TORCH_BOX(&mymul_cuda));
129+
m.impl("myadd_out", TORCH_BOX(&myadd_out_cuda));
89130
}
90131

91132
}

extension_cpp/csrc/muladd.cpp

Lines changed: 71 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1+
// LibTorch Stable ABI version of custom operators
2+
// This file uses the stable API for cross-version compatibility.
3+
// See: https://pytorch.org/docs/main/notes/libtorch_stable_abi.html
4+
15
#include <Python.h>
2-
#include <ATen/Operators.h>
3-
#include <torch/all.h>
4-
#include <torch/library.h>
56

6-
#include <vector>
7+
#include <torch/csrc/stable/library.h>
8+
#include <torch/csrc/stable/ops.h>
9+
#include <torch/csrc/stable/tensor.h>
10+
#include <torch/headeronly/core/ScalarType.h>
11+
#include <torch/headeronly/macros/Macros.h>
712

813
extern "C" {
914
/* Creates a dummy empty _C module that can be imported from Python.
1015
The import from Python will load the .so consisting of this file
11-
in this extension, so that the TORCH_LIBRARY static initializers
16+
in this extension, so that the STABLE_TORCH_LIBRARY static initializers
1217
below are run. */
1318
PyObject* PyInit__C(void)
1419
{
@@ -26,75 +31,92 @@ extern "C" {
2631

2732
namespace extension_cpp {
2833

29-
at::Tensor mymuladd_cpu(const at::Tensor& a, const at::Tensor& b, double c) {
30-
TORCH_CHECK(a.sizes() == b.sizes());
31-
TORCH_CHECK(a.dtype() == at::kFloat);
32-
TORCH_CHECK(b.dtype() == at::kFloat);
33-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
34-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
35-
at::Tensor a_contig = a.contiguous();
36-
at::Tensor b_contig = b.contiguous();
37-
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
38-
const float* a_ptr = a_contig.data_ptr<float>();
39-
const float* b_ptr = b_contig.data_ptr<float>();
40-
float* result_ptr = result.data_ptr<float>();
34+
torch::stable::Tensor mymuladd_cpu(
35+
const torch::stable::Tensor& a,
36+
const torch::stable::Tensor& b,
37+
double c) {
38+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
39+
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
40+
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
41+
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU);
42+
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU);
43+
44+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
45+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
46+
torch::stable::Tensor result = torch::stable::empty_like(a_contig);
47+
48+
const float* a_ptr = a_contig.const_data_ptr<float>();
49+
const float* b_ptr = b_contig.const_data_ptr<float>();
50+
float* result_ptr = result.mutable_data_ptr<float>();
51+
4152
for (int64_t i = 0; i < result.numel(); i++) {
4253
result_ptr[i] = a_ptr[i] * b_ptr[i] + c;
4354
}
4455
return result;
4556
}
4657

47-
at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) {
48-
TORCH_CHECK(a.sizes() == b.sizes());
49-
TORCH_CHECK(a.dtype() == at::kFloat);
50-
TORCH_CHECK(b.dtype() == at::kFloat);
51-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
52-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
53-
at::Tensor a_contig = a.contiguous();
54-
at::Tensor b_contig = b.contiguous();
55-
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
56-
const float* a_ptr = a_contig.data_ptr<float>();
57-
const float* b_ptr = b_contig.data_ptr<float>();
58-
float* result_ptr = result.data_ptr<float>();
58+
torch::stable::Tensor mymul_cpu(
59+
const torch::stable::Tensor& a,
60+
const torch::stable::Tensor& b) {
61+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
62+
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
63+
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
64+
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU);
65+
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU);
66+
67+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
68+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
69+
torch::stable::Tensor result = torch::stable::empty_like(a_contig);
70+
71+
const float* a_ptr = a_contig.const_data_ptr<float>();
72+
const float* b_ptr = b_contig.const_data_ptr<float>();
73+
float* result_ptr = result.mutable_data_ptr<float>();
74+
5975
for (int64_t i = 0; i < result.numel(); i++) {
6076
result_ptr[i] = a_ptr[i] * b_ptr[i];
6177
}
6278
return result;
6379
}
6480

6581
// An example of an operator that mutates one of its inputs.
66-
void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
67-
TORCH_CHECK(a.sizes() == b.sizes());
68-
TORCH_CHECK(b.sizes() == out.sizes());
69-
TORCH_CHECK(a.dtype() == at::kFloat);
70-
TORCH_CHECK(b.dtype() == at::kFloat);
71-
TORCH_CHECK(out.dtype() == at::kFloat);
72-
TORCH_CHECK(out.is_contiguous());
73-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
74-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
75-
TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU);
76-
at::Tensor a_contig = a.contiguous();
77-
at::Tensor b_contig = b.contiguous();
78-
const float* a_ptr = a_contig.data_ptr<float>();
79-
const float* b_ptr = b_contig.data_ptr<float>();
80-
float* result_ptr = out.data_ptr<float>();
82+
void myadd_out_cpu(
83+
const torch::stable::Tensor& a,
84+
const torch::stable::Tensor& b,
85+
torch::stable::Tensor& out) {
86+
STD_TORCH_CHECK(a.sizes().equals(b.sizes()));
87+
STD_TORCH_CHECK(b.sizes().equals(out.sizes()));
88+
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float);
89+
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float);
90+
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Float);
91+
STD_TORCH_CHECK(out.is_contiguous());
92+
STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU);
93+
STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU);
94+
STD_TORCH_CHECK(out.device().type() == torch::headeronly::DeviceType::CPU);
95+
96+
torch::stable::Tensor a_contig = torch::stable::contiguous(a);
97+
torch::stable::Tensor b_contig = torch::stable::contiguous(b);
98+
99+
const float* a_ptr = a_contig.const_data_ptr<float>();
100+
const float* b_ptr = b_contig.const_data_ptr<float>();
101+
float* result_ptr = out.mutable_data_ptr<float>();
102+
81103
for (int64_t i = 0; i < out.numel(); i++) {
82104
result_ptr[i] = a_ptr[i] + b_ptr[i];
83105
}
84106
}
85107

86108
// Defines the operators
87-
TORCH_LIBRARY(extension_cpp, m) {
109+
STABLE_TORCH_LIBRARY(extension_cpp, m) {
88110
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
89111
m.def("mymul(Tensor a, Tensor b) -> Tensor");
90112
m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()");
91113
}
92114

93115
// Registers CPU implementations for mymuladd, mymul, myadd_out
94-
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
95-
m.impl("mymuladd", &mymuladd_cpu);
96-
m.impl("mymul", &mymul_cpu);
97-
m.impl("myadd_out", &myadd_out_cpu);
116+
STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
117+
m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu));
118+
m.impl("mymul", TORCH_BOX(&mymul_cpu));
119+
m.impl("myadd_out", TORCH_BOX(&myadd_out_cpu));
98120
}
99121

100122
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[build-system]
22
requires = [
33
"setuptools",
4-
"torch",
4+
"torch>=2.10.0",
55
]
66
build-backend = "setuptools.build_meta"

0 commit comments

Comments
 (0)