diff --git a/python/cuda_cccl/cuda/cccl/cooperative/experimental/_types.py b/python/cuda_cccl/cuda/cccl/cooperative/experimental/_types.py index 76d031d368a..9c372c04a97 100644 --- a/python/cuda_cccl/cuda/cccl/cooperative/experimental/_types.py +++ b/python/cuda_cccl/cuda/cccl/cooperative/experimental/_types.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import re -from functools import cached_property from io import StringIO from textwrap import dedent from types import FunctionType as PyFunctionType @@ -15,6 +14,8 @@ from numba.core import cgutils from numba.core.extending import intrinsic, overload from numba.core.typing import signature +from numba.cuda import LTOIR +from numba.cuda.cudadrv import driver as cuda_driver import cuda.cccl.cooperative.experimental._nvrtc as nvrtc from cuda.cccl.cooperative.experimental._common import find_unsigned @@ -553,6 +554,8 @@ def __init__( self.parameters = parameters self.type_definitions = type_definitions self.fake_return = fake_return + self._temp_storage_bytes = None + self._temp_storage_alignment = None def __repr__(self) -> str: return f"{self.struct_name}::{self.method_name}{self.template_parameters}: {self.parameters}" @@ -604,43 +607,22 @@ def specialize(self, template_arguments): fake_return=self.fake_return, ) - @cached_property - def _temp_storage_bytes_and_alignment(self): - algorithm_name = self.struct_name - includes = self.includes or [] - type_definitions = self.type_definitions or [] - - buf = StringIO() - w = buf.write - - w("#include \n") - for include in includes: - w(f"#include <{include}>\n") - for type_definition in type_definitions: - w(f"{type_definition.code}\n") - - w(f"using algorithm_t = cub::{algorithm_name};\n") - w("using temp_storage_t = typename algorithm_t::TempStorage;\n") - prefix = "__device__ constexpr unsigned temp_storage_" - w(f"{prefix}bytes = sizeof(temp_storage_t);\n") - w(f"{prefix}alignment = alignof(temp_storage_t);\n") - - src = buf.getvalue() - device = cuda.get_current_device() - cc_major, cc_minor = device.compute_capability - cc = cc_major * 10 + cc_minor - _, ptx = nvrtc.compile(cpp=src, cc=cc, rdc=True, code="ptx") - temp_storage_bytes = find_unsigned("temp_storage_bytes", ptx) - temp_storage_alignment = find_unsigned("temp_storage_alignment", ptx) - return (temp_storage_bytes, temp_storage_alignment) - @property def temp_storage_bytes(self): - return self._temp_storage_bytes_and_alignment[0] + if self._temp_storage_bytes is None: + raise RuntimeError( + "Temporary storage bytes not computed yet. Call get_lto_ir() first." + ) + return self._temp_storage_bytes @property def temp_storage_alignment(self): - return self._temp_storage_bytes_and_alignment[1] + if self._temp_storage_alignment is None: + raise RuntimeError( + "Temporary storage alignment not computed yet. " + "Call get_lto_ir() first." + ) + return self._temp_storage_alignment def get_lto_ir(self, threads=None): lto_irs = [] @@ -680,6 +662,9 @@ def get_lto_ir(self, threads=None): w(f"using algorithm_t = cub::{algorithm_name};\n") w("using temp_storage_t = typename algorithm_t::TempStorage;\n") + prefix = "__device__ constexpr unsigned temp_storage_" + w(f"{prefix}bytes = sizeof(temp_storage_t);\n") + w(f"{prefix}alignment = alignof(temp_storage_t);\n") src = buf.getvalue() @@ -775,8 +760,24 @@ def get_lto_ir(self, threads=None): cc = cc_major * 10 + cc_minor # N.B. Uncomment this to immediately print generated source to stdout. # print(src) - _, lto_fn = nvrtc.compile(cpp=src, cc=cc, rdc=True, code="lto") - lto_irs.append(lto_fn) + _, blob = nvrtc.compile(cpp=src, cc=cc, rdc=True, code="lto") + lto_irs.append(blob) + + # Convert the LTO into PTX in order to extract the size and alignment + # variables. + obj = LTOIR(name=self.c_name, data=blob) + linker = cuda_driver._Linker.new( + cc=device.compute_capability, + additional_flags=["-ptx"], + lto=obj, + ) + ltoir_bytes = obj.data + linker.add_ltoir(ltoir_bytes) + ptx = linker.get_linked_ptx() + ptx = ptx.decode("utf-8") + self._temp_storage_bytes = find_unsigned("temp_storage_bytes", ptx) + self._temp_storage_alignment = find_unsigned("temp_storage_alignment", ptx) + return lto_irs def codegen(self, func_to_overload):