[Relax][Frontend][TFLite] Implement DETECTION_POSTPROCESS tflite operator#19345
[Relax][Frontend][TFLite] Implement DETECTION_POSTPROCESS tflite operator#19345Aharrypotter wants to merge 2 commits intoapache:mainfrom
Conversation
…ator This commit wires up the TFLite_Detection_PostProcess custom operator in the Relax TFLite frontend. Key changes include: - Implemented conversion logic using multibox_transform_loc and all_class_non_max_suppression. - Added support for both regular NMS and class-agnostic NMS paths via 'use_regular_nms'. - Properly formatted outputs (boxes, classes, scores, num_detections) to match TFLite spec. - Added strict validation for required custom options (num_classes, scales, etc.).
There was a problem hiding this comment.
Code Review
This pull request implements the DETECTION_POSTPROCESS operator in the TFLite frontend for Relax, including attribute validation and NMS logic. Feedback highlights several critical issues: invalid Python unpacking of relax.Call objects (e.g., from topk), premature access to struct_info.dtype before normalization, a potential shape mismatch in mask generation, and robustness concerns regarding dynamic batch sizes in slicing operations.
| multibox_transform_loc_attrs["keep_background"] = use_regular_nms | ||
|
|
||
| ret = relax.op.vision.multibox_transform_loc( | ||
| transformed_boxes, transformed_scores = relax.op.vision.multibox_transform_loc( |
There was a problem hiding this comment.
The relax.op.vision.multibox_transform_loc operator returns a relax.Call object, which is not iterable in Python. Attempting to unpack it into transformed_boxes, transformed_scores will raise a TypeError. You should assign the result to a single variable and then index it to retrieve the individual outputs.
res = relax.op.vision.multibox_transform_loc(
# reshape cls_pred so it can be consumed by
# multibox_transform_loc
relax.op.permute_dims(cls_pred, [0, 2, 1]),
loc_prob,
anchor_expr,
**multibox_transform_loc_attrs,
)
transformed_boxes = res[0]
transformed_scores = res[1]| num_detections = nms_out[2] | ||
| class_id_from_score = None | ||
| else: | ||
| max_scores, class_id_from_score = relax.op.topk( |
There was a problem hiding this comment.
The relax.op.topk operator returns a relax.Call object, which is not iterable. Unpacking it directly into max_scores, class_id_from_score will fail. Please assign the result to a variable and index it.
res_topk = relax.op.topk(
transformed_scores, k=1, axis=1, ret_type="both", largest=True
)
max_scores = res_topk[0]
class_id_from_score = res_topk[1]| masked_selected_scores = relax.op.where( | ||
| valid_detection_mask, | ||
| selected_scores, | ||
| relax.const(-1.0, selected_scores.struct_info.dtype), |
There was a problem hiding this comment.
Accessing selected_scores.struct_info.dtype will raise an AttributeError because struct_info is not populated on relax.Expr objects until they are normalized by the BlockBuilder. Since this code is inside the conversion function and normalization happens later, you should use a known dtype string like "float32" or retrieve it from the TFLite tensor wrapper.
| relax.const(-1.0, selected_scores.struct_info.dtype), | |
| relax.const(-1.0, "float32"), |
| selected_scores, | ||
| relax.const(-1.0, selected_scores.struct_info.dtype), | ||
| ) | ||
| detection_scores, top_positions = relax.op.topk( |
There was a problem hiding this comment.
The relax.op.topk operator returns a relax.Call object, which is not iterable. Unpacking it directly will fail.
| detection_scores, top_positions = relax.op.topk( | |
| res_topk_scores = relax.op.topk( | |
| masked_selected_scores, k=max_detections, axis=1, ret_type="both", largest=True | |
| ) | |
| detection_scores = res_topk_scores[0] | |
| top_positions = res_topk_scores[1] |
| detection_mask, | ||
| detection_boxes, | ||
| relax.op.zeros( | ||
| (batch_size, max_detections, 4), dtype=detection_boxes.struct_info.dtype |
| detection_positions = relax.op.expand_dims( | ||
| relax.op.arange(max_detections, dtype="int64"), axis=0 |
There was a problem hiding this comment.
The detection_positions tensor used to create the valid_detection_mask must have the same shape as the second dimension of selected_scores (which is num_total_boxes_per_batch) to be used in relax.op.where. Using max_detections here will cause a shape mismatch error if max_detections is different from the total number of boxes returned by NMS.
num_total_boxes_per_batch = anchor_boxes * (num_classes if use_regular_nms else 1)
detection_positions = relax.op.expand_dims(
relax.op.arange(num_total_boxes_per_batch, dtype="int64"), axis=0
)| ) | ||
| top_box_ids = relax.op.squeeze( | ||
| relax.op.strided_slice( | ||
| top_index_pairs, begin=[0, 0, 1], end=[batch_size, max_detections, 2] |
There was a problem hiding this comment.
Using batch_size in the end parameter of strided_slice can be problematic if the model has a dynamic batch size (where batch_size might be -1). It is safer to use the axes parameter to slice only the specific dimensions you need, leaving the batch dimension untouched.
| top_index_pairs, begin=[0, 0, 1], end=[batch_size, max_detections, 2] | |
| top_index_pairs, begin=[0, 1], end=[max_detections, 2], axes=[1, 2] |
Summary
DETECTION_POSTPROCESSoperator conversion to Relax IR.Changes
convert_detection_postprocessinpython/tvm/relax/frontend/tflite/tflite_frontend.py.multibox_transform_locfor coordinate decoding and variance scaling.use_regular_nmsattribute to switch between all-class NMS and class-agnostic NMS paths.all_class_non_max_suppressionfor efficient box filtering.topk,gather_nd, andwhereoperators to ensure the output tensors (boxes, classes, scores, num_detections) match the TFLite specification in terms of shape and layout.num_classes,max_detections, and scaling factors.Validation
Verified with linting and pre-commit hooks:
Result: