-
Notifications
You must be signed in to change notification settings - Fork 22
[codex] add fixed-size out array returns #341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0b29c23
76e2e0c
2a3b30c
6f7fdcd
e5a9f9f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,20 @@ class _OutReturnPtr(NamedTuple): | |
| numba_ty: types.Type | ||
| ptr: ir.Value | ||
| align: int | ||
| array_length: int | None = None | ||
| array_is_aggregate: bool = False | ||
|
|
||
|
|
||
| def _get_out_array_return_specs(plan): | ||
| specs = getattr(plan, "out_array_return_specs", ()) | ||
| if not specs: | ||
| return (None,) * len(plan.intents) | ||
| if len(specs) != len(plan.intents): | ||
| raise ValueError( | ||
| "IntentPlan out_array_return_specs length does not match intents: " | ||
| f"{len(specs)} != {len(plan.intents)}" | ||
| ) | ||
| return tuple(specs) | ||
|
Comment on lines
+24
to
+33
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fail fast when Line 26 currently treats missing Suggested guard def _get_out_array_return_specs(plan):
specs = getattr(plan, "out_array_return_specs", ())
- if not specs:
- return (None,) * len(plan.intents)
+ if not specs:
+ if any(getattr(i, "value", i) == "out_array_return" for i in plan.intents):
+ raise ValueError(
+ "IntentPlan contains out_array_return intent but out_array_return_specs is missing"
+ )
+ return (None,) * len(plan.intents)
if len(specs) != len(plan.intents):
raise ValueError(
"IntentPlan out_array_return_specs length does not match intents: "
f"{len(specs)} != {len(plan.intents)}"
)
+ for idx, (intent, spec) in enumerate(zip(plan.intents, specs)):
+ if getattr(intent, "value", intent) == "out_array_return" and spec is None:
+ raise ValueError(
+ f"Missing out_array_return spec for intent index {idx}"
+ )
return tuple(specs)🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| def _get_alloca_alignment(context, value_ty, numba_ty=None): | ||
|
|
@@ -198,22 +212,95 @@ def _lower_impl(self, builder, context, sig, args): | |
| for out_pos, orig_idx in enumerate(plan.out_return_indices): | ||
| orig_to_out[orig_idx] = out_pos | ||
|
|
||
| out_array_specs = _get_out_array_return_specs(plan) | ||
|
|
||
| for orig_idx in range(n_orig): | ||
| out_pos = orig_to_out[orig_idx] | ||
| if out_pos is not None: | ||
| out_nbty = self._out_return_types[out_pos] | ||
| vty = context.get_value_type(out_nbty) | ||
| ptr = cgutils.alloca_once(builder, vty) | ||
| ptr_align = _set_alloca_alignment( | ||
| ptr, context, vty, out_nbty | ||
| ) | ||
| ptrs.append(ptr) | ||
| arg_pointer_types.append(ir.PointerType(vty)) | ||
| out_return_ptrs.append( | ||
| _OutReturnPtr( | ||
| numba_ty=out_nbty, ptr=ptr, align=ptr_align | ||
| array_spec = out_array_specs[orig_idx] | ||
| if array_spec is not None: | ||
| elem_nbty = array_spec.dtype | ||
| elem_vty = context.get_value_type(elem_nbty) | ||
| length = int(array_spec.length) | ||
| if array_spec.shim_arg_indirect: | ||
| # Native pointer output parameters are shimmed as | ||
| # pointer-to-pointer arguments. Store the raw stack | ||
| # buffer pointer in a pointer slot, then pass that | ||
| # slot to the shim; the shim dereferences once and | ||
| # the native function receives the raw pointer. | ||
| storage_ptr = cgutils.alloca_once( | ||
| builder, | ||
| elem_vty, | ||
| size=length, | ||
| name="out_array_return", | ||
| ) | ||
| elem_align = _set_alloca_alignment( | ||
| storage_ptr, context, elem_vty, elem_nbty | ||
| ) | ||
| ptr_slot_ty = ir.PointerType(elem_vty) | ||
| ptr_slot = cgutils.alloca_once( | ||
| builder, | ||
| ptr_slot_ty, | ||
| name="out_array_return_ptr", | ||
| ) | ||
| ptr_slot_align = _set_alloca_alignment( | ||
| ptr_slot, context, ptr_slot_ty | ||
| ) | ||
| builder.store( | ||
| storage_ptr, | ||
| ptr_slot, | ||
| align=ptr_slot_align, | ||
| ) | ||
| ptrs.append(ptr_slot) | ||
| arg_pointer_types.append( | ||
| ir.PointerType(ptr_slot_ty) | ||
| ) | ||
| out_return_ptrs.append( | ||
| _OutReturnPtr( | ||
| numba_ty=out_nbty, | ||
| ptr=storage_ptr, | ||
| align=elem_align, | ||
| array_length=length, | ||
| ) | ||
| ) | ||
| else: | ||
| array_vty = ir.ArrayType(elem_vty, length) | ||
| storage_ptr = cgutils.alloca_once( | ||
| builder, | ||
| array_vty, | ||
| name="out_array_return", | ||
| ) | ||
| elem_align = _get_alloca_alignment( | ||
| context, elem_vty, elem_nbty | ||
| ) | ||
| _set_alloca_alignment( | ||
| storage_ptr, context, array_vty | ||
| ) | ||
| ptrs.append(storage_ptr) | ||
| arg_pointer_types.append(ir.PointerType(array_vty)) | ||
| out_return_ptrs.append( | ||
| _OutReturnPtr( | ||
| numba_ty=out_nbty, | ||
| ptr=storage_ptr, | ||
| align=elem_align, | ||
| array_length=length, | ||
| array_is_aggregate=True, | ||
| ) | ||
| ) | ||
| else: | ||
| vty = context.get_value_type(out_nbty) | ||
| ptr = cgutils.alloca_once(builder, vty) | ||
| ptr_align = _set_alloca_alignment( | ||
| ptr, context, vty, out_nbty | ||
| ) | ||
| ptrs.append(ptr) | ||
| arg_pointer_types.append(ir.PointerType(vty)) | ||
| out_return_ptrs.append( | ||
| _OutReturnPtr( | ||
| numba_ty=out_nbty, ptr=ptr, align=ptr_align | ||
| ) | ||
| ) | ||
| ) | ||
| continue | ||
|
|
||
| vis_pos = orig_to_vis[orig_idx] | ||
|
|
@@ -261,20 +348,38 @@ def _lower_impl(self, builder, context, sig, args): | |
| if cxx_return_type != types.void: | ||
| ret_vals.append(builder.load(retval_ptr, align=retval_align)) | ||
| for out_return in out_return_ptrs: | ||
| if out_return.array_length is None: | ||
| ret_vals.append( | ||
| builder.load(out_return.ptr, align=out_return.align) | ||
| ) | ||
| continue | ||
|
|
||
| elems = [] | ||
| for i in range(out_return.array_length): | ||
| idx = ir.Constant(ir.IntType(32), i) | ||
| if out_return.array_is_aggregate: | ||
| elem_ptr = builder.gep( | ||
| out_return.ptr, | ||
| [ir.Constant(ir.IntType(32), 0), idx], | ||
| inbounds=True, | ||
| ) | ||
| else: | ||
| elem_ptr = builder.gep(out_return.ptr, [idx], inbounds=True) | ||
| elems.append(builder.load(elem_ptr, align=out_return.align)) | ||
| ret_vals.append( | ||
| builder.load(out_return.ptr, align=out_return.align) | ||
| context.make_tuple(builder, out_return.numba_ty, elems) | ||
| ) | ||
|
|
||
| # If Numba-visible return is a tuple, use context.make_tuple. | ||
| # Otherwise (void + single out), return the single out value directly. | ||
| # A single out-return value may itself be a UniTuple. Return it directly | ||
| # instead of wrapping it as another tuple. | ||
| if len(ret_vals) == 1: | ||
| return ret_vals[0] | ||
| if hasattr(sig.return_type, "types"): | ||
| return context.make_tuple(builder, sig.return_type, ret_vals) | ||
| if len(ret_vals) != 1: | ||
| raise ValueError( | ||
| "Non-tuple return type requires exactly one return value; " | ||
| f"got {len(ret_vals)}" | ||
| ) | ||
| return ret_vals[0] | ||
| raise ValueError( | ||
| "Multiple return values require a tuple return type; " | ||
| f"got {len(ret_vals)}" | ||
| ) | ||
|
|
||
|
|
||
| # NBST:END_CALLCONV | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason behind this design is because of the lack of fixed-size array return information from ast_canopy.