Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 143 additions & 56 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2826,9 +2826,7 @@ def convert_batch_matmul(self, op):
new_b_shape = [1] * max(0, rank_a - rank_b) + [int(s) for s in shape_b]
max_rank = max(rank_a, rank_b)

batch_shape = [
max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 2)
]
batch_shape = [max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 2)]

a_broadcast = batch_shape + [int(shape_a[-2]), int(shape_a[-1])]
b_broadcast = batch_shape + [int(shape_b[-2]), int(shape_b[-1])]
Expand Down Expand Up @@ -3204,16 +3202,49 @@ def convert_dequantize(self, op):

def convert_detection_postprocess(self, op):
"""Convert TFLite_Detection_PostProcess"""
raise NotImplementedError(
"DETECTION_POSTPROCESS is not wired in this frontend yet: it still needs "
"Relax NMS / get_valid_counts / related vision helpers (see dead code below). "
"relax.vision.multibox_transform_loc exists; tracking: "
"https://github.com/apache/tvm/issues/18928"
)
flexbuffer = op.CustomOptionsAsNumpy().tobytes()
custom_options = FlexBufferDecoder(flexbuffer).decode()

use_regular_nms = "use_regular_nms" in custom_options and custom_options["use_regular_nms"]
use_regular_nms = bool(custom_options.get("use_regular_nms", False))

required_attrs = [
"num_classes",
"max_detections",
"detections_per_class",
"nms_iou_threshold",
"nms_score_threshold",
"x_scale",
"y_scale",
"w_scale",
"h_scale",
]
missing_attrs = [key for key in required_attrs if key not in custom_options]
if missing_attrs:
raise ValueError(
"DETECTION_POSTPROCESS custom options miss required attributes: "
+ ", ".join(missing_attrs)
)

num_classes = int(custom_options["num_classes"])
max_detections = int(custom_options["max_detections"])
detections_per_class = int(custom_options["detections_per_class"])
iou_threshold = float(custom_options["nms_iou_threshold"])
score_threshold = float(custom_options["nms_score_threshold"])
x_scale = float(custom_options["x_scale"])
y_scale = float(custom_options["y_scale"])
w_scale = float(custom_options["w_scale"])
h_scale = float(custom_options["h_scale"])

if num_classes <= 0:
raise ValueError("DETECTION_POSTPROCESS requires num_classes > 0.")
if max_detections <= 0:
raise ValueError("DETECTION_POSTPROCESS requires max_detections > 0.")
if detections_per_class <= 0:
raise ValueError("DETECTION_POSTPROCESS requires detections_per_class > 0.")
if not 0.0 <= iou_threshold <= 1.0:
raise ValueError("DETECTION_POSTPROCESS requires nms_iou_threshold in [0, 1].")
if x_scale <= 0.0 or y_scale <= 0.0 or w_scale <= 0.0 or h_scale <= 0.0:
raise ValueError("DETECTION_POSTPROCESS requires x/y/w/h_scale to be > 0.")

inputs = self.get_input_tensors(op)
assert len(inputs) == 3, "inputs length should be 3"
Expand Down Expand Up @@ -3275,18 +3306,16 @@ def convert_detection_postprocess(self, op):
# attributes for multibox_transform_loc
multibox_transform_loc_attrs = {}
multibox_transform_loc_attrs["clip"] = False
multibox_transform_loc_attrs["threshold"] = (
0.0 if use_regular_nms else custom_options["nms_score_threshold"]
)
multibox_transform_loc_attrs["threshold"] = 0.0 if use_regular_nms else score_threshold
multibox_transform_loc_attrs["variances"] = (
1 / custom_options["x_scale"],
1 / custom_options["y_scale"],
1 / custom_options["w_scale"],
1 / custom_options["h_scale"],
1 / x_scale,
1 / y_scale,
1 / w_scale,
1 / h_scale,
)
multibox_transform_loc_attrs["keep_background"] = use_regular_nms

ret = relax.op.vision.multibox_transform_loc(
transformed_boxes, transformed_scores = relax.op.vision.multibox_transform_loc(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The relax.op.vision.multibox_transform_loc operator returns a relax.Call object, which is not iterable in Python. Attempting to unpack it into transformed_boxes, transformed_scores will raise a TypeError. You should assign the result to a single variable and then index it to retrieve the individual outputs.

        res = relax.op.vision.multibox_transform_loc(
            # reshape cls_pred so it can be consumed by
            # multibox_transform_loc
            relax.op.permute_dims(cls_pred, [0, 2, 1]),
            loc_prob,
            anchor_expr,
            **multibox_transform_loc_attrs,
        )
        transformed_boxes = res[0]
        transformed_scores = res[1]

# reshape cls_pred so it can be consumed by
# multibox_transform_loc
relax.op.permute_dims(cls_pred, [0, 2, 1]),
Expand All @@ -3296,46 +3325,104 @@ def convert_detection_postprocess(self, op):
)

if use_regular_nms:
# box coordinates need to be converted from ltrb to (ymin, xmin, ymax, xmax)
_, transformed_boxes = relax.op.split(ret[0], (2,), axis=2)
box_l, box_t, box_r, box_b = relax.op.split(transformed_boxes, 4, axis=2)
transformed_boxes = relax.op.concat([box_t, box_l, box_b, box_r], axis=2)

return relax.op.vision.regular_non_max_suppression(
boxes=transformed_boxes,
scores=cls_pred,
max_detections_per_class=custom_options["detections_per_class"],
max_detections=custom_options["max_detections"],
num_classes=custom_options["num_classes"],
iou_threshold=custom_options["nms_iou_threshold"],
score_threshold=custom_options["nms_score_threshold"],
)

# attributes for non_max_suppression
non_max_suppression_attrs = {}
non_max_suppression_attrs["return_indices"] = False
non_max_suppression_attrs["iou_threshold"] = custom_options["nms_iou_threshold"]
non_max_suppression_attrs["force_suppress"] = True
non_max_suppression_attrs["top_k"] = anchor_boxes
non_max_suppression_attrs["max_output_size"] = custom_options["max_detections"]
non_max_suppression_attrs["invalid_to_bottom"] = False

ret = relax.op.vision.non_max_suppression(
ret[0], ret[1], ret[1], **non_max_suppression_attrs
nms_out = relax.op.vision.all_class_non_max_suppression(
transformed_boxes,
transformed_scores,
relax.const(detections_per_class, "int64"),
relax.const(iou_threshold, "float32"),
relax.const(score_threshold, "float32"),
output_format="tensorflow",
)
selected_indices = nms_out[0]
selected_scores = nms_out[1]
num_detections = nms_out[2]
class_id_from_score = None
else:
max_scores, class_id_from_score = relax.op.topk(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The relax.op.topk operator returns a relax.Call object, which is not iterable. Unpacking it directly into max_scores, class_id_from_score will fail. Please assign the result to a variable and index it.

            res_topk = relax.op.topk(
                transformed_scores, k=1, axis=1, ret_type="both", largest=True
            )
            max_scores = res_topk[0]
            class_id_from_score = res_topk[1]

transformed_scores, k=1, axis=1, ret_type="both", largest=True
)
nms_out = relax.op.vision.all_class_non_max_suppression(
transformed_boxes,
max_scores,
relax.const(max_detections, "int64"),
relax.const(iou_threshold, "float32"),
relax.const(score_threshold, "float32"),
output_format="tensorflow",
)
selected_indices = nms_out[0]
selected_scores = nms_out[1]
num_detections = nms_out[2]
class_id_from_score = relax.op.squeeze(class_id_from_score, axis=[1])

num_detections = relax.op.minimum(
num_detections, relax.const(np.array([max_detections], dtype="int64"))
)
detection_positions = relax.op.expand_dims(
relax.op.arange(max_detections, dtype="int64"), axis=0
Comment on lines +3360 to +3361
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The detection_positions tensor used to create the valid_detection_mask must have the same shape as the second dimension of selected_scores (which is num_total_boxes_per_batch) to be used in relax.op.where. Using max_detections here will cause a shape mismatch error if max_detections is different from the total number of boxes returned by NMS.

        num_total_boxes_per_batch = anchor_boxes * (num_classes if use_regular_nms else 1)
        detection_positions = relax.op.expand_dims(
            relax.op.arange(num_total_boxes_per_batch, dtype="int64"), axis=0
        )

)
valid_detection_mask = relax.op.less(
detection_positions, relax.op.expand_dims(num_detections, axis=1)
)
masked_selected_scores = relax.op.where(
valid_detection_mask,
selected_scores,
relax.const(-1.0, selected_scores.struct_info.dtype),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Accessing selected_scores.struct_info.dtype will raise an AttributeError because struct_info is not populated on relax.Expr objects until they are normalized by the BlockBuilder. Since this code is inside the conversion function and normalization happens later, you should use a known dtype string like "float32" or retrieve it from the TFLite tensor wrapper.

Suggested change
relax.const(-1.0, selected_scores.struct_info.dtype),
relax.const(-1.0, "float32"),

)
detection_scores, top_positions = relax.op.topk(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The relax.op.topk operator returns a relax.Call object, which is not iterable. Unpacking it directly will fail.

Suggested change
detection_scores, top_positions = relax.op.topk(
res_topk_scores = relax.op.topk(
masked_selected_scores, k=max_detections, axis=1, ret_type="both", largest=True
)
detection_scores = res_topk_scores[0]
top_positions = res_topk_scores[1]

masked_selected_scores, k=max_detections, axis=1, ret_type="both", largest=True
)
top_positions_expanded = relax.op.expand_dims(top_positions, axis=2)
top_positions_for_pairs = relax.op.repeat(top_positions_expanded, 2, axis=2)
top_index_pairs = relax.op.gather_elements(
selected_indices, top_positions_for_pairs, axis=1
)
top_box_ids = relax.op.squeeze(
relax.op.strided_slice(
top_index_pairs, begin=[0, 0, 1], end=[batch_size, max_detections, 2]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using batch_size in the end parameter of strided_slice can be problematic if the model has a dynamic batch size (where batch_size might be -1). It is safer to use the axes parameter to slice only the specific dimensions you need, leaving the batch dimension untouched.

Suggested change
top_index_pairs, begin=[0, 0, 1], end=[batch_size, max_detections, 2]
top_index_pairs, begin=[0, 1], end=[max_detections, 2], axes=[1, 2]

),
axis=[2],
)
top_box_ids_for_gather = relax.op.expand_dims(relax.op.astype(top_box_ids, "int64"), axis=2)
detection_boxes = relax.op.gather_nd(
transformed_boxes, top_box_ids_for_gather, batch_dims=1
)

if use_regular_nms:
detection_classes = relax.op.squeeze(
relax.op.strided_slice(
top_index_pairs, begin=[0, 0, 0], end=[batch_size, max_detections, 1]
),
axis=[2],
)
else:
top_box_ids_for_class = relax.op.expand_dims(
relax.op.astype(top_box_ids, "int64"), axis=2
)
detection_classes = relax.op.gather_nd(
class_id_from_score, top_box_ids_for_class, batch_dims=1
)

detection_mask = relax.op.expand_dims(valid_detection_mask, axis=2)
detection_boxes = relax.op.where(
detection_mask,
detection_boxes,
relax.op.zeros(
(batch_size, max_detections, 4), dtype=detection_boxes.struct_info.dtype
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Accessing detection_boxes.struct_info.dtype will fail as struct_info is None at this stage. Use "float32" instead.

Suggested change
(batch_size, max_detections, 4), dtype=detection_boxes.struct_info.dtype
(batch_size, max_detections, 4), dtype="float32"

),
)
detection_classes = relax.op.where(
valid_detection_mask,
detection_classes,
relax.op.zeros((batch_size, max_detections), dtype=detection_classes.struct_info.dtype),
)
ret = relax.op.vision.get_valid_counts(ret, 0)
valid_count = ret[0]
# keep only the top 'max_detections' rows
ret = relax.op.strided_slice(
ret[1], [0, 0, 0], [batch_size, custom_options["max_detections"], 6]
detection_scores = relax.op.where(
valid_detection_mask,
detection_scores,
relax.op.zeros((batch_size, max_detections), dtype=detection_scores.struct_info.dtype),
)
# the output needs some reshaping to match tflite
ret = relax.op.split(ret, 6, axis=2)
cls_ids = relax.op.reshape(ret[0], [batch_size, -1])
scores = relax.op.reshape(ret[1], [batch_size, -1])
boxes = relax.op.concat([ret[3], ret[2], ret[5], ret[4]], axis=2)
ret = relax.Tuple(relax.Tuple([boxes, cls_ids, scores, valid_count]), size=4)
return ret
detection_classes = relax.op.astype(detection_classes, "float32")
num_detections = relax.op.astype(num_detections, "float32")
return relax.Tuple([detection_boxes, detection_classes, detection_scores, num_detections])

def convert_nms_v5(self, op):
"""Convert TFLite NonMaxSuppressionV5"""
Expand Down
Loading