Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions docs/source/argument_intents.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ C++ source of truth
__device__ bool stats_update_and_get_zscore(
RunningStats &state, float x, float &zscore_out);

__device__ void stats_get_matrix_3x4(float out[3][4]);

__device__ void stats_get_vectors(float4 out[3]);

Example config
--------------

Expand All @@ -38,6 +42,34 @@ Example config
stats_update_and_get_zscore:
state: inout_ptr
zscore_out: out_return
stats_get_matrix_3x4:
out:
intent: out_array_return
dtype: float
length: 12
stats_get_vectors:
out:
intent: out_array_return
dtype: float4
length: 3

Programmatic API
----------------

.. code-block:: python

from numba.cuda.types import float32
from numbast import bind_cxx_functions, out_array_return

bindings = bind_cxx_functions(
shim_writer,
funcs,
arg_intent={
"stats_get_matrix_3x4": {
"out": out_array_return(dtype=float32, length=12),
},
},
)

Intent semantics
----------------
Expand Down Expand Up @@ -70,6 +102,28 @@ Intent semantics
- Parameter is removed from the visible Python call arguments.
- Numbast allocates temporary storage, passes it to C++, then returns the value to Python.
- If C++ also returns a non-``void`` value, generated return type is packed as a tuple.
- Use ``out_return`` for a single scalar or struct-like output object, such as
``float &out`` or ``MyType &out``. It loads one value from the hidden storage
after the call.
- ``out_return`` does not describe fixed-size buffer shape. Use
``out_array_return`` when the hidden output is an array/buffer and Numbast
must know the element ``dtype`` and fixed ``length`` to build a
``UniTuple``.

``out_array_return``
^^^^^^^^^^^^^^^^^^^^

- Pointer or fixed-size array output parameter is removed from the visible Python call arguments.
- Numbast allocates fixed-size native stack storage, passes the raw pointer to
C++ through the shim, loads each element after the call, and returns a fixed
``UniTuple``.
- ``dtype`` is the element type and ``length`` is the number of elements to load.
- Multidimensional native arrays are returned as flat tuples. For example,
``float out[3][4]`` uses ``length: 12`` and row-major indexing
``value[row * 4 + col]``.
- Static configs use C++ or registered type names such as ``float`` or
``float4``. Programmatic bindings can use Numba types such as ``float32`` or
registered C++ type names.

Generated Python signatures
---------------------------
Expand All @@ -94,11 +148,20 @@ Representative signatures for the example API:
float32,
)

# out_array_return:
signature(UniTuple(float32, 12)) # logical 3x4 matrix, flattened
signature(UniTuple(float32x4, 3))

Notes
-----

- ``inout_ptr``, ``out_ptr``, and ``out_return`` are only supported on C++
reference parameters (``T&`` / ``T&&``).
- ``out_array_return`` is supported on pointer/array output parameters such as
``float *out``, ``float out[12]``, ``float out[3][4]``, and ``float4 out[3]``.
- ``out_array_return`` returns a one-dimensional ``UniTuple``. For logical
multidimensional outputs, use the total element count as ``length`` and
flatten the indexing convention in the binding documentation.
- In ``Function Argument Intents``, parameter overrides can be keyed by
parameter name or 0-based parameter index.

Expand All @@ -112,3 +175,9 @@ Notes
returns that value directly (not a tuple).
- If C++ returns ``void`` and there are multiple ``out_return`` parameters,
Numbast returns ``types.Tuple((out1, out2, ...))``.
- ``out_array_return`` values participate in the same return packing rules as
``out_return``. A single ``void`` function output returns the ``UniTuple``
directly; multiple outputs or a non-``void`` C++ return are packed in an
outer heterogeneous ``types.Tuple``. For example, a function returning
``int`` with one ``out_array_return(dtype=float32, length=12)`` has return
type ``types.Tuple((int32, UniTuple(float32, 12)))``.
22 changes: 22 additions & 0 deletions docs/source/generated/static_binding_schema_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,12 @@ Optional keys
my_function:
result: out_ptr
0: in
Function Argument Intents:
get_matrix:
out:
intent: out_array_return
dtype: float
length: 12
Comment on lines +279 to +284

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason behind this design is because of the lack of fixed-size array return information from ast_canopy.



Optional nested keys
Expand Down Expand Up @@ -603,6 +609,11 @@ Raw schema
- my_function:
result: out_ptr
0: in
- get_matrix:
out:
intent: out_array_return
dtype: float
length: 12
additionalProperties:
type: object
additionalProperties:
Expand All @@ -615,3 +626,14 @@ Raw schema
type: string
enum: ["in", "inout_ptr", "out_ptr", "out_return"]
required: ["intent"]
- type: object
properties:
intent:
type: string
const: out_array_return
dtype:
type: string
length:
type: integer
minimum: 1
required: ["intent", "dtype", "length"]
2 changes: 2 additions & 0 deletions numbast/src/numbast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
bind_cxx_function_templates,
)
from numbast.enum import bind_cxx_enum, bind_cxx_enums
from numbast.intent_defs import out_array_return
from numbast.shim_writer import MemoryShimWriter, FileShimWriter

import importlib.metadata
Expand All @@ -40,6 +41,7 @@
"bind_cxx_class_template",
"bind_cxx_class_templates",
"clear_concrete_type_caches",
"out_array_return",
"MemoryShimWriter",
"FileShimWriter",
]
145 changes: 125 additions & 20 deletions numbast/src/numbast/callconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@ class _OutReturnPtr(NamedTuple):
numba_ty: types.Type
ptr: ir.Value
align: int
array_length: int | None = None
array_is_aggregate: bool = False


def _get_out_array_return_specs(plan):
specs = getattr(plan, "out_array_return_specs", ())
if not specs:
return (None,) * len(plan.intents)
if len(specs) != len(plan.intents):
raise ValueError(
"IntentPlan out_array_return_specs length does not match intents: "
f"{len(specs)} != {len(plan.intents)}"
)
return tuple(specs)
Comment on lines +24 to +33

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Fail fast when out_array_return intent lacks a spec.

Line 26 currently treats missing out_array_return_specs as all-None. If an IntentPlan includes out_array_return, this can silently take the scalar out-return path and generate mismatched ABI/lowering artifacts. Please raise immediately when an out_array_return intent has no corresponding spec.

Suggested guard
 def _get_out_array_return_specs(plan):
     specs = getattr(plan, "out_array_return_specs", ())
-    if not specs:
-        return (None,) * len(plan.intents)
+    if not specs:
+        if any(getattr(i, "value", i) == "out_array_return" for i in plan.intents):
+            raise ValueError(
+                "IntentPlan contains out_array_return intent but out_array_return_specs is missing"
+            )
+        return (None,) * len(plan.intents)
     if len(specs) != len(plan.intents):
         raise ValueError(
             "IntentPlan out_array_return_specs length does not match intents: "
             f"{len(specs)} != {len(plan.intents)}"
         )
+    for idx, (intent, spec) in enumerate(zip(plan.intents, specs)):
+        if getattr(intent, "value", intent) == "out_array_return" and spec is None:
+            raise ValueError(
+                f"Missing out_array_return spec for intent index {idx}"
+            )
     return tuple(specs)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@numbast/src/numbast/callconv.py` around lines 24 - 33, The current
_get_out_array_return_specs returns all-None when out_array_return_specs is
missing, which can mask cases where an IntentPlan contains an out_array_return
intent; change it to fail fast: in _get_out_array_return_specs, if
out_array_return_specs is falsy, scan plan.intents for any intent that
represents the out_array_return (match the enum/constant used in your code,
e.g., Intent.out_array_return or intent.name == "out_array_return"); if any such
intent exists, raise a ValueError indicating a missing out_array_return spec,
otherwise return the tuple of Nones as before; keep the existing length check
(len(specs) != len(plan.intents)) and final return tuple(specs).



def _get_alloca_alignment(context, value_ty, numba_ty=None):
Expand Down Expand Up @@ -198,22 +212,95 @@ def _lower_impl(self, builder, context, sig, args):
for out_pos, orig_idx in enumerate(plan.out_return_indices):
orig_to_out[orig_idx] = out_pos

out_array_specs = _get_out_array_return_specs(plan)

for orig_idx in range(n_orig):
out_pos = orig_to_out[orig_idx]
if out_pos is not None:
out_nbty = self._out_return_types[out_pos]
vty = context.get_value_type(out_nbty)
ptr = cgutils.alloca_once(builder, vty)
ptr_align = _set_alloca_alignment(
ptr, context, vty, out_nbty
)
ptrs.append(ptr)
arg_pointer_types.append(ir.PointerType(vty))
out_return_ptrs.append(
_OutReturnPtr(
numba_ty=out_nbty, ptr=ptr, align=ptr_align
array_spec = out_array_specs[orig_idx]
if array_spec is not None:
elem_nbty = array_spec.dtype
elem_vty = context.get_value_type(elem_nbty)
length = int(array_spec.length)
if array_spec.shim_arg_indirect:
# Native pointer output parameters are shimmed as
# pointer-to-pointer arguments. Store the raw stack
# buffer pointer in a pointer slot, then pass that
# slot to the shim; the shim dereferences once and
# the native function receives the raw pointer.
storage_ptr = cgutils.alloca_once(
builder,
elem_vty,
size=length,
name="out_array_return",
)
elem_align = _set_alloca_alignment(
storage_ptr, context, elem_vty, elem_nbty
)
ptr_slot_ty = ir.PointerType(elem_vty)
ptr_slot = cgutils.alloca_once(
builder,
ptr_slot_ty,
name="out_array_return_ptr",
)
ptr_slot_align = _set_alloca_alignment(
ptr_slot, context, ptr_slot_ty
)
builder.store(
storage_ptr,
ptr_slot,
align=ptr_slot_align,
)
ptrs.append(ptr_slot)
arg_pointer_types.append(
ir.PointerType(ptr_slot_ty)
)
out_return_ptrs.append(
_OutReturnPtr(
numba_ty=out_nbty,
ptr=storage_ptr,
align=elem_align,
array_length=length,
)
)
else:
array_vty = ir.ArrayType(elem_vty, length)
storage_ptr = cgutils.alloca_once(
builder,
array_vty,
name="out_array_return",
)
elem_align = _get_alloca_alignment(
context, elem_vty, elem_nbty
)
_set_alloca_alignment(
storage_ptr, context, array_vty
)
ptrs.append(storage_ptr)
arg_pointer_types.append(ir.PointerType(array_vty))
out_return_ptrs.append(
_OutReturnPtr(
numba_ty=out_nbty,
ptr=storage_ptr,
align=elem_align,
array_length=length,
array_is_aggregate=True,
)
)
else:
vty = context.get_value_type(out_nbty)
ptr = cgutils.alloca_once(builder, vty)
ptr_align = _set_alloca_alignment(
ptr, context, vty, out_nbty
)
ptrs.append(ptr)
arg_pointer_types.append(ir.PointerType(vty))
out_return_ptrs.append(
_OutReturnPtr(
numba_ty=out_nbty, ptr=ptr, align=ptr_align
)
)
)
continue

vis_pos = orig_to_vis[orig_idx]
Expand Down Expand Up @@ -261,20 +348,38 @@ def _lower_impl(self, builder, context, sig, args):
if cxx_return_type != types.void:
ret_vals.append(builder.load(retval_ptr, align=retval_align))
for out_return in out_return_ptrs:
if out_return.array_length is None:
ret_vals.append(
builder.load(out_return.ptr, align=out_return.align)
)
continue

elems = []
for i in range(out_return.array_length):
idx = ir.Constant(ir.IntType(32), i)
if out_return.array_is_aggregate:
elem_ptr = builder.gep(
out_return.ptr,
[ir.Constant(ir.IntType(32), 0), idx],
inbounds=True,
)
else:
elem_ptr = builder.gep(out_return.ptr, [idx], inbounds=True)
elems.append(builder.load(elem_ptr, align=out_return.align))
ret_vals.append(
builder.load(out_return.ptr, align=out_return.align)
context.make_tuple(builder, out_return.numba_ty, elems)
)

# If Numba-visible return is a tuple, use context.make_tuple.
# Otherwise (void + single out), return the single out value directly.
# A single out-return value may itself be a UniTuple. Return it directly
# instead of wrapping it as another tuple.
if len(ret_vals) == 1:
return ret_vals[0]
if hasattr(sig.return_type, "types"):
return context.make_tuple(builder, sig.return_type, ret_vals)
if len(ret_vals) != 1:
raise ValueError(
"Non-tuple return type requires exactly one return value; "
f"got {len(ret_vals)}"
)
return ret_vals[0]
raise ValueError(
"Multiple return values require a tuple return type; "
f"got {len(ret_vals)}"
)


# NBST:END_CALLCONV
Loading
Loading