diff --git a/docs/source/argument_intents.rst b/docs/source/argument_intents.rst index 6fa8ce4e..f595020f 100644 --- a/docs/source/argument_intents.rst +++ b/docs/source/argument_intents.rst @@ -112,3 +112,26 @@ Notes returns that value directly (not a tuple). - If C++ returns ``void`` and there are multiple ``out_return`` parameters, Numbast returns ``types.Tuple((out1, out2, ...))``. + +Pointer return materialization +------------------------------ + +Some APIs return a borrowed pointer to a fixed-size sequence owned elsewhere: + +.. code-block:: c++ + + __device__ const float4 *get_transform(Handle h); + +Use ``Function Return Materializations`` to copy a configured number of +pointee values into the Numba return value: + +.. code-block:: yaml + + Function Return Materializations: + get_transform: + kind: pointer + length: 3 + +The generated Numba signature returns ``UniTuple(float32x4, 3)``. The source +pointer is still the C++ return value; Numbast does not allocate or own that +storage beyond copying the configured elements during the call. diff --git a/docs/source/generated/static_binding_schema_reference.rst b/docs/source/generated/static_binding_schema_reference.rst index 153dc220..ca590520 100644 --- a/docs/source/generated/static_binding_schema_reference.rst +++ b/docs/source/generated/static_binding_schema_reference.rst @@ -278,6 +278,23 @@ Optional keys 0: in +``Function Return Materializations`` : ``object`` + Per-function return materialization overrides for borrowed fixed-size pointer returns. Function keys map to an + integer length or an object with kind/intent pointer metadata and a length, size, or count. + + Default: ``{}``. + + Example: + + .. code-block:: yaml + + Function Return Materializations: + get_transform: + kind: pointer + length: 3 + get_coefficients: 8 + + Optional nested keys ^^^^^^^^^^^^^^^^^^^^ @@ -615,3 +632,59 @@ Raw schema type: string enum: ["in", "inout_ptr", "out_ptr", "out_return"] required: ["intent"] + Function Return Materializations: + type: object + default: {} + description: > + Per-function return materialization overrides for borrowed fixed-size pointer returns. Function keys map to an + integer length or an object with kind/intent pointer metadata and a length, size, or count. + + examples: + - get_transform: + kind: pointer + length: 3 + get_coefficients: 8 + additionalProperties: + oneOf: + - type: integer + minimum: 1 + - type: object + additionalProperties: false + properties: + kind: + type: string + enum: + - pointer + - ptr + - pointer_return + - ptr_return + - borrowed_ptr + - borrowed_pointer + - fixed_size_pointer + - borrowed_fixed_size_ptr + - borrowed_fixed_size_pointer + intent: + type: string + enum: + - pointer + - ptr + - pointer_return + - ptr_return + - borrowed_ptr + - borrowed_pointer + - fixed_size_pointer + - borrowed_fixed_size_ptr + - borrowed_fixed_size_pointer + length: + type: integer + minimum: 1 + size: + type: integer + minimum: 1 + count: + type: integer + minimum: 1 + anyOf: + - required: ["length"] + - required: ["size"] + - required: ["count"] diff --git a/numbast/src/numbast/__init__.py b/numbast/src/numbast/__init__.py index 2e949470..a59ca0ad 100644 --- a/numbast/src/numbast/__init__.py +++ b/numbast/src/numbast/__init__.py @@ -16,6 +16,7 @@ bind_cxx_function_templates, ) from numbast.enum import bind_cxx_enum, bind_cxx_enums +from numbast.return_materialization import PointerReturnMaterialization from numbast.shim_writer import MemoryShimWriter, FileShimWriter import importlib.metadata @@ -40,6 +41,7 @@ "bind_cxx_class_template", "bind_cxx_class_templates", "clear_concrete_type_caches", + "PointerReturnMaterialization", "MemoryShimWriter", "FileShimWriter", ] diff --git a/numbast/src/numbast/callconv.py b/numbast/src/numbast/callconv.py index 04f9bb19..6a54b312 100644 --- a/numbast/src/numbast/callconv.py +++ b/numbast/src/numbast/callconv.py @@ -3,6 +3,7 @@ from numbast.args import prepare_ir_types from numbast.intent import IntentPlan +from numbast.return_materialization import PointerReturnMaterialization from numbast.types import get_numba_type_alignof # NBST:BEGIN_CALLCONV @@ -78,6 +79,7 @@ def __init__( intent_plan: IntentPlan | None = None, out_return_types: list[types.Type] | None = None, cxx_return_type: types.Type | None = None, + return_materialization: PointerReturnMaterialization | None = None, ): """ Initialize a FunctionCallConv with shim information and optional ABI/intent hints. @@ -90,6 +92,10 @@ def __init__( intent_plan (IntentPlan | None): Optional plan describing visible parameter indices, which parameters should be passed as pointers, and which parameters are out-returns; when present it drives argument mapping and out-return handling. out_return_types (list[types.Type] | None): Types of the out-return values in the order declared by the IntentPlan; required when the intent_plan defines out-return indices. cxx_return_type (types.Type | None): The C++ ABI return type to use for allocating/shimming the return slot; if None, the signature's return type is used. + return_materialization (PointerReturnMaterialization | None): + Optional borrowed-pointer return materialization plan. When + provided, the C++ return value must be a pointer and Numbast + copies `length` pointee elements into a UniTuple return value. """ super().__init__(itanium_mangled_name, shim_writer, shim_code) self._arg_is_ref = list(arg_is_ref) if arg_is_ref is not None else None @@ -98,6 +104,41 @@ def __init__( list(out_return_types) if out_return_types is not None else None ) self._cxx_return_type = cxx_return_type + self._return_materialization = return_materialization + + def _materialize_pointer_return( + self, + builder, + context, + cxx_return_type, + retval_ptr, + retval_align, + ): + if self._return_materialization is None: + raise ValueError("return materialization was not configured") + if not isinstance(cxx_return_type, types.CPointer): + raise ValueError( + "pointer return materialization requires a C++ pointer return " + f"type, got {cxx_return_type}" + ) + + pointee_numba_ty = cxx_return_type.dtype + pointee_ir_ty = context.get_value_type(pointee_numba_ty) + pointee_align = _get_alloca_alignment( + context, pointee_ir_ty, pointee_numba_ty + ) + source_ptr = builder.load(retval_ptr, align=retval_align) + + ret_vals: list[ir.Value] = [] + for idx in range(self._return_materialization.length): + offset = ir.Constant(ir.IntType(32), idx) + elem_ptr = builder.gep(source_ptr, [offset], inbounds=True) + ret_vals.append(builder.load(elem_ptr, align=pointee_align)) + + return_type = types.UniTuple( + pointee_numba_ty, self._return_materialization.length + ) + return context.make_tuple(builder, return_type, ret_vals) def _lower_impl(self, builder, context, sig, args): # Numba-visible return type may differ from the underlying C++ return type @@ -125,6 +166,13 @@ def _lower_impl(self, builder, context, sig, args): if self._cxx_return_type is not None else sig.return_type ) + if ( + self._return_materialization is not None + and cxx_return_type == types.void + ): + raise ValueError( + "pointer return materialization cannot be used with void returns" + ) # 1. Prepare return value pointer if cxx_return_type == types.void: # Void return type in C++ is shimmed as int& ignored @@ -248,18 +296,29 @@ def _lower_impl(self, builder, context, sig, args): builder.call(fn, (retval_ptr, *ptrs)) # 5. Return + materialized_return = None + if self._return_materialization is not None: + materialized_return = self._materialize_pointer_return( + builder, context, cxx_return_type, retval_ptr, retval_align + ) + if ( self._intent_plan is None or not self._intent_plan.out_return_indices ): if cxx_return_type == types.void: return None + if materialized_return is not None: + return materialized_return return builder.load(retval_ptr, align=retval_align) # out_return enabled: return either a value or a tuple (ret, out1, out2, ...) ret_vals: list[ir.Value] = [] if cxx_return_type != types.void: - ret_vals.append(builder.load(retval_ptr, align=retval_align)) + if materialized_return is not None: + ret_vals.append(materialized_return) + else: + ret_vals.append(builder.load(retval_ptr, align=retval_align)) for out_return in out_return_ptrs: ret_vals.append( builder.load(out_return.ptr, align=out_return.align) diff --git a/numbast/src/numbast/function.py b/numbast/src/numbast/function.py index 23049b14..87f24d66 100644 --- a/numbast/src/numbast/function.py +++ b/numbast/src/numbast/function.py @@ -15,6 +15,10 @@ from numbast.types import to_numba_type, to_numba_arg_type from numbast.intent import compute_intent_plan +from numbast.return_materialization import ( + PointerReturnMaterialization, + parse_return_materialization, +) from numbast.utils import ( deduplicate_overloads, make_function_shim, @@ -42,6 +46,22 @@ def func(): func_obj_registry: dict[str, object] = defaultdict(make_new_func_obj) +def _visible_cxx_return_type( + cxx_return_type: nbtypes.Type, + return_materialization: PointerReturnMaterialization | None, +) -> nbtypes.Type: + if return_materialization is None: + return cxx_return_type + if not isinstance(cxx_return_type, nbtypes.CPointer): + raise ValueError( + "pointer return materialization requires a C++ pointer return " + f"type, got {cxx_return_type}" + ) + return nbtypes.UniTuple( + cxx_return_type.dtype, return_materialization.length + ) + + def bind_cxx_operator_overload_function( shim_writer: ShimWriter, func_decl: Function, @@ -129,6 +149,7 @@ def bind_cxx_non_operator_function( exclude: set[str], *, arg_intent: dict | None = None, + return_materializations: dict | None = None, ) -> object: """ Create a Python-callable binding for a C++ non-operator function. @@ -147,6 +168,8 @@ def bind_cxx_non_operator_function( Set of function names to exclude from binding. arg_intent : dict | None, optional Optional per-function intent overrides that specify visibility and in/out semantics for reference parameters. + return_materializations : dict | None, optional + Optional per-function return materialization specs for borrowed fixed-size pointer returns. Returns ------- @@ -164,11 +187,19 @@ def bind_cxx_non_operator_function( cxx_return_type = to_numba_type( func_decl.return_type.unqualified_non_ref_type_name ) + return_materialization = parse_return_materialization( + return_materializations.get(func_decl.name) + if return_materializations + else None + ) + visible_cxx_return_type = _visible_cxx_return_type( + cxx_return_type, return_materialization + ) overrides = arg_intent.get(func_decl.name) if arg_intent else None if overrides is None: # Backward-compatible default: refs are input-only values. - return_type = cxx_return_type + return_type = visible_cxx_return_type param_types = [to_numba_arg_type(arg) for arg in func_decl.param_types] arg_is_ref = [ bool(t.is_left_reference() or t.is_right_reference()) @@ -211,10 +242,10 @@ def bind_cxx_non_operator_function( return_type = nbtypes.Tuple(tuple(out_return_types)) else: return_type = nbtypes.Tuple( - tuple([cxx_return_type, *out_return_types]) + tuple([visible_cxx_return_type, *out_return_types]) ) else: - return_type = cxx_return_type + return_type = visible_cxx_return_type # In intentful mode, pass-through pointers are controlled by intent_plan, # not by whether the C++ parameter is a reference. @@ -235,7 +266,11 @@ class func_typing(ConcreteTemplate): register_global(func, nbtypes.Function(func_typing)) - return_type_name = func_decl.return_type.unqualified_non_ref_type_name + return_type_name = ( + func_decl.return_type.name + if return_materialization is not None + else func_decl.return_type.unqualified_non_ref_type_name + ) mangled_name = deduplicate_overloads(func_decl.mangled_name) shim_func_name = f"{mangled_name}_nbst" @@ -251,6 +286,7 @@ class func_typing(ConcreteTemplate): intent_plan=intent_plan, out_return_types=out_return_types, cxx_return_type=cxx_return_type, + return_materialization=return_materialization, ) # Lowering @@ -269,6 +305,7 @@ def bind_cxx_function( exclude: set[str] = set(), *, arg_intent: dict | None = None, + return_materializations: dict | None = None, ) -> object: """ Create Python bindings for a C++ function. @@ -282,6 +319,8 @@ def bind_cxx_function( arg_intent (dict | None): Optional explicit intent overrides that control which C++ reference parameters are exposed as inputs, outputs, or inout pointers and which parameters are promoted to out-returns. + return_materializations (dict | None): Optional per-function return + materialization specs for borrowed fixed-size pointer returns. Returns: object or None: The Numba-CUDA-callable Python binding object for the function, or `None` @@ -307,6 +346,7 @@ def bind_cxx_function( skip_prefix=skip_prefix, exclude=exclude, arg_intent=arg_intent, + return_materializations=return_materializations, ) return None @@ -320,6 +360,7 @@ def bind_cxx_functions( exclude: set[str] = set(), *, arg_intent: dict | None = None, + return_materializations: dict | None = None, ) -> list[object]: """Create bindings for a list of C++ functions. @@ -339,6 +380,9 @@ def bind_cxx_functions( exclude : set[str] A set of function names to exclude. Default to empty set. + return_materializations : dict | None + Optional per-function return materialization specs for borrowed fixed-size + pointer returns. Returns ------- @@ -355,6 +399,7 @@ def bind_cxx_functions( skip_non_device=skip_non_device, exclude=exclude, arg_intent=arg_intent, + return_materializations=return_materializations, ) # overloaded operator (e.g. "+") do not need to have a separate API # as they are called directly from the Python operator. diff --git a/numbast/src/numbast/return_materialization.py b/numbast/src/numbast/return_materialization.py new file mode 100644 index 00000000..fa03ea3b --- /dev/null +++ b/numbast/src/numbast/return_materialization.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +# NBST:BEGIN_RETURN_MATERIALIZATION_DEFS +from dataclasses import dataclass +from typing import Any, Mapping + + +@dataclass(frozen=True) +class PointerReturnMaterialization: + """ + Describes a borrowed pointer return copied into a Numba tuple. + """ + + length: int + + def __post_init__(self): + if type(self.length) is not int: + raise TypeError( + "pointer return materialization length must be an int" + ) + if self.length <= 0: + raise ValueError( + "pointer return materialization length must be positive" + ) + + +def parse_return_materialization( + raw: Any, +) -> PointerReturnMaterialization | None: + """ + Normalize user-facing return materialization config. + """ + if raw is None: + return None + if isinstance(raw, PointerReturnMaterialization): + return raw + if type(raw) is int: + return PointerReturnMaterialization(raw) + if not isinstance(raw, Mapping): + raise TypeError( + "return materialization must be an int, mapping, " + f"PointerReturnMaterialization, or None; got {type(raw)}" + ) + + kind = raw.get("kind", raw.get("intent", "pointer")) + allowed_kinds = { + "pointer", + "ptr", + "pointer_return", + "ptr_return", + "borrowed_ptr", + "borrowed_pointer", + "fixed_size_pointer", + "borrowed_fixed_size_ptr", + "borrowed_fixed_size_pointer", + } + if kind not in allowed_kinds: + raise ValueError( + f"unsupported return materialization kind {kind!r}; " + f"expected one of {sorted(allowed_kinds)}" + ) + + length = raw.get("length", raw.get("size", raw.get("count", None))) + if length is None: + raise ValueError( + "pointer return materialization requires a length, size, or count" + ) + return PointerReturnMaterialization(length) + + +# NBST:END_RETURN_MATERIALIZATION_DEFS diff --git a/numbast/src/numbast/static/callconv.py b/numbast/src/numbast/static/callconv.py index 229ddcf2..3596ea85 100644 --- a/numbast/src/numbast/static/callconv.py +++ b/numbast/src/numbast/static/callconv.py @@ -4,6 +4,7 @@ from numbast import callconv from numbast import args from numbast import intent_defs as intent_mod +from numbast import return_materialization as return_materialization_mod from numbast.types import NUMBA_TYPE_ALIGNOF_MAPS @@ -39,6 +40,13 @@ def _extract_section(src: str, begin: str, end: str) -> str: _INTENT_DEFS_SRC, "# NBST:BEGIN_INTENT_DEFS", "# NBST:END_INTENT_DEFS" ) +_RETURN_MATERIALIZATION_DEFS_SRC = inspect.getsource(return_materialization_mod) +RETURN_MATERIALIZATION_SRC = _extract_section( + _RETURN_MATERIALIZATION_DEFS_SRC, + "# NBST:BEGIN_RETURN_MATERIALIZATION_DEFS", + "# NBST:END_RETURN_MATERIALIZATION_DEFS", +) + ARGS_SRC = inspect.getsource(args) _CALLCONV_SRC = inspect.getsource(callconv) _CALLCONV_SRC_SECTION = _extract_section( @@ -71,6 +79,8 @@ def get_numba_type_alignof(numba_type): CALLCONV_SRC = ( INTENT_SRC + "\n" + + RETURN_MATERIALIZATION_SRC + + "\n" + ARGS_SRC + "\n" + _render_numba_type_alignof_src() diff --git a/numbast/src/numbast/static/function.py b/numbast/src/numbast/static/function.py index 3627bb99..70eba673 100644 --- a/numbast/src/numbast/static/function.py +++ b/numbast/src/numbast/static/function.py @@ -18,6 +18,10 @@ ) from numbast.static.types import to_numba_type_str, to_numba_arg_type_str from numbast.intent import ArgIntent, compute_intent_plan +from numbast.return_materialization import ( + PointerReturnMaterialization, + parse_return_materialization, +) from numbast.utils import make_function_shim, _apply_prefix_removal from numbast.errors import TypeNotFoundError, MangledFunctionNameConflictError @@ -37,6 +41,34 @@ """A set of created function API names.""" +def _visible_cxx_return_type_str( + cxx_return_type: str, + return_materialization: PointerReturnMaterialization | None, +) -> str: + if return_materialization is None: + return cxx_return_type + if not cxx_return_type.startswith( + "CPointer(" + ) or not cxx_return_type.endswith(")"): + raise ValueError( + "pointer return materialization requires a C++ pointer return " + f"type, got {cxx_return_type}" + ) + BaseRenderer._try_import_numba_type("UniTuple") + pointee_type = cxx_return_type[len("CPointer(") : -1] + return f"UniTuple({pointee_type}, {return_materialization.length})" + + +def _render_return_materialization( + return_materialization: PointerReturnMaterialization | None, +) -> str: + if return_materialization is None: + return "None" + return ( + f"PointerReturnMaterialization(length={return_materialization.length})" + ) + + def _matches_any_regex_pattern(name: str, patterns: list[str]) -> bool: """Check if a function name matches any of the provided regex patterns. @@ -114,6 +146,7 @@ def impl(context, builder, sig, args): intent_plan={intent_plan}, out_return_types={out_return_types}, cxx_return_type={cxx_return_type}, + return_materialization={return_materialization}, ) return callconv(builder, context, sig, args) """ @@ -146,6 +179,7 @@ def __init__( header_path: str, use_cooperative: bool, function_argument_intents: dict | None = None, + function_return_materializations: dict | None = None, ): """ Initialize the renderer for a static (non-operator) C++/CUDA function binding and compute type/intent metadata used for shim and lowering generation. @@ -172,11 +206,22 @@ def __init__( overrides = None if function_argument_intents: overrides = function_argument_intents.get(self._decl.name) + self._return_materialization = parse_return_materialization( + function_return_materializations.get(self._decl.name) + if function_return_materializations + else None + ) self._cxx_return_type = to_numba_type_str( self._decl.return_type.unqualified_non_ref_type_name ) self._cxx_return_type_str = str(self._cxx_return_type) + self._visible_cxx_return_type_str = _visible_cxx_return_type_str( + self._cxx_return_type_str, self._return_materialization + ) + self._return_materialization_rendered = _render_return_materialization( + self._return_materialization + ) if overrides is None: self._argument_numba_types = [ @@ -189,10 +234,14 @@ def __init__( bool(t.is_left_reference() or t.is_right_reference()) for t in self._decl.param_types ] - self._return_numba_type_str = self._cxx_return_type_str + self._return_numba_type_str = self._visible_cxx_return_type_str self._intent_plan_rendered = "None" self._out_return_types_rendered = "None" - self._cxx_return_type_rendered = "None" + self._cxx_return_type_rendered = ( + self._cxx_return_type_str + if self._return_materialization is not None + else "None" + ) else: plan = compute_intent_plan( params=self._decl.params, @@ -235,11 +284,11 @@ def __init__( self._return_numba_type_str = f"types.Tuple(({outs},))" else: outs = ", ".join( - [self._cxx_return_type_str, *out_return_types] + [self._visible_cxx_return_type_str, *out_return_types] ) self._return_numba_type_str = f"types.Tuple(({outs},))" else: - self._return_numba_type_str = self._cxx_return_type_str + self._return_numba_type_str = self._visible_cxx_return_type_str def _tuple_literal(items: list[str]) -> str: """ @@ -314,7 +363,11 @@ def _render_shim_function(self): self._c_ext_shim_rendered = make_function_shim( shim_name=self._deduplicated_shim_name, func_name=self._decl.name, - return_type=self._decl.return_type.unqualified_non_ref_type_name, + return_type=( + self._decl.return_type.name + if self._return_materialization is not None + else self._decl.return_type.unqualified_non_ref_type_name + ), params=self._decl.params, ) @@ -351,6 +404,7 @@ def _render_lowering(self): intent_plan=self._intent_plan_rendered, out_return_types=self._out_return_types_rendered, cxx_return_type=self._cxx_return_type_rendered, + return_materialization=self._return_materialization_rendered, ) def _render_scoped_lower(self): @@ -434,6 +488,7 @@ def __init__( decl: Function, header_path: str, function_argument_intents: dict | None = None, + function_return_materializations: dict | None = None, ): """ Initialize an overloaded-operator renderer and record the Python operator mapping. @@ -448,6 +503,7 @@ def __init__( header_path, use_cooperative=False, function_argument_intents=function_argument_intents, + function_return_materializations=function_return_materializations, ) self._py_op = decl.overloaded_operator_to_python_operator @@ -492,6 +548,7 @@ def __init__( use_cooperative: bool, function_prefix_removal: list[str] = [], function_argument_intents: dict | None = None, + function_return_materializations: dict | None = None, ): """ Initialize the non-operator function renderer, compute the Python-facing function name by removing configured prefixes, and update tracked function symbols accordingly. @@ -510,6 +567,7 @@ def __init__( header_path, use_cooperative, function_argument_intents=function_argument_intents, + function_return_materializations=function_return_materializations, ) self._python_func_name = _apply_prefix_removal( decl.name, function_prefix_removal @@ -594,6 +652,7 @@ def __init__( cooperative_launch_required: list[str] = [], function_prefix_removal: list[str] = [], function_argument_intents: dict | None = None, + function_return_materializations: dict | None = None, ): """ Initialize the renderer for a collection of CUDA/C++ function declarations and configure rendering options. @@ -620,6 +679,9 @@ def __init__( self._cooperative_launch_required = cooperative_launch_required self._function_prefix_removal = function_prefix_removal self._function_argument_intents = function_argument_intents or {} + self._function_return_materializations = ( + function_return_materializations or {} + ) self._func_typing_signature_cache: dict[str, list[str]] = defaultdict( list @@ -673,6 +735,9 @@ def _create_operator_renderer( decl, self._header_path, function_argument_intents=self._function_argument_intents, + function_return_materializations=( + self._function_return_materializations + ), ) except TypeNotFoundError as e: warn( @@ -700,6 +765,9 @@ def _create_function_renderer( use_cooperative, self._function_prefix_removal, function_argument_intents=self._function_argument_intents, + function_return_materializations=( + self._function_return_materializations + ), ) except TypeNotFoundError as e: warn( diff --git a/numbast/src/numbast/static/tests/conftest.py b/numbast/src/numbast/static/tests/conftest.py index 1556fb92..38a9b6b5 100644 --- a/numbast/src/numbast/static/tests/conftest.py +++ b/numbast/src/numbast/static/tests/conftest.py @@ -47,6 +47,7 @@ def _make_binding( datamodels: dict[str, type], cc: str = "sm_80", function_argument_intents: dict | None = None, + function_return_materializations: dict | None = None, ): clear_base_renderer_cache() clear_function_apis_registry() @@ -63,6 +64,9 @@ def _make_binding( separate_registry=False, ) cfg.function_argument_intents = function_argument_intents or {} + cfg.function_return_materializations = ( + function_return_materializations or {} + ) _static_binding_generator(cfg, tmpdir) basename = header_name.split(".")[0] diff --git a/numbast/src/numbast/static/tests/data/function_out.cuh b/numbast/src/numbast/static/tests/data/function_out.cuh index 0ba25f4c..23af66e1 100644 --- a/numbast/src/numbast/static/tests/data/function_out.cuh +++ b/numbast/src/numbast/static/tests/data/function_out.cuh @@ -10,3 +10,4 @@ void __device__ add_out(int &out, int x); int __device__ add_out_ret(int &out, int x); int __device__ add_in_ref(int &x); void __device__ add_inout_ref(int &x, int delta); +const float4 *__device__ get_transform(int handle); diff --git a/numbast/src/numbast/static/tests/data/src/function_out.cu b/numbast/src/numbast/static/tests/data/src/function_out.cu index 037f85f3..0fe20559 100644 --- a/numbast/src/numbast/static/tests/data/src/function_out.cu +++ b/numbast/src/numbast/static/tests/data/src/function_out.cu @@ -37,3 +37,14 @@ int __device__ add_in_ref(int &x) { return x + 5; } * @param delta Amount to add to `x`. */ void __device__ add_inout_ref(int &x, int delta) { x += delta; } + +static __device__ const float4 transform_rows[3] = { + {1.0f, 2.0f, 3.0f, 4.0f}, + {5.0f, 6.0f, 7.0f, 8.0f}, + {9.0f, 10.0f, 11.0f, 12.0f}, +}; + +const float4 *__device__ get_transform(int handle) { + (void)handle; + return transform_rows; +} diff --git a/numbast/src/numbast/static/tests/test_function_static_bindings.py b/numbast/src/numbast/static/tests/test_function_static_bindings.py index 418bd90a..790b67a5 100644 --- a/numbast/src/numbast/static/tests/test_function_static_bindings.py +++ b/numbast/src/numbast/static/tests/test_function_static_bindings.py @@ -90,6 +90,24 @@ def decl_out_ptr(make_binding): return bindings +@pytest.fixture(scope="function") +def decl_pointer_return(make_binding): + res = make_binding( + "function_out.cuh", + {}, + {}, + "sm_50", + function_return_materializations={ + "get_transform": {"kind": "pointer", "length": 3}, + }, + ) + bindings = res["bindings"] + + assert "get_transform" in bindings + + return bindings + + @pytest.fixture(scope="module") def impl_out(data_folder): """ @@ -230,3 +248,51 @@ def kernel(out_ptr_buf, in_val, inout_buf, out_val): assert out_ptr_buf[0] == 9 assert inout_buf[0] == 10 assert out_val.copy_to_host()[0] == 13 + + +def test_pointer_return_materializes_borrowed_rows( + decl_pointer_return, impl_out +): + get_transform = decl_pointer_return["get_transform"] + + @cuda.jit(link=[impl_out]) + def kernel(out): + rows = get_transform(0) + out[0] = rows[0].x + out[1] = rows[0].y + out[2] = rows[0].z + out[3] = rows[0].w + out[4] = rows[1].x + out[5] = rows[1].y + out[6] = rows[1].z + out[7] = rows[1].w + out[8] = rows[2].x + out[9] = rows[2].y + out[10] = rows[2].z + out[11] = rows[2].w + + out = device_array((12,), "float32") + kernel[1, 1](out) + + np.testing.assert_allclose( + out.copy_to_host(), np.arange(1, 13, dtype=np.float32), rtol=0, atol=0 + ) + + +def test_pointer_return_static_binding_source(make_binding): + src = make_binding( + "function_out.cuh", + {}, + {}, + "sm_50", + function_return_materializations={ + "get_transform": {"kind": "pointer", "length": 3}, + }, + )["src"] + + assert "signature(UniTuple(float32x4, 3), int32)" in src + assert ( + "return_materialization=PointerReturnMaterialization(length=3)" in src + ) + assert "cxx_return_type=CPointer(float32x4)" in src + assert "const float4 * &retval" in src diff --git a/numbast/src/numbast/tools/static_binding_generator.py b/numbast/src/numbast/tools/static_binding_generator.py index 68ed1484..74039668 100644 --- a/numbast/src/numbast/tools/static_binding_generator.py +++ b/numbast/src/numbast/tools/static_binding_generator.py @@ -80,6 +80,7 @@ class Config: skip_prefix: str | None separate_registry: bool function_argument_intents: dict + function_return_materializations: dict def __init__(self, config_dict: dict): """ @@ -145,6 +146,9 @@ def __init__(self, config_dict: dict): self.function_argument_intents = ( config_dict.get("Function Argument Intents", {}) or {} ) + self.function_return_materializations = ( + config_dict.get("Function Return Materializations", {}) or {} + ) # TODO: support multiple GPU architectures if len(self.gpu_arch) > 1: @@ -194,12 +198,14 @@ def from_params( skip_prefix: str | None = None, separate_registry: bool = False, function_argument_intents: dict | None = None, + function_return_materializations: dict | None = None, ) -> "Config": """ Construct a Config from explicit parameters instead of a YAML file. Parameters: function_argument_intents (dict | None): Mapping from function names to argument-intent specifications used by renderers; defaults to an empty dict if omitted. + function_return_materializations (dict | None): Mapping from function names to fixed-size pointer-return materialization specifications; defaults to an empty dict if omitted. cooperative_launch_required_functions_regex (list[str] | None): Regular expression patterns that identify functions requiring cooperative launch handling; defaults to an empty list if omitted. api_prefix_removal (dict[str, list[str]] | None): Mapping of API names to lists of symbol-name prefixes to remove when generating bindings; defaults to an empty dict if omitted. module_callbacks (dict[str, str] | None): Mapping of callback identifiers to their fully qualified callable names to be invoked from the generated module; defaults to an empty dict if omitted. @@ -234,6 +240,8 @@ def from_params( "Skip Prefix": skip_prefix, "Use Separate Registry": separate_registry, "Function Argument Intents": function_argument_intents or {}, + "Function Return Materializations": function_return_materializations + or {}, } # Convert types and datamodels back to string format for the dict @@ -398,6 +406,7 @@ def _generate_functions( function_prefix_removal: list[str], skip_prefix: str | None, function_argument_intents: dict | None = None, + function_return_materializations: dict | None = None, ) -> str: """ Render the function-binding source for the given function declarations. @@ -423,6 +432,7 @@ def _generate_functions( function_prefix_removal=function_prefix_removal, skip_prefix=skip_prefix, function_argument_intents=function_argument_intents or {}, + function_return_materializations=function_return_materializations or {}, ) return SFR.render_as_str(with_imports=False, with_shim_stream=False) @@ -640,6 +650,7 @@ def _static_binding_generator( config.api_prefix_removal.get("Function", []), config.skip_prefix, config.function_argument_intents, + config.function_return_materializations, ) class_template_bindings = _generate_class_templates( class_templates, diff --git a/numbast/src/numbast/tools/static_binding_generator.schema.yaml b/numbast/src/numbast/tools/static_binding_generator.schema.yaml index 7970bb14..b17d0c3f 100644 --- a/numbast/src/numbast/tools/static_binding_generator.schema.yaml +++ b/numbast/src/numbast/tools/static_binding_generator.schema.yaml @@ -229,3 +229,58 @@ properties: type: string enum: ["in", "inout_ptr", "out_ptr", "out_return"] required: ["intent"] + Function Return Materializations: + type: object + default: {} + description: > + Per-function return materialization overrides for borrowed fixed-size pointer returns. Function keys map to an integer length or an object with kind/intent pointer metadata and a length, size, or count. + + examples: + - get_transform: + kind: pointer + length: 3 + get_coefficients: 8 + additionalProperties: + oneOf: + - type: integer + minimum: 1 + - type: object + additionalProperties: false + properties: + kind: + type: string + enum: + - pointer + - ptr + - pointer_return + - ptr_return + - borrowed_ptr + - borrowed_pointer + - fixed_size_pointer + - borrowed_fixed_size_ptr + - borrowed_fixed_size_pointer + intent: + type: string + enum: + - pointer + - ptr + - pointer_return + - ptr_return + - borrowed_ptr + - borrowed_pointer + - fixed_size_pointer + - borrowed_fixed_size_ptr + - borrowed_fixed_size_pointer + length: + type: integer + minimum: 1 + size: + type: integer + minimum: 1 + count: + type: integer + minimum: 1 + anyOf: + - required: ["length"] + - required: ["size"] + - required: ["count"] diff --git a/numbast/src/numbast/tools/tests/test_config_schema_docs.py b/numbast/src/numbast/tools/tests/test_config_schema_docs.py index db7343d7..bf2b819d 100644 --- a/numbast/src/numbast/tools/tests/test_config_schema_docs.py +++ b/numbast/src/numbast/tools/tests/test_config_schema_docs.py @@ -45,6 +45,7 @@ def test_static_binding_schema_has_expected_keys(): "Skip Prefix", "Use Separate Registry", "Function Argument Intents", + "Function Return Materializations", } assert expected_keys.issubset(set(properties)) diff --git a/numbast/tests/data/sample_function_out.cuh b/numbast/tests/data/sample_function_out.cuh index ef56004b..7156daca 100644 --- a/numbast/tests/data/sample_function_out.cuh +++ b/numbast/tests/data/sample_function_out.cuh @@ -13,3 +13,14 @@ __device__ int add_out_ret(int &out, int x) { } __device__ int add_in_ref(int &x) { return x + 5; } + +static __device__ const float4 transform_rows[3] = { + {1.0f, 2.0f, 3.0f, 4.0f}, + {5.0f, 6.0f, 7.0f, 8.0f}, + {9.0f, 10.0f, 11.0f, 12.0f}, +}; + +__device__ const float4 *get_transform(int handle) { + (void)handle; + return transform_rows; +} diff --git a/numbast/tests/test_callconv.py b/numbast/tests/test_callconv.py index b3cf53b0..4f72a74b 100644 --- a/numbast/tests/test_callconv.py +++ b/numbast/tests/test_callconv.py @@ -11,6 +11,7 @@ from numbast.callconv import FunctionCallConv, _get_alloca_alignment from numbast.intent_defs import ArgIntent, IntentPlan +from numbast.return_materialization import PointerReturnMaterialization from numbast.types import CTYPE_MAPS, get_numba_type_alignof @@ -26,6 +27,7 @@ def _lower_callconv_to_ir( intent_plan=None, out_return_types=None, cxx_return_type=None, + return_materialization=None, ): context = cuda_target.target_context sig = SimpleNamespace(return_type=return_type, args=tuple(args)) @@ -44,6 +46,7 @@ def _lower_callconv_to_ir( intent_plan=intent_plan, out_return_types=out_return_types, cxx_return_type=cxx_return_type, + return_materialization=return_materialization, ) callconv._lower_impl(builder, context, sig, tuple(fn.args)) builder.ret_void() @@ -120,3 +123,18 @@ def test_intent_plan_allocas_are_aligned_in_lowered_ir(): assert re.search( r"load \{float, float, float, float\}, .* align 16", llvm_ir ) + + +def test_pointer_return_materialization_loads_fixed_size_rows(): + float4 = CTYPE_MAPS["float4"] + + llvm_ir = _lower_callconv_to_ir( + return_type=cuda_types.UniTuple(float4, 3), + args=(cuda_types.int32,), + cxx_return_type=cuda_types.CPointer(float4), + return_materialization=PointerReturnMaterialization(length=3), + ) + + assert llvm_ir.count("load {float, float, float, float}, ") == 3 + assert llvm_ir.count("align 16") >= 3 + assert "getelementptr inbounds {float, float, float, float}" in llvm_ir diff --git a/numbast/tests/test_function.py b/numbast/tests/test_function.py index 83b23017..d7c9c7dd 100644 --- a/numbast/tests/test_function.py +++ b/numbast/tests/test_function.py @@ -11,6 +11,7 @@ from ast_canopy import parse_declarations_from_source from numbast import bind_cxx_functions, MemoryShimWriter +from numbast.function import overload_registry import pytest @@ -234,3 +235,58 @@ def kernel_ptr(out_ptr_buf, in_val, out_in_ref): kernel_ptr[1, 1](out_ptr_buf, np.int32(8), out_in_ref) assert out_ptr_buf[0] == 9 assert out_in_ref[0] == 13 + + +@pytest.fixture +def _sample_pointer_return_functions(): + DATA_FOLDER = os.path.join(os.path.dirname(__file__), "data") + p = os.path.join(DATA_FOLDER, "sample_function_out.cuh") + decls = parse_declarations_from_source(p, [p], "sm_80", verbose=True) + funcs = decls.functions + shim_writer = MemoryShimWriter(f'#include "{p}"') + + func_bindings = bind_cxx_functions( + shim_writer, + funcs, + return_materializations={ + "get_transform": {"kind": "pointer", "length": 3}, + }, + ) + + return func_bindings, shim_writer + + +def test_pointer_return_materializes_borrowed_rows( + _sample_pointer_return_functions, +): + func_bindings, shim_writer = _sample_pointer_return_functions + get_transform = find_binding(func_bindings, "get_transform") + + @cuda.jit(link=shim_writer.links()) + def kernel(out): + rows = get_transform(0) + out[0] = rows[0].x + out[1] = rows[0].y + out[2] = rows[0].z + out[3] = rows[0].w + out[4] = rows[1].x + out[5] = rows[1].y + out[6] = rows[1].z + out[7] = rows[1].w + out[8] = rows[2].x + out[9] = rows[2].y + out[10] = rows[2].z + out[11] = rows[2].w + + out = np.zeros(12, dtype=np.float32) + kernel[1, 1](out) + np.testing.assert_allclose( + out, np.arange(1, 13, dtype=np.float32), rtol=0, atol=0 + ) + + +def test_pointer_return_binding_signature(_sample_pointer_return_functions): + assert any( + str(sig) == "(int32,) -> UniTuple(float32x4 x 3)" + for sig in overload_registry["get_transform"] + )