Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 36 additions & 35 deletions python/cuda_cccl/cuda/cccl/cooperative/experimental/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import re
from functools import cached_property
from io import StringIO
from textwrap import dedent
from types import FunctionType as PyFunctionType
Expand All @@ -15,6 +14,8 @@
from numba.core import cgutils
from numba.core.extending import intrinsic, overload
from numba.core.typing import signature
from numba.cuda import LTOIR
from numba.cuda.cudadrv import driver as cuda_driver

import cuda.cccl.cooperative.experimental._nvrtc as nvrtc
from cuda.cccl.cooperative.experimental._common import find_unsigned
Expand Down Expand Up @@ -553,6 +554,8 @@ def __init__(
self.parameters = parameters
self.type_definitions = type_definitions
self.fake_return = fake_return
self._temp_storage_bytes = None
self._temp_storage_alignment = None

def __repr__(self) -> str:
return f"{self.struct_name}::{self.method_name}{self.template_parameters}: {self.parameters}"
Expand Down Expand Up @@ -604,43 +607,22 @@ def specialize(self, template_arguments):
fake_return=self.fake_return,
)

@cached_property
def _temp_storage_bytes_and_alignment(self):
algorithm_name = self.struct_name
includes = self.includes or []
type_definitions = self.type_definitions or []

buf = StringIO()
w = buf.write

w("#include <cuda/std/cstdint>\n")
for include in includes:
w(f"#include <{include}>\n")
for type_definition in type_definitions:
w(f"{type_definition.code}\n")

w(f"using algorithm_t = cub::{algorithm_name};\n")
w("using temp_storage_t = typename algorithm_t::TempStorage;\n")
prefix = "__device__ constexpr unsigned temp_storage_"
w(f"{prefix}bytes = sizeof(temp_storage_t);\n")
w(f"{prefix}alignment = alignof(temp_storage_t);\n")

src = buf.getvalue()
device = cuda.get_current_device()
cc_major, cc_minor = device.compute_capability
cc = cc_major * 10 + cc_minor
_, ptx = nvrtc.compile(cpp=src, cc=cc, rdc=True, code="ptx")
temp_storage_bytes = find_unsigned("temp_storage_bytes", ptx)
temp_storage_alignment = find_unsigned("temp_storage_alignment", ptx)
return (temp_storage_bytes, temp_storage_alignment)

@property
def temp_storage_bytes(self):
return self._temp_storage_bytes_and_alignment[0]
if self._temp_storage_bytes is None:
raise RuntimeError(
"Temporary storage bytes not computed yet. Call get_lto_ir() first."
)
return self._temp_storage_bytes

@property
def temp_storage_alignment(self):
return self._temp_storage_bytes_and_alignment[1]
if self._temp_storage_alignment is None:
raise RuntimeError(
"Temporary storage alignment not computed yet. "
"Call get_lto_ir() first."

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: why not call get_lto_ir() on the user's behalf instead of raising here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's essentially an internal invariant... users wouldn't be calling this directly in a way where they could hit this. But we might accidentally trip the invariant as part of library development (i.e. adding a new primitive), so -> fail fast.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess put differently, is there any disadvantage to making get_lto_ir() an implementation detail, and simply exposing these (cached) properties? The "user" here is us, i.e., library developers.

Either ways, I'm going ahead and approving. Leaving it to your best judgement here!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I actually had it as a cached property like that last time, but switched to using the underscore + property here as we'll be adding more of these in single-phase.

)
return self._temp_storage_alignment

def get_lto_ir(self, threads=None):
lto_irs = []
Expand Down Expand Up @@ -680,6 +662,9 @@ def get_lto_ir(self, threads=None):

w(f"using algorithm_t = cub::{algorithm_name};\n")
w("using temp_storage_t = typename algorithm_t::TempStorage;\n")
prefix = "__device__ constexpr unsigned temp_storage_"
w(f"{prefix}bytes = sizeof(temp_storage_t);\n")
w(f"{prefix}alignment = alignof(temp_storage_t);\n")

src = buf.getvalue()

Expand Down Expand Up @@ -775,8 +760,24 @@ def get_lto_ir(self, threads=None):
cc = cc_major * 10 + cc_minor
# N.B. Uncomment this to immediately print generated source to stdout.
# print(src)
_, lto_fn = nvrtc.compile(cpp=src, cc=cc, rdc=True, code="lto")
lto_irs.append(lto_fn)
_, blob = nvrtc.compile(cpp=src, cc=cc, rdc=True, code="lto")
lto_irs.append(blob)

# Convert the LTO into PTX in order to extract the size and alignment
# variables.
obj = LTOIR(name=self.c_name, data=blob)
linker = cuda_driver._Linker.new(
cc=device.compute_capability,
additional_flags=["-ptx"],
lto=obj,
)
ltoir_bytes = obj.data
linker.add_ltoir(ltoir_bytes)
ptx = linker.get_linked_ptx()
ptx = ptx.decode("utf-8")
self._temp_storage_bytes = find_unsigned("temp_storage_bytes", ptx)
self._temp_storage_alignment = find_unsigned("temp_storage_alignment", ptx)

return lto_irs

def codegen(self, func_to_overload):
Expand Down