Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/source/argument_intents.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,26 @@ Notes
returns that value directly (not a tuple).
- If C++ returns ``void`` and there are multiple ``out_return`` parameters,
Numbast returns ``types.Tuple((out1, out2, ...))``.

Pointer return materialization
------------------------------

Some APIs return a borrowed pointer to a fixed-size sequence owned elsewhere:

.. code-block:: c++

__device__ const float4 *get_transform(Handle h);

Use ``Function Return Materializations`` to copy a configured number of
pointee values into the Numba return value:

.. code-block:: yaml

Function Return Materializations:
get_transform:
kind: pointer
length: 3

The generated Numba signature returns ``UniTuple(float32x4, 3)``. The source
pointer is still the C++ return value; Numbast does not allocate or own that
storage beyond copying the configured elements during the call.
73 changes: 73 additions & 0 deletions docs/source/generated/static_binding_schema_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,23 @@ Optional keys
0: in


``Function Return Materializations`` : ``object``
Per-function return materialization overrides for borrowed fixed-size pointer returns. Function keys map to an
integer length or an object with kind/intent pointer metadata and a length, size, or count.

Default: ``{}``.

Example:

.. code-block:: yaml

Function Return Materializations:
get_transform:
kind: pointer
length: 3
get_coefficients: 8


Optional nested keys
^^^^^^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -615,3 +632,59 @@ Raw schema
type: string
enum: ["in", "inout_ptr", "out_ptr", "out_return"]
required: ["intent"]
Function Return Materializations:
type: object
default: {}
description: >
Per-function return materialization overrides for borrowed fixed-size pointer returns. Function keys map to an
integer length or an object with kind/intent pointer metadata and a length, size, or count.

examples:
- get_transform:
kind: pointer
length: 3
get_coefficients: 8
additionalProperties:
oneOf:
- type: integer
minimum: 1
- type: object
additionalProperties: false
properties:
kind:
type: string
enum:
- pointer
- ptr
- pointer_return
- ptr_return
- borrowed_ptr
- borrowed_pointer
- fixed_size_pointer
- borrowed_fixed_size_ptr
- borrowed_fixed_size_pointer
intent:
type: string
enum:
- pointer
- ptr
- pointer_return
- ptr_return
- borrowed_ptr
- borrowed_pointer
- fixed_size_pointer
- borrowed_fixed_size_ptr
- borrowed_fixed_size_pointer
length:
type: integer
minimum: 1
size:
type: integer
minimum: 1
count:
type: integer
minimum: 1
anyOf:
- required: ["length"]
- required: ["size"]
- required: ["count"]
2 changes: 2 additions & 0 deletions numbast/src/numbast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
bind_cxx_function_templates,
)
from numbast.enum import bind_cxx_enum, bind_cxx_enums
from numbast.return_materialization import PointerReturnMaterialization
from numbast.shim_writer import MemoryShimWriter, FileShimWriter

import importlib.metadata
Expand All @@ -40,6 +41,7 @@
"bind_cxx_class_template",
"bind_cxx_class_templates",
"clear_concrete_type_caches",
"PointerReturnMaterialization",
"MemoryShimWriter",
"FileShimWriter",
]
61 changes: 60 additions & 1 deletion numbast/src/numbast/callconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from numbast.args import prepare_ir_types
from numbast.intent import IntentPlan
from numbast.return_materialization import PointerReturnMaterialization
from numbast.types import get_numba_type_alignof

# NBST:BEGIN_CALLCONV
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(
intent_plan: IntentPlan | None = None,
out_return_types: list[types.Type] | None = None,
cxx_return_type: types.Type | None = None,
return_materialization: PointerReturnMaterialization | None = None,
):
"""
Initialize a FunctionCallConv with shim information and optional ABI/intent hints.
Expand All @@ -90,6 +92,10 @@ def __init__(
intent_plan (IntentPlan | None): Optional plan describing visible parameter indices, which parameters should be passed as pointers, and which parameters are out-returns; when present it drives argument mapping and out-return handling.
out_return_types (list[types.Type] | None): Types of the out-return values in the order declared by the IntentPlan; required when the intent_plan defines out-return indices.
cxx_return_type (types.Type | None): The C++ ABI return type to use for allocating/shimming the return slot; if None, the signature's return type is used.
return_materialization (PointerReturnMaterialization | None):
Optional borrowed-pointer return materialization plan. When
provided, the C++ return value must be a pointer and Numbast
copies `length` pointee elements into a UniTuple return value.
"""
super().__init__(itanium_mangled_name, shim_writer, shim_code)
self._arg_is_ref = list(arg_is_ref) if arg_is_ref is not None else None
Expand All @@ -98,6 +104,41 @@ def __init__(
list(out_return_types) if out_return_types is not None else None
)
self._cxx_return_type = cxx_return_type
self._return_materialization = return_materialization

def _materialize_pointer_return(
self,
builder,
context,
cxx_return_type,
retval_ptr,
retval_align,
):
if self._return_materialization is None:
raise ValueError("return materialization was not configured")
if not isinstance(cxx_return_type, types.CPointer):
raise ValueError(
"pointer return materialization requires a C++ pointer return "
f"type, got {cxx_return_type}"
)

pointee_numba_ty = cxx_return_type.dtype
pointee_ir_ty = context.get_value_type(pointee_numba_ty)
pointee_align = _get_alloca_alignment(
context, pointee_ir_ty, pointee_numba_ty
)
source_ptr = builder.load(retval_ptr, align=retval_align)

ret_vals: list[ir.Value] = []
for idx in range(self._return_materialization.length):
offset = ir.Constant(ir.IntType(32), idx)
elem_ptr = builder.gep(source_ptr, [offset], inbounds=True)
ret_vals.append(builder.load(elem_ptr, align=pointee_align))

return_type = types.UniTuple(
pointee_numba_ty, self._return_materialization.length
)
return context.make_tuple(builder, return_type, ret_vals)

def _lower_impl(self, builder, context, sig, args):
# Numba-visible return type may differ from the underlying C++ return type
Expand Down Expand Up @@ -125,6 +166,13 @@ def _lower_impl(self, builder, context, sig, args):
if self._cxx_return_type is not None
else sig.return_type
)
if (
self._return_materialization is not None
and cxx_return_type == types.void
):
raise ValueError(
"pointer return materialization cannot be used with void returns"
)
# 1. Prepare return value pointer
if cxx_return_type == types.void:
# Void return type in C++ is shimmed as int& ignored
Expand Down Expand Up @@ -248,18 +296,29 @@ def _lower_impl(self, builder, context, sig, args):
builder.call(fn, (retval_ptr, *ptrs))

# 5. Return
materialized_return = None
if self._return_materialization is not None:
materialized_return = self._materialize_pointer_return(
builder, context, cxx_return_type, retval_ptr, retval_align
)

if (
self._intent_plan is None
or not self._intent_plan.out_return_indices
):
if cxx_return_type == types.void:
return None
if materialized_return is not None:
return materialized_return
return builder.load(retval_ptr, align=retval_align)

# out_return enabled: return either a value or a tuple (ret, out1, out2, ...)
ret_vals: list[ir.Value] = []
if cxx_return_type != types.void:
ret_vals.append(builder.load(retval_ptr, align=retval_align))
if materialized_return is not None:
ret_vals.append(materialized_return)
else:
ret_vals.append(builder.load(retval_ptr, align=retval_align))
for out_return in out_return_ptrs:
ret_vals.append(
builder.load(out_return.ptr, align=out_return.align)
Expand Down
53 changes: 49 additions & 4 deletions numbast/src/numbast/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

from numbast.types import to_numba_type, to_numba_arg_type
from numbast.intent import compute_intent_plan
from numbast.return_materialization import (
PointerReturnMaterialization,
parse_return_materialization,
)
from numbast.utils import (
deduplicate_overloads,
make_function_shim,
Expand Down Expand Up @@ -42,6 +46,22 @@ def func():
func_obj_registry: dict[str, object] = defaultdict(make_new_func_obj)


def _visible_cxx_return_type(
cxx_return_type: nbtypes.Type,
return_materialization: PointerReturnMaterialization | None,
) -> nbtypes.Type:
if return_materialization is None:
return cxx_return_type
if not isinstance(cxx_return_type, nbtypes.CPointer):
raise ValueError(
"pointer return materialization requires a C++ pointer return "
f"type, got {cxx_return_type}"
)
return nbtypes.UniTuple(
cxx_return_type.dtype, return_materialization.length
)


def bind_cxx_operator_overload_function(
shim_writer: ShimWriter,
func_decl: Function,
Expand Down Expand Up @@ -129,6 +149,7 @@ def bind_cxx_non_operator_function(
exclude: set[str],
*,
arg_intent: dict | None = None,
return_materializations: dict | None = None,
) -> object:
"""
Create a Python-callable binding for a C++ non-operator function.
Expand All @@ -147,6 +168,8 @@ def bind_cxx_non_operator_function(
Set of function names to exclude from binding.
arg_intent : dict | None, optional
Optional per-function intent overrides that specify visibility and in/out semantics for reference parameters.
return_materializations : dict | None, optional
Optional per-function return materialization specs for borrowed fixed-size pointer returns.

Returns
-------
Expand All @@ -164,11 +187,19 @@ def bind_cxx_non_operator_function(
cxx_return_type = to_numba_type(
func_decl.return_type.unqualified_non_ref_type_name
)
return_materialization = parse_return_materialization(
return_materializations.get(func_decl.name)
if return_materializations
else None
)
visible_cxx_return_type = _visible_cxx_return_type(
cxx_return_type, return_materialization
)

overrides = arg_intent.get(func_decl.name) if arg_intent else None
if overrides is None:
# Backward-compatible default: refs are input-only values.
return_type = cxx_return_type
return_type = visible_cxx_return_type
param_types = [to_numba_arg_type(arg) for arg in func_decl.param_types]
arg_is_ref = [
bool(t.is_left_reference() or t.is_right_reference())
Expand Down Expand Up @@ -211,10 +242,10 @@ def bind_cxx_non_operator_function(
return_type = nbtypes.Tuple(tuple(out_return_types))
else:
return_type = nbtypes.Tuple(
tuple([cxx_return_type, *out_return_types])
tuple([visible_cxx_return_type, *out_return_types])
)
else:
return_type = cxx_return_type
return_type = visible_cxx_return_type

# In intentful mode, pass-through pointers are controlled by intent_plan,
# not by whether the C++ parameter is a reference.
Expand All @@ -235,7 +266,11 @@ class func_typing(ConcreteTemplate):

register_global(func, nbtypes.Function(func_typing))

return_type_name = func_decl.return_type.unqualified_non_ref_type_name
return_type_name = (
func_decl.return_type.name
if return_materialization is not None
else func_decl.return_type.unqualified_non_ref_type_name
)
mangled_name = deduplicate_overloads(func_decl.mangled_name)
shim_func_name = f"{mangled_name}_nbst"

Expand All @@ -251,6 +286,7 @@ class func_typing(ConcreteTemplate):
intent_plan=intent_plan,
out_return_types=out_return_types,
cxx_return_type=cxx_return_type,
return_materialization=return_materialization,
)

# Lowering
Expand All @@ -269,6 +305,7 @@ def bind_cxx_function(
exclude: set[str] = set(),
*,
arg_intent: dict | None = None,
return_materializations: dict | None = None,
) -> object:
"""
Create Python bindings for a C++ function.
Expand All @@ -282,6 +319,8 @@ def bind_cxx_function(
arg_intent (dict | None): Optional explicit intent overrides that control which C++ reference
parameters are exposed as inputs, outputs, or inout pointers and which parameters are
promoted to out-returns.
return_materializations (dict | None): Optional per-function return
materialization specs for borrowed fixed-size pointer returns.

Returns:
object or None: The Numba-CUDA-callable Python binding object for the function, or `None`
Expand All @@ -307,6 +346,7 @@ def bind_cxx_function(
skip_prefix=skip_prefix,
exclude=exclude,
arg_intent=arg_intent,
return_materializations=return_materializations,
)

return None
Expand All @@ -320,6 +360,7 @@ def bind_cxx_functions(
exclude: set[str] = set(),
*,
arg_intent: dict | None = None,
return_materializations: dict | None = None,
) -> list[object]:
"""Create bindings for a list of C++ functions.

Expand All @@ -339,6 +380,9 @@ def bind_cxx_functions(

exclude : set[str]
A set of function names to exclude. Default to empty set.
return_materializations : dict | None
Optional per-function return materialization specs for borrowed fixed-size
pointer returns.

Returns
-------
Expand All @@ -355,6 +399,7 @@ def bind_cxx_functions(
skip_non_device=skip_non_device,
exclude=exclude,
arg_intent=arg_intent,
return_materializations=return_materializations,
)
# overloaded operator (e.g. "+") do not need to have a separate API
# as they are called directly from the Python operator.
Expand Down
Loading
Loading