diff --git a/docs/source/argument_intents.rst b/docs/source/argument_intents.rst index 6fa8ce4e..6fb9c2c4 100644 --- a/docs/source/argument_intents.rst +++ b/docs/source/argument_intents.rst @@ -24,6 +24,10 @@ C++ source of truth __device__ bool stats_update_and_get_zscore( RunningStats &state, float x, float &zscore_out); + __device__ void stats_get_matrix_3x4(float out[3][4]); + + __device__ void stats_get_vectors(float4 out[3]); + Example config -------------- @@ -38,6 +42,34 @@ Example config stats_update_and_get_zscore: state: inout_ptr zscore_out: out_return + stats_get_matrix_3x4: + out: + intent: out_array_return + dtype: float + length: 12 + stats_get_vectors: + out: + intent: out_array_return + dtype: float4 + length: 3 + +Programmatic API +---------------- + +.. code-block:: python + + from numba.cuda.types import float32 + from numbast import bind_cxx_functions, out_array_return + + bindings = bind_cxx_functions( + shim_writer, + funcs, + arg_intent={ + "stats_get_matrix_3x4": { + "out": out_array_return(dtype=float32, length=12), + }, + }, + ) Intent semantics ---------------- @@ -70,6 +102,28 @@ Intent semantics - Parameter is removed from the visible Python call arguments. - Numbast allocates temporary storage, passes it to C++, then returns the value to Python. - If C++ also returns a non-``void`` value, generated return type is packed as a tuple. +- Use ``out_return`` for a single scalar or struct-like output object, such as + ``float &out`` or ``MyType &out``. It loads one value from the hidden storage + after the call. +- ``out_return`` does not describe fixed-size buffer shape. Use + ``out_array_return`` when the hidden output is an array/buffer and Numbast + must know the element ``dtype`` and fixed ``length`` to build a + ``UniTuple``. + +``out_array_return`` +^^^^^^^^^^^^^^^^^^^^ + +- Pointer or fixed-size array output parameter is removed from the visible Python call arguments. +- Numbast allocates fixed-size native stack storage, passes the raw pointer to + C++ through the shim, loads each element after the call, and returns a fixed + ``UniTuple``. +- ``dtype`` is the element type and ``length`` is the number of elements to load. +- Multidimensional native arrays are returned as flat tuples. For example, + ``float out[3][4]`` uses ``length: 12`` and row-major indexing + ``value[row * 4 + col]``. +- Static configs use C++ or registered type names such as ``float`` or + ``float4``. Programmatic bindings can use Numba types such as ``float32`` or + registered C++ type names. Generated Python signatures --------------------------- @@ -94,11 +148,20 @@ Representative signatures for the example API: float32, ) + # out_array_return: + signature(UniTuple(float32, 12)) # logical 3x4 matrix, flattened + signature(UniTuple(float32x4, 3)) + Notes ----- - ``inout_ptr``, ``out_ptr``, and ``out_return`` are only supported on C++ reference parameters (``T&`` / ``T&&``). +- ``out_array_return`` is supported on pointer/array output parameters such as + ``float *out``, ``float out[12]``, ``float out[3][4]``, and ``float4 out[3]``. +- ``out_array_return`` returns a one-dimensional ``UniTuple``. For logical + multidimensional outputs, use the total element count as ``length`` and + flatten the indexing convention in the binding documentation. - In ``Function Argument Intents``, parameter overrides can be keyed by parameter name or 0-based parameter index. @@ -112,3 +175,9 @@ Notes returns that value directly (not a tuple). - If C++ returns ``void`` and there are multiple ``out_return`` parameters, Numbast returns ``types.Tuple((out1, out2, ...))``. +- ``out_array_return`` values participate in the same return packing rules as + ``out_return``. A single ``void`` function output returns the ``UniTuple`` + directly; multiple outputs or a non-``void`` C++ return are packed in an + outer heterogeneous ``types.Tuple``. For example, a function returning + ``int`` with one ``out_array_return(dtype=float32, length=12)`` has return + type ``types.Tuple((int32, UniTuple(float32, 12)))``. diff --git a/docs/source/generated/static_binding_schema_reference.rst b/docs/source/generated/static_binding_schema_reference.rst index 153dc220..d9f4c61a 100644 --- a/docs/source/generated/static_binding_schema_reference.rst +++ b/docs/source/generated/static_binding_schema_reference.rst @@ -276,6 +276,12 @@ Optional keys my_function: result: out_ptr 0: in + Function Argument Intents: + get_matrix: + out: + intent: out_array_return + dtype: float + length: 12 Optional nested keys @@ -603,6 +609,11 @@ Raw schema - my_function: result: out_ptr 0: in + - get_matrix: + out: + intent: out_array_return + dtype: float + length: 12 additionalProperties: type: object additionalProperties: @@ -615,3 +626,14 @@ Raw schema type: string enum: ["in", "inout_ptr", "out_ptr", "out_return"] required: ["intent"] + - type: object + properties: + intent: + type: string + const: out_array_return + dtype: + type: string + length: + type: integer + minimum: 1 + required: ["intent", "dtype", "length"] diff --git a/numbast/src/numbast/__init__.py b/numbast/src/numbast/__init__.py index 2e949470..8943be6f 100644 --- a/numbast/src/numbast/__init__.py +++ b/numbast/src/numbast/__init__.py @@ -16,6 +16,7 @@ bind_cxx_function_templates, ) from numbast.enum import bind_cxx_enum, bind_cxx_enums +from numbast.intent_defs import out_array_return from numbast.shim_writer import MemoryShimWriter, FileShimWriter import importlib.metadata @@ -40,6 +41,7 @@ "bind_cxx_class_template", "bind_cxx_class_templates", "clear_concrete_type_caches", + "out_array_return", "MemoryShimWriter", "FileShimWriter", ] diff --git a/numbast/src/numbast/callconv.py b/numbast/src/numbast/callconv.py index 04f9bb19..cb7f7c28 100644 --- a/numbast/src/numbast/callconv.py +++ b/numbast/src/numbast/callconv.py @@ -17,6 +17,20 @@ class _OutReturnPtr(NamedTuple): numba_ty: types.Type ptr: ir.Value align: int + array_length: int | None = None + array_is_aggregate: bool = False + + +def _get_out_array_return_specs(plan): + specs = getattr(plan, "out_array_return_specs", ()) + if not specs: + return (None,) * len(plan.intents) + if len(specs) != len(plan.intents): + raise ValueError( + "IntentPlan out_array_return_specs length does not match intents: " + f"{len(specs)} != {len(plan.intents)}" + ) + return tuple(specs) def _get_alloca_alignment(context, value_ty, numba_ty=None): @@ -198,22 +212,95 @@ def _lower_impl(self, builder, context, sig, args): for out_pos, orig_idx in enumerate(plan.out_return_indices): orig_to_out[orig_idx] = out_pos + out_array_specs = _get_out_array_return_specs(plan) + for orig_idx in range(n_orig): out_pos = orig_to_out[orig_idx] if out_pos is not None: out_nbty = self._out_return_types[out_pos] - vty = context.get_value_type(out_nbty) - ptr = cgutils.alloca_once(builder, vty) - ptr_align = _set_alloca_alignment( - ptr, context, vty, out_nbty - ) - ptrs.append(ptr) - arg_pointer_types.append(ir.PointerType(vty)) - out_return_ptrs.append( - _OutReturnPtr( - numba_ty=out_nbty, ptr=ptr, align=ptr_align + array_spec = out_array_specs[orig_idx] + if array_spec is not None: + elem_nbty = array_spec.dtype + elem_vty = context.get_value_type(elem_nbty) + length = int(array_spec.length) + if array_spec.shim_arg_indirect: + # Native pointer output parameters are shimmed as + # pointer-to-pointer arguments. Store the raw stack + # buffer pointer in a pointer slot, then pass that + # slot to the shim; the shim dereferences once and + # the native function receives the raw pointer. + storage_ptr = cgutils.alloca_once( + builder, + elem_vty, + size=length, + name="out_array_return", + ) + elem_align = _set_alloca_alignment( + storage_ptr, context, elem_vty, elem_nbty + ) + ptr_slot_ty = ir.PointerType(elem_vty) + ptr_slot = cgutils.alloca_once( + builder, + ptr_slot_ty, + name="out_array_return_ptr", + ) + ptr_slot_align = _set_alloca_alignment( + ptr_slot, context, ptr_slot_ty + ) + builder.store( + storage_ptr, + ptr_slot, + align=ptr_slot_align, + ) + ptrs.append(ptr_slot) + arg_pointer_types.append( + ir.PointerType(ptr_slot_ty) + ) + out_return_ptrs.append( + _OutReturnPtr( + numba_ty=out_nbty, + ptr=storage_ptr, + align=elem_align, + array_length=length, + ) + ) + else: + array_vty = ir.ArrayType(elem_vty, length) + storage_ptr = cgutils.alloca_once( + builder, + array_vty, + name="out_array_return", + ) + elem_align = _get_alloca_alignment( + context, elem_vty, elem_nbty + ) + _set_alloca_alignment( + storage_ptr, context, array_vty + ) + ptrs.append(storage_ptr) + arg_pointer_types.append(ir.PointerType(array_vty)) + out_return_ptrs.append( + _OutReturnPtr( + numba_ty=out_nbty, + ptr=storage_ptr, + align=elem_align, + array_length=length, + array_is_aggregate=True, + ) + ) + else: + vty = context.get_value_type(out_nbty) + ptr = cgutils.alloca_once(builder, vty) + ptr_align = _set_alloca_alignment( + ptr, context, vty, out_nbty + ) + ptrs.append(ptr) + arg_pointer_types.append(ir.PointerType(vty)) + out_return_ptrs.append( + _OutReturnPtr( + numba_ty=out_nbty, ptr=ptr, align=ptr_align + ) ) - ) continue vis_pos = orig_to_vis[orig_idx] @@ -261,20 +348,38 @@ def _lower_impl(self, builder, context, sig, args): if cxx_return_type != types.void: ret_vals.append(builder.load(retval_ptr, align=retval_align)) for out_return in out_return_ptrs: + if out_return.array_length is None: + ret_vals.append( + builder.load(out_return.ptr, align=out_return.align) + ) + continue + + elems = [] + for i in range(out_return.array_length): + idx = ir.Constant(ir.IntType(32), i) + if out_return.array_is_aggregate: + elem_ptr = builder.gep( + out_return.ptr, + [ir.Constant(ir.IntType(32), 0), idx], + inbounds=True, + ) + else: + elem_ptr = builder.gep(out_return.ptr, [idx], inbounds=True) + elems.append(builder.load(elem_ptr, align=out_return.align)) ret_vals.append( - builder.load(out_return.ptr, align=out_return.align) + context.make_tuple(builder, out_return.numba_ty, elems) ) - # If Numba-visible return is a tuple, use context.make_tuple. - # Otherwise (void + single out), return the single out value directly. + # A single out-return value may itself be a UniTuple. Return it directly + # instead of wrapping it as another tuple. + if len(ret_vals) == 1: + return ret_vals[0] if hasattr(sig.return_type, "types"): return context.make_tuple(builder, sig.return_type, ret_vals) - if len(ret_vals) != 1: - raise ValueError( - "Non-tuple return type requires exactly one return value; " - f"got {len(ret_vals)}" - ) - return ret_vals[0] + raise ValueError( + "Multiple return values require a tuple return type; " + f"got {len(ret_vals)}" + ) # NBST:END_CALLCONV diff --git a/numbast/src/numbast/class_template.py b/numbast/src/numbast/class_template.py index abd1dd1e..d83c93ca 100644 --- a/numbast/src/numbast/class_template.py +++ b/numbast/src/numbast/class_template.py @@ -53,7 +53,15 @@ to_c_type_str, to_numba_arg_type, ) -from numbast.intent import ArgIntent, IntentPlan, compute_intent_plan +from numbast.intent import ArgIntent, compute_intent_plan +from numbast.intent_utils import ( + compose_return_type, + get_out_array_return_specs, + normalize_out_array_return_specs, + out_return_types_for_plan, + prepend_receiver_to_intent_plan, + shim_arg_type_for_out_return, +) from numbast.utils import ( deduplicate_overloads, make_struct_ctor_shim, @@ -334,15 +342,8 @@ def bind_cxx_struct_regular_method( overrides=overrides, allow_out_return=True, ) - intent_plan = IntentPlan( - intents=(ArgIntent.in_,) + method_plan.intents, - visible_param_indices=(0,) - + tuple(i + 1 for i in method_plan.visible_param_indices), - out_return_indices=tuple( - i + 1 for i in method_plan.out_return_indices - ), - pass_ptr_mask=(False,) + method_plan.pass_ptr_mask, - ) + method_plan = normalize_out_array_return_specs(method_plan) + intent_plan = prepend_receiver_to_intent_plan(method_plan) param_types = [] for orig_idx in method_plan.visible_param_indices: @@ -354,24 +355,10 @@ def bind_cxx_struct_regular_method( else: param_types.append(base) - out_return_types = [ - to_numba_type( - method_decl.param_types[i].unqualified_non_ref_type_name - ) - for i in method_plan.out_return_indices - ] - if out_return_types: - if cxx_return_type == nbtypes.void: - if len(out_return_types) == 1: - return_type = out_return_types[0] - else: - return_type = nbtypes.Tuple(tuple(out_return_types)) - else: - return_type = nbtypes.Tuple( - tuple([cxx_return_type, *out_return_types]) - ) - else: - return_type = cxx_return_type + out_return_types = out_return_types_for_plan( + method_decl.param_types, method_plan + ) + return_type = compose_return_type(cxx_return_type, out_return_types) arg_is_ref = None # Lowering @@ -519,35 +506,14 @@ def generic( overrides=overrides, allow_out_return=True, ) - intent_plan = IntentPlan( - intents=(ArgIntent.in_,) + method_plan.intents, - visible_param_indices=(0,) - + tuple(i + 1 for i in method_plan.visible_param_indices), - out_return_indices=tuple( - i + 1 for i in method_plan.out_return_indices - ), - pass_ptr_mask=(False,) + method_plan.pass_ptr_mask, + method_plan = normalize_out_array_return_specs(method_plan) + intent_plan = prepend_receiver_to_intent_plan(method_plan) + out_return_types = out_return_types_for_plan( + templated_method.function.param_types, method_plan + ) + return_type = compose_return_type( + cxx_return_type, out_return_types ) - out_return_types = [ - to_numba_type( - templated_method.function.param_types[ - i - ].unqualified_non_ref_type_name - ) - for i in method_plan.out_return_indices - ] - if out_return_types: - if cxx_return_type == nbtypes.void: - if len(out_return_types) == 1: - return_type = out_return_types[0] - else: - return_type = nbtypes.Tuple(tuple(out_return_types)) - else: - return_type = nbtypes.Tuple( - tuple([cxx_return_type, *out_return_types]) - ) - else: - return_type = cxx_return_type lowering_key = (qualname, recvr, param_types) if lowering_key not in _TEMPLATED_METHOD_LOWERING_CACHE: @@ -592,6 +558,9 @@ def _impl( method_plan.out_return_indices ) } + out_array_specs = get_out_array_return_specs( + method_plan + ) # Reconstruct full C++ param order by merging visible # params with out_return slots, keeping a shim-aligned # pass_ptr_mask. @@ -602,7 +571,10 @@ def _impl( out_pos = out_return_map.get(orig_idx) if out_pos is not None: param_types_for_shim_list.append( - out_return_types[out_pos] + shim_arg_type_for_out_return( + out_return_types[out_pos], + out_array_specs[orig_idx], + ) ) pass_ptr_mask_for_shim_list.append(False) else: diff --git a/numbast/src/numbast/function.py b/numbast/src/numbast/function.py index 23049b14..ab108547 100644 --- a/numbast/src/numbast/function.py +++ b/numbast/src/numbast/function.py @@ -15,6 +15,11 @@ from numbast.types import to_numba_type, to_numba_arg_type from numbast.intent import compute_intent_plan +from numbast.intent_utils import ( + compose_return_type, + normalize_out_array_return_specs, + out_return_types_for_plan, +) from numbast.utils import ( deduplicate_overloads, make_function_shim, @@ -184,6 +189,7 @@ def bind_cxx_non_operator_function( overrides=overrides, allow_out_return=True, ) + intent_plan = normalize_out_array_return_specs(intent_plan) # Visible param types in original order param_types = [] @@ -196,25 +202,10 @@ def bind_cxx_non_operator_function( else: param_types.append(base) - out_return_types = [ - to_numba_type( - func_decl.param_types[i].unqualified_non_ref_type_name - ) - for i in intent_plan.out_return_indices - ] - - if out_return_types: - if cxx_return_type == nbtypes.void: - if len(out_return_types) == 1: - return_type = out_return_types[0] - else: - return_type = nbtypes.Tuple(tuple(out_return_types)) - else: - return_type = nbtypes.Tuple( - tuple([cxx_return_type, *out_return_types]) - ) - else: - return_type = cxx_return_type + out_return_types = out_return_types_for_plan( + func_decl.param_types, intent_plan + ) + return_type = compose_return_type(cxx_return_type, out_return_types) # In intentful mode, pass-through pointers are controlled by intent_plan, # not by whether the C++ parameter is a reference. diff --git a/numbast/src/numbast/function_template.py b/numbast/src/numbast/function_template.py index 945aa099..59d28051 100644 --- a/numbast/src/numbast/function_template.py +++ b/numbast/src/numbast/function_template.py @@ -22,8 +22,20 @@ from numbast.callconv import FunctionCallConv from numbast.deduction import deduce_templated_overloads from numbast.intent import ArgIntent, compute_intent_plan +from numbast.intent_utils import ( + compose_return_type, + get_out_array_return_specs, + normalize_out_array_return_specs, + out_return_types_for_plan, + shim_arg_type_for_out_return, +) from numbast.types import to_c_type_str, to_numba_type -from numbast.utils import deduplicate_overloads, get_return_type_strings +from numbast.utils import ( + _canonicalize_array_pointer_type, + _param_type_name_to_pointer_arg, + deduplicate_overloads, + get_return_type_strings, +) from numbast.shim_writer import ShimWriterBase from numbast.overload_selection import _select_templated_overload @@ -144,7 +156,9 @@ def _make_templated_function_shim_arg_strings( formal_default = f"{c_ty}* arg{i}" actual_default = f"*arg{i}" - cxx_ty = cxx_param.type_.unqualified_non_ref_type_name + cxx_ty = _canonicalize_array_pointer_type( + cxx_param.type_.unqualified_non_ref_type_name + ) if use_pass_ptr: if "*" in cxx_ty: actual_default = f"arg{i}" @@ -152,7 +166,12 @@ def _make_templated_function_shim_arg_strings( actual_default = f"*arg{i}" m = _CXX_ARRAY_TYPE_RE.match(cxx_ty) is_lref = cxx_param.type_.is_left_reference() - if m and is_lref: + if m and "*" in m.group("base") and isinstance(nb_ty, nbtypes.CPointer): + formal_parts.append( + _param_type_name_to_pointer_arg(cxx_ty, f"arg{i}") + ) + actual_parts.append(f"*arg{i}") + elif m and is_lref: if not isinstance(nb_ty, nbtypes.CPointer): raise TypingError( f"{cxx_param.name}: expected a pointer argument in Numba " @@ -264,27 +283,14 @@ def generic(self, args, kwds, overloads=overloads, overrides=overrides): overrides=overrides, allow_out_return=True, ) + func_plan = normalize_out_array_return_specs(func_plan) intent_plan = func_plan - out_return_types = [ - to_numba_type( - templated_func.function.param_types[ - i - ].unqualified_non_ref_type_name - ) - for i in func_plan.out_return_indices - ] - if out_return_types: - if cxx_return_type == nbtypes.void: - if len(out_return_types) == 1: - return_type = out_return_types[0] - else: - return_type = nbtypes.Tuple(tuple(out_return_types)) - else: - return_type = nbtypes.Tuple( - tuple([cxx_return_type, *out_return_types]) - ) - else: - return_type = cxx_return_type + out_return_types = out_return_types_for_plan( + templated_func.function.param_types, func_plan + ) + return_type = compose_return_type( + cxx_return_type, out_return_types + ) @lower(func, *param_types) def _impl( @@ -321,6 +327,7 @@ def _impl( func_plan.out_return_indices ) } + out_array_specs = get_out_array_return_specs(func_plan) visible_iter = iter(param_types_inner) visible_mask_iter = iter(func_plan.pass_ptr_mask) param_types_for_shim_list = [] @@ -329,7 +336,10 @@ def _impl( out_pos = out_return_map.get(orig_idx) if out_pos is not None: param_types_for_shim_list.append( - out_return_types[out_pos] + shim_arg_type_for_out_return( + out_return_types[out_pos], + out_array_specs[orig_idx], + ) ) pass_ptr_mask_for_shim_list.append(False) else: diff --git a/numbast/src/numbast/intent.py b/numbast/src/numbast/intent.py index c3db6bf2..d0cfc3d5 100644 --- a/numbast/src/numbast/intent.py +++ b/numbast/src/numbast/intent.py @@ -5,7 +5,12 @@ from typing import Any, Mapping -from numbast.intent_defs import ArgIntent, IntentPlan +from numbast.intent_defs import ( + ArgIntent, + IntentPlan, + OutArrayReturnSpec, + out_array_return, +) def _parse_arg_intent(cls, v: Any) -> ArgIntent: @@ -33,6 +38,8 @@ def _parse_arg_intent(cls, v: Any) -> ArgIntent: return ArgIntent.out_ptr if v2 == "out_return": return ArgIntent.out_return + if v2 == "out_array_return": + return ArgIntent.out_array_return raise ValueError(f"Unknown arg intent: {v!r}") @@ -60,6 +67,49 @@ def _is_ref_type(ast_type: Any) -> bool: return bool(is_ref) +def _is_pointer_type(ast_type: Any) -> bool: + type_name = getattr(ast_type, "unqualified_non_ref_type_name", "") + return "*" in str(type_name) + + +def _parse_override_value( + raw: Any, +) -> tuple[ArgIntent, OutArrayReturnSpec | None]: + if isinstance(raw, OutArrayReturnSpec): + return ArgIntent.out_array_return, raw + + if isinstance(raw, Mapping): + if "intent" not in raw: + raise ValueError( + "arg_intent object values must include an 'intent' key" + ) + intent = _parse_arg_intent(ArgIntent, raw["intent"]) + if intent == ArgIntent.out_array_return: + if "dtype" not in raw or "length" not in raw: + raise ValueError( + "out_array_return intent requires 'dtype' and 'length'" + ) + return intent, out_array_return( + dtype=raw["dtype"], length=raw["length"] + ) + return intent, None + + if not isinstance(raw, (str, ArgIntent)): + raise TypeError( + "arg_intent values must be strings, ArgIntent enums, " + "OutArrayReturnSpec objects, or intent objects" + ) + + intent = _parse_arg_intent(ArgIntent, raw) + if intent == ArgIntent.out_array_return: + raise ValueError( + "out_array_return requires dtype and length; use " + "out_array_return(dtype=..., length=...) or an object with " + "intent/dtype/length" + ) + return intent, None + + def compute_intent_plan( *, params: list[Any], @@ -95,14 +145,11 @@ def compute_intent_plan( ) normalized: list[ArgIntent] = [ArgIntent.in_] * len(params) + out_array_specs: list[OutArrayReturnSpec | None] = [None] * len(params) if overrides: # First apply index-based overrides, then name-based overrides so names win. for key, raw in overrides.items(): - if type(raw) not in (str, ArgIntent): - raise TypeError( - "arg_intent values must be strings or ArgIntent enums" - ) - intent = _parse_arg_intent(ArgIntent, raw) + intent, array_spec = _parse_override_value(raw) if isinstance(key, int): if key < 0 or key >= len(params): @@ -110,6 +157,7 @@ def compute_intent_plan( f"arg_intent index {key} out of range for {len(params)} params" ) normalized[key] = intent + out_array_specs[key] = array_spec elif isinstance(key, str): # Defer name lookup until after we process all keys. continue @@ -124,16 +172,14 @@ def compute_intent_plan( for key, raw in overrides.items(): if not isinstance(key, str): continue - if type(raw) not in (str, ArgIntent): - raise TypeError( - "arg_intent values must be strings or ArgIntent enums" - ) - intent = _parse_arg_intent(ArgIntent, raw) + intent, array_spec = _parse_override_value(raw) if key not in name_to_idx: raise ValueError( f"arg_intent specified unknown param name {key!r}; known params: {list(name_to_idx.keys())}" ) - normalized[name_to_idx[key]] = intent + idx = name_to_idx[key] + normalized[idx] = intent + out_array_specs[idx] = array_spec # Validation + derived plan visible_param_indices: list[int] = [] @@ -142,12 +188,25 @@ def compute_intent_plan( for i, (intent, ty) in enumerate(zip(normalized, param_types)): is_ref = _is_ref_type(ty) - if intent != ArgIntent.in_: + is_pointer = _is_pointer_type(ty) + if intent == ArgIntent.out_array_return: + spec = out_array_specs[i] + if spec is None: + raise ValueError( + f"arg_intent[{i}]='out_array_return' requires dtype and length" + ) + if not (is_ref or is_pointer): + raise ValueError( + f"arg_intent[{i}]='out_array_return' is only supported for pointer " + "parameters or array/reference output parameters" + ) + out_array_specs[i] = spec.with_shim_arg_indirect(is_pointer) + elif intent != ArgIntent.in_: if not is_ref: raise ValueError( f"arg_intent[{i}]={intent.value!r} is only supported for reference parameters (T&/T&&)" ) - if intent == ArgIntent.out_return: + if intent in (ArgIntent.out_return, ArgIntent.out_array_return): if not allow_out_return: raise ValueError( "out_return intent is not supported in this context" @@ -164,4 +223,5 @@ def compute_intent_plan( visible_param_indices=tuple(visible_param_indices), out_return_indices=tuple(out_return_indices), pass_ptr_mask=tuple(pass_ptr_mask), + out_array_return_specs=tuple(out_array_specs), ) diff --git a/numbast/src/numbast/intent_defs.py b/numbast/src/numbast/intent_defs.py index e68ef661..2a0782f7 100644 --- a/numbast/src/numbast/intent_defs.py +++ b/numbast/src/numbast/intent_defs.py @@ -4,8 +4,9 @@ from __future__ import annotations # NBST:BEGIN_INTENT_DEFS -from dataclasses import dataclass +from dataclasses import dataclass, replace from enum import Enum +from typing import Any class ArgIntent(str, Enum): @@ -20,12 +21,42 @@ class ArgIntent(str, Enum): (CPointer(T)) on the Numba side and passed through to the shim. - `out_return`: C++ reference parameter is *not* exposed as an argument; a temporary is allocated, passed to the shim, and then returned to the caller. + - `out_array_return`: C++ pointer/array output parameter is *not* exposed as + an argument; fixed-size stack storage is allocated, passed to the shim, and + returned to the caller as a UniTuple. """ in_ = "in" inout_ptr = "inout_ptr" out_ptr = "out_ptr" out_return = "out_return" + out_array_return = "out_array_return" + + +@dataclass(frozen=True) +class OutArrayReturnSpec: + """ + Metadata for a fixed-size output array returned as a Numba UniTuple. + """ + + dtype: Any + length: int + shim_arg_indirect: bool | None = None + + def with_shim_arg_indirect(self, value: bool): + return replace(self, shim_arg_indirect=bool(value)) + + +def out_array_return(*, dtype: Any, length: int) -> OutArrayReturnSpec: + """ + Create an argument-intent spec for fixed-size native output arrays. + """ + length = int(length) + if length <= 0: + raise ValueError("out_array_return length must be positive") + if dtype is None: + raise ValueError("out_array_return dtype must be provided") + return OutArrayReturnSpec(dtype=dtype, length=length) @dataclass(frozen=True) @@ -38,6 +69,7 @@ class IntentPlan: visible_param_indices: tuple[int, ...] # subset of [0..N) out_return_indices: tuple[int, ...] # subset of [0..N) pass_ptr_mask: tuple[bool, ...] # aligned with visible params only + out_array_return_specs: tuple[OutArrayReturnSpec | None, ...] = () # NBST:END_INTENT_DEFS diff --git a/numbast/src/numbast/intent_utils.py b/numbast/src/numbast/intent_utils.py new file mode 100644 index 00000000..d7064a7f --- /dev/null +++ b/numbast/src/numbast/intent_utils.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import replace + +from numba import types as nbtypes + +from numbast.intent_defs import ArgIntent, IntentPlan, OutArrayReturnSpec +from numbast.types import to_numba_type + + +def get_out_array_return_specs( + plan: IntentPlan, +) -> tuple[OutArrayReturnSpec | None, ...]: + specs = getattr(plan, "out_array_return_specs", ()) + if not specs: + return (None,) * len(plan.intents) + if len(specs) != len(plan.intents): + raise ValueError( + "IntentPlan out_array_return_specs length does not match intents: " + f"{len(specs)} != {len(plan.intents)}" + ) + return tuple(specs) + + +def resolve_out_array_dtype(dtype) -> nbtypes.Type: + if isinstance(dtype, nbtypes.Type): + return dtype + if isinstance(dtype, str): + resolved = to_numba_type(dtype) + if not isinstance(resolved, nbtypes.Opaque): + return resolved + raise ValueError(f"Unknown out_array_return dtype: {dtype!r}") + + +def out_return_type_for_param(param_type, spec: OutArrayReturnSpec | None): + if spec is None: + return to_numba_type(param_type.unqualified_non_ref_type_name) + return nbtypes.UniTuple(resolve_out_array_dtype(spec.dtype), spec.length) + + +def out_return_types_for_plan(param_types, plan: IntentPlan): + specs = get_out_array_return_specs(plan) + return [ + out_return_type_for_param(param_types[i], specs[i]) + for i in plan.out_return_indices + ] + + +def normalize_out_array_return_specs(plan: IntentPlan) -> IntentPlan: + specs = get_out_array_return_specs(plan) + if not any(spec is not None for spec in specs): + return plan + normalized_specs = tuple( + replace(spec, dtype=resolve_out_array_dtype(spec.dtype)) + if spec is not None + else None + for spec in specs + ) + return IntentPlan( + intents=plan.intents, + visible_param_indices=plan.visible_param_indices, + out_return_indices=plan.out_return_indices, + pass_ptr_mask=plan.pass_ptr_mask, + out_array_return_specs=normalized_specs, + ) + + +def compose_return_type(cxx_return_type, out_return_types): + if not out_return_types: + return cxx_return_type + if cxx_return_type == nbtypes.void: + if len(out_return_types) == 1: + return out_return_types[0] + return nbtypes.Tuple(tuple(out_return_types)) + return nbtypes.Tuple(tuple([cxx_return_type, *out_return_types])) + + +def prepend_receiver_to_intent_plan(method_plan: IntentPlan) -> IntentPlan: + return IntentPlan( + intents=(ArgIntent.in_,) + method_plan.intents, + visible_param_indices=(0,) + + tuple(i + 1 for i in method_plan.visible_param_indices), + out_return_indices=tuple(i + 1 for i in method_plan.out_return_indices), + pass_ptr_mask=(False,) + method_plan.pass_ptr_mask, + out_array_return_specs=(None,) + + get_out_array_return_specs(method_plan), + ) + + +def shim_arg_type_for_out_return( + out_return_type, spec: OutArrayReturnSpec | None +): + if spec is None: + return out_return_type + return nbtypes.CPointer(resolve_out_array_dtype(spec.dtype)) diff --git a/numbast/src/numbast/static/function.py b/numbast/src/numbast/static/function.py index 3627bb99..756964c7 100644 --- a/numbast/src/numbast/static/function.py +++ b/numbast/src/numbast/static/function.py @@ -16,7 +16,11 @@ get_shim, get_callconv_utils, ) -from numbast.static.types import to_numba_type_str, to_numba_arg_type_str +from numbast.static.types import ( + to_numba_arg_type_str, + to_numba_out_array_type_str, + to_numba_type_str, +) from numbast.intent import ArgIntent, compute_intent_plan from numbast.utils import make_function_shim, _apply_prefix_removal from numbast.errors import TypeNotFoundError, MangledFunctionNameConflictError @@ -60,6 +64,64 @@ def _matches_any_regex_pattern(name: str, patterns: list[str]) -> bool: return False +def _tuple_literal(items: list[str]) -> str: + if not items: + return "()" + if len(items) == 1: + return f"({items[0]},)" + return f"({', '.join(items)})" + + +def _out_array_specs_for_plan(plan): + specs = getattr(plan, "out_array_return_specs", ()) + if not specs: + return (None,) * len(plan.intents) + if len(specs) != len(plan.intents): + raise ValueError( + "IntentPlan out_array_return_specs length does not match intents: " + f"{len(specs)} != {len(plan.intents)}" + ) + return tuple(specs) + + +def _out_return_type_str(param_type, spec) -> str: + if spec is None: + return to_numba_type_str(param_type.unqualified_non_ref_type_name) + return to_numba_out_array_type_str(spec.dtype, spec.length) + + +def _compose_return_type_str( + cxx_return_type_str: str, out_return_types: list[str] +): + if not out_return_types: + return cxx_return_type_str + if cxx_return_type_str == "void": + if len(out_return_types) == 1: + return out_return_types[0] + outs = ", ".join(out_return_types) + return f"types.Tuple(({outs},))" + outs = ", ".join([cxx_return_type_str, *out_return_types]) + return f"types.Tuple(({outs},))" + + +def _render_out_array_specs(plan) -> str: + specs = _out_array_specs_for_plan(plan) + rendered = [] + for spec in specs: + if spec is None: + rendered.append("None") + else: + dtype_str = to_numba_type_str(spec.dtype) + rendered.append( + "OutArrayReturnSpec(" + f"dtype={dtype_str}, " + f"length={int(spec.length)}, " + f"shim_arg_indirect={bool(spec.shim_arg_indirect)}" + ")" + ) + return _tuple_literal(rendered) + + class StaticFunctionRenderer(BaseRenderer): """Base class for function static bindings renderer. @@ -220,43 +282,20 @@ def __init__( ) out_return_types = [ - to_numba_type_str( - self._decl.param_types[i].unqualified_non_ref_type_name + _out_return_type_str( + self._decl.param_types[i], + _out_array_specs_for_plan(plan)[i], ) for i in plan.out_return_indices ] if out_return_types: self.Imports.add("from numba import types") - if self._cxx_return_type_str == "void": - if len(out_return_types) == 1: - self._return_numba_type_str = out_return_types[0] - else: - outs = ", ".join(out_return_types) - self._return_numba_type_str = f"types.Tuple(({outs},))" - else: - outs = ", ".join( - [self._cxx_return_type_str, *out_return_types] - ) - self._return_numba_type_str = f"types.Tuple(({outs},))" + self._return_numba_type_str = _compose_return_type_str( + self._cxx_return_type_str, out_return_types + ) else: self._return_numba_type_str = self._cxx_return_type_str - def _tuple_literal(items: list[str]) -> str: - """ - Builds a Python tuple literal from a list of string expressions. - - Parameters: - items (list[str]): String representations of tuple elements. - - Returns: - tuple_literal (str): A Python tuple literal. For an empty list returns "()"; for a single element returns "(element,)" (includes the trailing comma); otherwise returns "(elem1, elem2, ...)". - """ - if not items: - return "()" - if len(items) == 1: - return f"({items[0]},)" - return f"({', '.join(items)})" - intents_str = _tuple_literal( [ f"ArgIntent.{i.value if i != ArgIntent.in_ else 'in_'}" @@ -271,7 +310,8 @@ def _tuple_literal(items: list[str]) -> str: f"intents={intents_str}, " f"visible_param_indices={visible_str}, " f"out_return_indices={out_str}, " - f"pass_ptr_mask={mask_str}" + f"pass_ptr_mask={mask_str}, " + f"out_array_return_specs={_render_out_array_specs(plan)}" ")" ) if out_return_types: diff --git a/numbast/src/numbast/static/struct.py b/numbast/src/numbast/static/struct.py index d2dec854..dabd969c 100644 --- a/numbast/src/numbast/static/struct.py +++ b/numbast/src/numbast/static/struct.py @@ -20,9 +20,10 @@ get_callconv_utils, ) from numbast.static.types import ( - to_numba_type_str, - to_numba_arg_type_str, CTYPE_TO_NBTYPE_STR, + to_numba_arg_type_str, + to_numba_out_array_type_str, + to_numba_type_str, ) from numbast.intent import ArgIntent, IntentPlan, compute_intent_plan from numbast.utils import ( @@ -39,6 +40,64 @@ file_logger.addHandler(FileHandler(logger_path)) +def _tuple_literal(items: list[str]) -> str: + if not items: + return "()" + if len(items) == 1: + return f"({items[0]},)" + return f"({', '.join(items)})" + + +def _out_array_specs_for_plan(plan): + specs = getattr(plan, "out_array_return_specs", ()) + if not specs: + return (None,) * len(plan.intents) + if len(specs) != len(plan.intents): + raise ValueError( + "IntentPlan out_array_return_specs length does not match intents: " + f"{len(specs)} != {len(plan.intents)}" + ) + return tuple(specs) + + +def _out_return_type_str(param_type, spec) -> str: + if spec is None: + return to_numba_type_str(param_type.unqualified_non_ref_type_name) + return to_numba_out_array_type_str(spec.dtype, spec.length) + + +def _compose_return_type_str( + cxx_return_type_str: str, out_return_types: list[str] +): + if not out_return_types: + return cxx_return_type_str + if cxx_return_type_str == "void": + if len(out_return_types) == 1: + return out_return_types[0] + outs = ", ".join(out_return_types) + return f"types.Tuple(({outs},))" + outs = ", ".join([cxx_return_type_str, *out_return_types]) + return f"types.Tuple(({outs},))" + + +def _render_out_array_specs(plan) -> str: + specs = _out_array_specs_for_plan(plan) + rendered = [] + for spec in specs: + if spec is None: + rendered.append("None") + else: + dtype_str = to_numba_type_str(spec.dtype) + rendered.append( + "OutArrayReturnSpec(" + f"dtype={dtype_str}, " + f"length={int(spec.length)}, " + f"shim_arg_indirect={bool(spec.shim_arg_indirect)}" + ")" + ) + return _tuple_literal(rendered) + + class StaticStructMethodRenderer(BaseRenderer): """Base class for all struct methods TODO: merge all common code paths @@ -668,6 +727,8 @@ def __init__( i + 1 for i in method_plan.out_return_indices ), pass_ptr_mask=(False,) + method_plan.pass_ptr_mask, + out_array_return_specs=(None,) + + _out_array_specs_for_plan(method_plan), ) self._arg_is_ref = None @@ -691,44 +752,33 @@ def __init__( ) out_return_types = [ - to_numba_type_str( - self._method_decl.param_types[ - i - ].unqualified_non_ref_type_name + _out_return_type_str( + self._method_decl.param_types[i], + _out_array_specs_for_plan(method_plan)[i], ) for i in method_plan.out_return_indices ] if out_return_types: self.Imports.add("from numba import types") - if self._cxx_return_type_str == "void": - if len(out_return_types) == 1: - self._nb_return_type_str = out_return_types[0] - else: - outs = ", ".join(out_return_types) - self._nb_return_type_str = f"types.Tuple(({outs},))" - else: - outs = ", ".join( - [self._cxx_return_type_str, *out_return_types] - ) - self._nb_return_type_str = f"types.Tuple(({outs},))" + self._nb_return_type_str = _compose_return_type_str( + self._cxx_return_type_str, out_return_types + ) else: self._nb_return_type_str = self._cxx_return_type_str - intents_str = ( - "(" - + ", ".join( + intents_str = _tuple_literal( + [ f"ArgIntent.{i.value if i != ArgIntent.in_ else 'in_'}" for i in intent_plan.intents - ) - + ("," if len(intent_plan.intents) == 1 else "") - + ")" + ] ) self._intent_plan_rendered = ( "IntentPlan(" f"intents={intents_str}, " f"visible_param_indices={repr(intent_plan.visible_param_indices)}, " f"out_return_indices={repr(intent_plan.out_return_indices)}, " - f"pass_ptr_mask={repr(intent_plan.pass_ptr_mask)}" + f"pass_ptr_mask={repr(intent_plan.pass_ptr_mask)}, " + f"out_array_return_specs={_render_out_array_specs(intent_plan)}" ")" ) if out_return_types: diff --git a/numbast/src/numbast/static/tests/data/function_out.cuh b/numbast/src/numbast/static/tests/data/function_out.cuh index 0ba25f4c..14abf651 100644 --- a/numbast/src/numbast/static/tests/data/function_out.cuh +++ b/numbast/src/numbast/static/tests/data/function_out.cuh @@ -10,3 +10,6 @@ void __device__ add_out(int &out, int x); int __device__ add_out_ret(int &out, int x); int __device__ add_in_ref(int &x); void __device__ add_inout_ref(int &x, int delta); +void __device__ get_matrix(float out[12]); +void __device__ get_matrix_3x4(float out[3][4]); +void __device__ get_data(float4 out[3]); diff --git a/numbast/src/numbast/static/tests/data/src/function_out.cu b/numbast/src/numbast/static/tests/data/src/function_out.cu index 037f85f3..13fffe71 100644 --- a/numbast/src/numbast/static/tests/data/src/function_out.cu +++ b/numbast/src/numbast/static/tests/data/src/function_out.cu @@ -37,3 +37,23 @@ int __device__ add_in_ref(int &x) { return x + 5; } * @param delta Amount to add to `x`. */ void __device__ add_inout_ref(int &x, int delta) { x += delta; } + +void __device__ get_matrix(float out[12]) { + for (int i = 0; i < 12; ++i) { + out[i] = static_cast(i) + 0.5f; + } +} + +void __device__ get_matrix_3x4(float out[3][4]) { + for (int row = 0; row < 3; ++row) { + for (int col = 0; col < 4; ++col) { + out[row][col] = static_cast(row * 4 + col) + 1.25f; + } + } +} + +void __device__ get_data(float4 out[3]) { + out[0] = make_float4(1.0f, 2.0f, 3.0f, 4.0f); + out[1] = make_float4(5.0f, 6.0f, 7.0f, 8.0f); + out[2] = make_float4(9.0f, 10.0f, 11.0f, 12.0f); +} diff --git a/numbast/src/numbast/static/tests/test_function_static_bindings.py b/numbast/src/numbast/static/tests/test_function_static_bindings.py index 418bd90a..041a956f 100644 --- a/numbast/src/numbast/static/tests/test_function_static_bindings.py +++ b/numbast/src/numbast/static/tests/test_function_static_bindings.py @@ -172,6 +172,49 @@ def test_generated_callconv_alignof_helper_is_standalone(make_binding): assert "def get_numba_type_alignof(numba_type):" in src +def test_out_array_return_static_binding_source(make_binding): + intents = { + "get_matrix": { + "out": { + "intent": "out_array_return", + "dtype": "float", + "length": 12, + } + }, + "get_matrix_3x4": { + "out": { + "intent": "out_array_return", + "dtype": "float", + "length": 12, + } + }, + "get_data": { + "out": { + "intent": "out_array_return", + "dtype": "float4", + "length": 3, + } + }, + } + res = make_binding("function_out.cuh", {}, {}, "sm_50", intents) + bindings = res["bindings"] + src = res["src"] + + assert "get_matrix" in bindings + assert "get_matrix_3x4" in bindings + assert "get_data" in bindings + assert "signature(UniTuple(float32, 12), )" in src + assert "signature(UniTuple(float32x4, 3), )" in src + assert ( + "OutArrayReturnSpec(dtype=float32, length=12, shim_arg_indirect=True)" + ) in src + assert ( + "OutArrayReturnSpec(dtype=float32x4, length=3, shim_arg_indirect=True)" + ) in src + assert "float (**out)[4]" in src + assert "get_matrix_3x4(*out);" in src + + def test_out_return_function_bindings(decl_out, impl_out): add_out = decl_out["add_out"] add_out_ret = decl_out["add_out_ret"] diff --git a/numbast/src/numbast/static/types.py b/numbast/src/numbast/static/types.py index e7fd6ab0..ca3a00b0 100644 --- a/numbast/src/numbast/static/types.py +++ b/numbast/src/numbast/static/types.py @@ -9,12 +9,15 @@ from numbast.errors import TypeNotFoundError -_DEFAULT_CTYPE_TO_NBTYPE_STR_MAP = { - k: str(v) for k, v in CTYPE_MAPS.items() -} | { - "bool": "bool_", - "void": "void", -} +_DEFAULT_CTYPE_TO_NBTYPE_STR_MAP = ( + {k: str(v) for k, v in CTYPE_MAPS.items()} + | {str(v): str(v) for v in CTYPE_MAPS.values()} + | { + "bool": "bool_", + "bool_": "bool_", + "void": "void", + } +) CTYPE_TO_NBTYPE_STR = copy.deepcopy(_DEFAULT_CTYPE_TO_NBTYPE_STR_MAP) @@ -80,6 +83,11 @@ def to_numba_type_str(ty: str): TypeNotFoundError: If `ty` has no known mapping to a Numba type. """ + if not isinstance(ty, str): + nb_type_str = str(ty) + BaseRenderer._try_import_numba_type(nb_type_str) + return nb_type_str + if ty == "__nv_bfloat16": BaseRenderer._try_import_numba_type("__nv_bfloat16") return "bfloat16" @@ -122,6 +130,12 @@ def to_numba_type_str(ty: str): return nb_type_str +def to_numba_out_array_type_str(dtype: str, length: int) -> str: + dtype_str = to_numba_type_str(dtype) + BaseRenderer._try_import_numba_type("UniTuple") + return f"UniTuple({dtype_str}, {int(length)})" + + def to_numba_arg_type_str(ast_type) -> str: """ Convert an AST Canopy Type to the corresponding Numba type string for use in function argument typing. diff --git a/numbast/src/numbast/struct.py b/numbast/src/numbast/struct.py index 7a421db8..43d68a36 100644 --- a/numbast/src/numbast/struct.py +++ b/numbast/src/numbast/struct.py @@ -23,7 +23,13 @@ from ast_canopy.decl import Struct, StructMethod from numbast.types import CTYPE_MAPS as C2N, to_numba_type, to_numba_arg_type -from numbast.intent import ArgIntent, IntentPlan, compute_intent_plan +from numbast.intent import compute_intent_plan +from numbast.intent_utils import ( + compose_return_type, + normalize_out_array_return_specs, + out_return_types_for_plan, + prepend_receiver_to_intent_plan, +) from numbast.utils import ( deduplicate_overloads, make_struct_regular_method_shim, @@ -312,15 +318,8 @@ def bind_cxx_struct_regular_method( overrides=overrides, allow_out_return=True, ) - intent_plan = IntentPlan( - intents=(ArgIntent.in_,) + method_plan.intents, - visible_param_indices=(0,) - + tuple(i + 1 for i in method_plan.visible_param_indices), - out_return_indices=tuple( - i + 1 for i in method_plan.out_return_indices - ), - pass_ptr_mask=(False,) + method_plan.pass_ptr_mask, - ) + method_plan = normalize_out_array_return_specs(method_plan) + intent_plan = prepend_receiver_to_intent_plan(method_plan) # Visible param types for @lower exclude receiver param_types = [] @@ -333,24 +332,10 @@ def bind_cxx_struct_regular_method( else: param_types.append(base) - out_return_types = [ - to_numba_type( - method_decl.param_types[i].unqualified_non_ref_type_name - ) - for i in method_plan.out_return_indices - ] - if out_return_types: - if cxx_return_type == nbtypes.void: - if len(out_return_types) == 1: - return_type = out_return_types[0] - else: - return_type = nbtypes.Tuple(tuple(out_return_types)) - else: - return_type = nbtypes.Tuple( - tuple([cxx_return_type, *out_return_types]) - ) - else: - return_type = cxx_return_type + out_return_types = out_return_types_for_plan( + method_decl.param_types, method_plan + ) + return_type = compose_return_type(cxx_return_type, out_return_types) arg_is_ref = None # Lowering diff --git a/numbast/src/numbast/tools/static_binding_generator.schema.yaml b/numbast/src/numbast/tools/static_binding_generator.schema.yaml index 7970bb14..5b55db4e 100644 --- a/numbast/src/numbast/tools/static_binding_generator.schema.yaml +++ b/numbast/src/numbast/tools/static_binding_generator.schema.yaml @@ -217,6 +217,11 @@ properties: - my_function: result: out_ptr 0: in + - get_matrix: + out: + intent: out_array_return + dtype: float + length: 12 additionalProperties: type: object additionalProperties: @@ -229,3 +234,14 @@ properties: type: string enum: ["in", "inout_ptr", "out_ptr", "out_return"] required: ["intent"] + - type: object + properties: + intent: + type: string + const: out_array_return + dtype: + type: string + length: + type: integer + minimum: 1 + required: ["intent", "dtype", "length"] diff --git a/numbast/src/numbast/types.py b/numbast/src/numbast/types.py index 127aaec0..d5d0f5fb 100644 --- a/numbast/src/numbast/types.py +++ b/numbast/src/numbast/types.py @@ -253,8 +253,15 @@ def to_numba_type(ty: str): base_ty, size = is_array_type.groups() return nbtypes.UniTuple(to_numba_type(base_ty), int(size)) + if ty in CTYPE_MAPS: + return CTYPE_MAPS[ty] + + for registered_type in CTYPE_MAPS.values(): + if str(registered_type) == ty: + return registered_type + # For any type that's unknown / not yet supported, return an opaque type. - return CTYPE_MAPS.get(ty, nbtypes.Opaque(ty)) + return nbtypes.Opaque(ty) def to_numba_arg_type(ast_type) -> nbtypes.Type: diff --git a/numbast/src/numbast/utils.py b/numbast/src/numbast/utils.py index ba6f1bfe..efaa9f29 100644 --- a/numbast/src/numbast/utils.py +++ b/numbast/src/numbast/utils.py @@ -11,6 +11,44 @@ OVERLOADS_CNT: dict[str, int] = defaultdict(int) # overload counter +_ARRAY_POINTER_TYPE_RE = re.compile( + r"^(?P.+?)(?P(?:\[\d+\])+)\s*(?P\*+)\s*$" +) + + +def _canonicalize_array_pointer_type(type_name: str) -> str: + """Normalize ast_canopy's ``T[N] *`` spelling to C declarator form.""" + match = _ARRAY_POINTER_TYPE_RE.match(type_name) + if not match: + return type_name + + base = match.group("base").strip() + sizes = match.group("sizes") + pointers = match.group("pointers") + return f"{base} ({pointers}){sizes}" + + +def _param_type_name_to_pointer_arg(type_name: str, arg_name: str) -> str: + array_pattern = r"(.*)(\[\d+\]+)" + type_name = _canonicalize_array_pointer_type(type_name) + + # For each of the arguments, elevate to pointer type. + match = re.match(array_pattern, type_name) + if match: + # Array type + base_ty, sizes = match.groups() + if "*" in base_ty: + # Pointer to array type: int (*arr)[10] + loc = base_ty.rfind("*") + return ( + base_ty[: loc + 1] + f"*{arg_name}" + base_ty[loc + 1 :] + sizes + ) + + # Regular array type: int arr[10] + return base_ty + f" (*{arg_name})" + sizes + + return f"{type_name}* {arg_name}" + def make_device_caller_with_nargs( name: str, nargs: int, wrapped: ExternFunction @@ -72,26 +110,10 @@ def paramvar_to_str(arg: pylibastcanopy.ParamVar): Performs necessary downcasting of array-typed ``ParamVar`` to pointer types. """ - array_pattern = r"(.*)(\[\d+\]+)" - - # For each of the arguments, elevate to pointer type. - match = re.match(array_pattern, arg.type_.unqualified_non_ref_type_name) - if match: - # Array type - base_ty, sizes = match.groups() - if "*" in base_ty: - # Pointer to array type: int (*arr)[10] - loc = base_ty.rfind("*") - fml_arg = ( - base_ty[: loc + 1] + f"*{arg.name}" + base_ty[loc + 1 :] + sizes - ) - else: - # Regular array type: int arr[10] - fml_arg = base_ty + f" (*{arg.name})" + sizes - else: - fml_arg = f"{arg.type_.unqualified_non_ref_type_name}* {arg.name}" - - return fml_arg + return _param_type_name_to_pointer_arg( + arg.type_.unqualified_non_ref_type_name, + arg.name, + ) def assemble_arglist_string(params: list[pylibastcanopy.ParamVar]) -> str: diff --git a/numbast/tests/data/sample_function_out.cuh b/numbast/tests/data/sample_function_out.cuh index ef56004b..63a4cea8 100644 --- a/numbast/tests/data/sample_function_out.cuh +++ b/numbast/tests/data/sample_function_out.cuh @@ -13,3 +13,23 @@ __device__ int add_out_ret(int &out, int x) { } __device__ int add_in_ref(int &x) { return x + 5; } + +__device__ void get_matrix(float out[12]) { + for (int i = 0; i < 12; ++i) { + out[i] = static_cast(i) + 0.5f; + } +} + +__device__ void get_matrix_3x4(float out[3][4]) { + for (int row = 0; row < 3; ++row) { + for (int col = 0; col < 4; ++col) { + out[row][col] = static_cast(row * 4 + col) + 1.25f; + } + } +} + +__device__ void get_data(float4 out[3]) { + out[0] = make_float4(1.0f, 2.0f, 3.0f, 4.0f); + out[1] = make_float4(5.0f, 6.0f, 7.0f, 8.0f); + out[2] = make_float4(9.0f, 10.0f, 11.0f, 12.0f); +} diff --git a/numbast/tests/test_callconv.py b/numbast/tests/test_callconv.py index b3cf53b0..4cb73a66 100644 --- a/numbast/tests/test_callconv.py +++ b/numbast/tests/test_callconv.py @@ -10,7 +10,7 @@ from numba.cuda.descriptor import cuda_target from numbast.callconv import FunctionCallConv, _get_alloca_alignment -from numbast.intent_defs import ArgIntent, IntentPlan +from numbast.intent_defs import ArgIntent, IntentPlan, OutArrayReturnSpec from numbast.types import CTYPE_MAPS, get_numba_type_alignof @@ -120,3 +120,63 @@ def test_intent_plan_allocas_are_aligned_in_lowered_ir(): assert re.search( r"load \{float, float, float, float\}, .* align 16", llvm_ir ) + + +def test_out_array_return_allocates_stack_storage_and_returns_unituple(): + plan = IntentPlan( + intents=(ArgIntent.out_array_return,), + visible_param_indices=(), + out_return_indices=(0,), + pass_ptr_mask=(), + out_array_return_specs=( + OutArrayReturnSpec( + dtype=cuda_types.float32, + length=12, + shim_arg_indirect=True, + ), + ), + ) + return_type = cuda_types.UniTuple(cuda_types.float32, 12) + + llvm_ir = _lower_callconv_to_ir( + return_type=return_type, + args=(), + intent_plan=plan, + out_return_types=(return_type,), + cxx_return_type=cuda_types.void, + ) + + assert re.search(r"alloca float, i64 12, align 4", llvm_ir) + assert "float**" in llvm_ir + assert llvm_ir.count("load float, float*") >= 12 + + +def test_out_array_return_honors_vector_element_alignment(): + float4 = CTYPE_MAPS["float4"] + plan = IntentPlan( + intents=(ArgIntent.out_array_return,), + visible_param_indices=(), + out_return_indices=(0,), + pass_ptr_mask=(), + out_array_return_specs=( + OutArrayReturnSpec( + dtype=float4, + length=3, + shim_arg_indirect=True, + ), + ), + ) + return_type = cuda_types.UniTuple(float4, 3) + + llvm_ir = _lower_callconv_to_ir( + return_type=return_type, + args=(), + intent_plan=plan, + out_return_types=(return_type,), + cxx_return_type=cuda_types.void, + ) + + assert "alloca {float, float, float, float}, i64 3, align 16" in llvm_ir + assert re.search( + r"load \{float, float, float, float\}, .* align 16", llvm_ir + ) diff --git a/numbast/tests/test_function.py b/numbast/tests/test_function.py index 83b23017..bff25b4d 100644 --- a/numbast/tests/test_function.py +++ b/numbast/tests/test_function.py @@ -6,11 +6,13 @@ import numpy as np from numba import cuda +from numba.cuda.types import float32 import cffi from ast_canopy import parse_declarations_from_source -from numbast import bind_cxx_functions, MemoryShimWriter +from numbast import bind_cxx_functions, MemoryShimWriter, out_array_return +from numbast.types import CTYPE_MAPS import pytest @@ -146,6 +148,13 @@ def _sample_out_functions(): arg_intent={ "add_out": {"out": "out_return"}, "add_out_ret": {"out": "out_return"}, + "get_matrix": {"out": out_array_return(dtype=float32, length=12)}, + "get_matrix_3x4": { + "out": out_array_return(dtype=float32, length=12) + }, + "get_data": { + "out": out_array_return(dtype=CTYPE_MAPS["float4"], length=3) + }, }, ) @@ -195,6 +204,40 @@ def kernel(out_single, out_pair): assert out_pair[0] == 10 assert out_pair[1] == 9 + get_matrix = find_binding(func_bindings, "get_matrix") + get_matrix_3x4 = find_binding(func_bindings, "get_matrix_3x4") + get_data = find_binding(func_bindings, "get_data") + + @cuda.jit(link=shim_writer.links()) + def kernel_arrays(out_matrix, out_matrix_3x4, out_data): + matrix = get_matrix() + for i in range(12): + out_matrix[i] = matrix[i] + + matrix_3x4 = get_matrix_3x4() + for i in range(12): + out_matrix_3x4[i] = matrix_3x4[i] + + data = get_data() + for i in range(3): + item = data[i] + out_data[i] = item.x + item.y + item.z + item.w + + out_matrix = np.zeros(12, dtype=np.float32) + out_matrix_3x4 = np.zeros(12, dtype=np.float32) + out_data = np.zeros(3, dtype=np.float32) + kernel_arrays[1, 1](out_matrix, out_matrix_3x4, out_data) + np.testing.assert_allclose( + out_matrix, np.arange(12, dtype=np.float32) + np.float32(0.5) + ) + np.testing.assert_allclose( + out_matrix_3x4, + np.arange(12, dtype=np.float32) + np.float32(1.25), + ) + np.testing.assert_allclose( + out_data, np.array([10, 26, 42], dtype=np.float32) + ) + DATA_FOLDER = os.path.join(os.path.dirname(__file__), "data") p = os.path.join(DATA_FOLDER, "sample_function_out.cuh") decls = parse_declarations_from_source(p, [p], "sm_80", verbose=True) diff --git a/numbast/tests/test_utils.py b/numbast/tests/test_utils.py new file mode 100644 index 00000000..6bba17a6 --- /dev/null +++ b/numbast/tests/test_utils.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from ast_canopy import parse_declarations_from_source + +from numbast.utils import make_function_shim + + +def test_make_function_shim_preserves_multidimensional_array_pointer(tmp_path): + header = tmp_path / "matrix.cuh" + header.write_text( + "__device__ void get_matrix_3x4(float out[3][4]);\n", + encoding="utf-8", + ) + + decls = parse_declarations_from_source(str(header), [str(header)], "sm_50") + func_decl = decls.functions[0] + + shim = make_function_shim( + "shim", + func_decl.name, + func_decl.return_type.unqualified_non_ref_type_name, + func_decl.params, + ) + + assert "float (**out)[4]" in shim + assert "get_matrix_3x4(*out);" in shim