Skip to content

[Relax][Frontend][TFLite] Implement DETECTION_POSTPROCESS tflite operator#19345

Open
Aharrypotter wants to merge 2 commits intoapache:mainfrom
Aharrypotter:relax-onnx-detection-postprocess
Open

[Relax][Frontend][TFLite] Implement DETECTION_POSTPROCESS tflite operator#19345
Aharrypotter wants to merge 2 commits intoapache:mainfrom
Aharrypotter:relax-onnx-detection-postprocess

Conversation

@Aharrypotter
Copy link
Copy Markdown
Contributor

Summary

Changes

  • Operator Registration: Implemented convert_detection_postprocess in python/tvm/relax/frontend/tflite/tflite_frontend.py.
  • Core Logic:
    • Integrated multibox_transform_loc for coordinate decoding and variance scaling.
    • Supported use_regular_nms attribute to switch between all-class NMS and class-agnostic NMS paths.
    • Leveraged all_class_non_max_suppression for efficient box filtering.
  • Output Alignment: Used topk, gather_nd, and where operators to ensure the output tensors (boxes, classes, scores, num_detections) match the TFLite specification in terms of shape and layout.
  • Attribute Validation: Added strict validation for required custom options such as num_classes, max_detections, and scaling factors.

Validation

Verified with linting and pre-commit hooks:

# Lint check
python -m ruff check python/tvm/relax/frontend/tflite/tflite_frontend.py

# Pre-commit checks
python -m pre_commit run --files python/tvm/relax/frontend/tflite/tflite_frontend.py

Result:

  • Passed: All static checks and style guidelines are met.

…ator

This commit wires up the TFLite_Detection_PostProcess custom operator in the
Relax TFLite frontend. Key changes include:
- Implemented conversion logic using multibox_transform_loc and all_class_non_max_suppression.
- Added support for both regular NMS and class-agnostic NMS paths via 'use_regular_nms'.
- Properly formatted outputs (boxes, classes, scores, num_detections) to match TFLite spec.
- Added strict validation for required custom options (num_classes, scales, etc.).
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the DETECTION_POSTPROCESS operator in the TFLite frontend for Relax, including attribute validation and NMS logic. Feedback highlights several critical issues: invalid Python unpacking of relax.Call objects (e.g., from topk), premature access to struct_info.dtype before normalization, a potential shape mismatch in mask generation, and robustness concerns regarding dynamic batch sizes in slicing operations.

multibox_transform_loc_attrs["keep_background"] = use_regular_nms

ret = relax.op.vision.multibox_transform_loc(
transformed_boxes, transformed_scores = relax.op.vision.multibox_transform_loc(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The relax.op.vision.multibox_transform_loc operator returns a relax.Call object, which is not iterable in Python. Attempting to unpack it into transformed_boxes, transformed_scores will raise a TypeError. You should assign the result to a single variable and then index it to retrieve the individual outputs.

        res = relax.op.vision.multibox_transform_loc(
            # reshape cls_pred so it can be consumed by
            # multibox_transform_loc
            relax.op.permute_dims(cls_pred, [0, 2, 1]),
            loc_prob,
            anchor_expr,
            **multibox_transform_loc_attrs,
        )
        transformed_boxes = res[0]
        transformed_scores = res[1]

num_detections = nms_out[2]
class_id_from_score = None
else:
max_scores, class_id_from_score = relax.op.topk(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The relax.op.topk operator returns a relax.Call object, which is not iterable. Unpacking it directly into max_scores, class_id_from_score will fail. Please assign the result to a variable and index it.

            res_topk = relax.op.topk(
                transformed_scores, k=1, axis=1, ret_type="both", largest=True
            )
            max_scores = res_topk[0]
            class_id_from_score = res_topk[1]

masked_selected_scores = relax.op.where(
valid_detection_mask,
selected_scores,
relax.const(-1.0, selected_scores.struct_info.dtype),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Accessing selected_scores.struct_info.dtype will raise an AttributeError because struct_info is not populated on relax.Expr objects until they are normalized by the BlockBuilder. Since this code is inside the conversion function and normalization happens later, you should use a known dtype string like "float32" or retrieve it from the TFLite tensor wrapper.

Suggested change
relax.const(-1.0, selected_scores.struct_info.dtype),
relax.const(-1.0, "float32"),

selected_scores,
relax.const(-1.0, selected_scores.struct_info.dtype),
)
detection_scores, top_positions = relax.op.topk(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The relax.op.topk operator returns a relax.Call object, which is not iterable. Unpacking it directly will fail.

Suggested change
detection_scores, top_positions = relax.op.topk(
res_topk_scores = relax.op.topk(
masked_selected_scores, k=max_detections, axis=1, ret_type="both", largest=True
)
detection_scores = res_topk_scores[0]
top_positions = res_topk_scores[1]

detection_mask,
detection_boxes,
relax.op.zeros(
(batch_size, max_detections, 4), dtype=detection_boxes.struct_info.dtype
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Accessing detection_boxes.struct_info.dtype will fail as struct_info is None at this stage. Use "float32" instead.

Suggested change
(batch_size, max_detections, 4), dtype=detection_boxes.struct_info.dtype
(batch_size, max_detections, 4), dtype="float32"

Comment on lines +3360 to +3361
detection_positions = relax.op.expand_dims(
relax.op.arange(max_detections, dtype="int64"), axis=0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The detection_positions tensor used to create the valid_detection_mask must have the same shape as the second dimension of selected_scores (which is num_total_boxes_per_batch) to be used in relax.op.where. Using max_detections here will cause a shape mismatch error if max_detections is different from the total number of boxes returned by NMS.

        num_total_boxes_per_batch = anchor_boxes * (num_classes if use_regular_nms else 1)
        detection_positions = relax.op.expand_dims(
            relax.op.arange(num_total_boxes_per_batch, dtype="int64"), axis=0
        )

)
top_box_ids = relax.op.squeeze(
relax.op.strided_slice(
top_index_pairs, begin=[0, 0, 1], end=[batch_size, max_detections, 2]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using batch_size in the end parameter of strided_slice can be problematic if the model has a dynamic batch size (where batch_size might be -1). It is safer to use the axes parameter to slice only the specific dimensions you need, leaving the batch dimension untouched.

Suggested change
top_index_pairs, begin=[0, 0, 1], end=[batch_size, max_detections, 2]
top_index_pairs, begin=[0, 1], end=[max_detections, 2], axes=[1, 2]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant