diff --git a/requirements.txt b/requirements.txt index af3149e..f9c596d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ torch numpy +packaging diff --git a/setup.py b/setup.py index 0dde1e4..367f574 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,7 @@ import glob from setuptools import find_packages, setup +from packaging.version import Version from torch.utils.cpp_extension import ( CppExtension, @@ -18,10 +19,7 @@ library_name = "extension_cpp" -if torch.__version__ >= "2.6.0": - py_limited_api = True -else: - py_limited_api = False +py_limited_api = Version(torch.__version__) >= Version("2.6.0") def get_extensions():