diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 435180dfee01..ef16311ef9f8 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -2826,9 +2826,7 @@ def convert_batch_matmul(self, op): new_b_shape = [1] * max(0, rank_a - rank_b) + [int(s) for s in shape_b] max_rank = max(rank_a, rank_b) - batch_shape = [ - max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 2) - ] + batch_shape = [max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 2)] a_broadcast = batch_shape + [int(shape_a[-2]), int(shape_a[-1])] b_broadcast = batch_shape + [int(shape_b[-2]), int(shape_b[-1])] @@ -3204,16 +3202,49 @@ def convert_dequantize(self, op): def convert_detection_postprocess(self, op): """Convert TFLite_Detection_PostProcess""" - raise NotImplementedError( - "DETECTION_POSTPROCESS is not wired in this frontend yet: it still needs " - "Relax NMS / get_valid_counts / related vision helpers (see dead code below). " - "relax.vision.multibox_transform_loc exists; tracking: " - "https://github.com/apache/tvm/issues/18928" - ) flexbuffer = op.CustomOptionsAsNumpy().tobytes() custom_options = FlexBufferDecoder(flexbuffer).decode() - use_regular_nms = "use_regular_nms" in custom_options and custom_options["use_regular_nms"] + use_regular_nms = bool(custom_options.get("use_regular_nms", False)) + + required_attrs = [ + "num_classes", + "max_detections", + "detections_per_class", + "nms_iou_threshold", + "nms_score_threshold", + "x_scale", + "y_scale", + "w_scale", + "h_scale", + ] + missing_attrs = [key for key in required_attrs if key not in custom_options] + if missing_attrs: + raise ValueError( + "DETECTION_POSTPROCESS custom options miss required attributes: " + + ", ".join(missing_attrs) + ) + + num_classes = int(custom_options["num_classes"]) + max_detections = int(custom_options["max_detections"]) + detections_per_class = int(custom_options["detections_per_class"]) + iou_threshold = float(custom_options["nms_iou_threshold"]) + score_threshold = float(custom_options["nms_score_threshold"]) + x_scale = float(custom_options["x_scale"]) + y_scale = float(custom_options["y_scale"]) + w_scale = float(custom_options["w_scale"]) + h_scale = float(custom_options["h_scale"]) + + if num_classes <= 0: + raise ValueError("DETECTION_POSTPROCESS requires num_classes > 0.") + if max_detections <= 0: + raise ValueError("DETECTION_POSTPROCESS requires max_detections > 0.") + if detections_per_class <= 0: + raise ValueError("DETECTION_POSTPROCESS requires detections_per_class > 0.") + if not 0.0 <= iou_threshold <= 1.0: + raise ValueError("DETECTION_POSTPROCESS requires nms_iou_threshold in [0, 1].") + if x_scale <= 0.0 or y_scale <= 0.0 or w_scale <= 0.0 or h_scale <= 0.0: + raise ValueError("DETECTION_POSTPROCESS requires x/y/w/h_scale to be > 0.") inputs = self.get_input_tensors(op) assert len(inputs) == 3, "inputs length should be 3" @@ -3275,18 +3306,16 @@ def convert_detection_postprocess(self, op): # attributes for multibox_transform_loc multibox_transform_loc_attrs = {} multibox_transform_loc_attrs["clip"] = False - multibox_transform_loc_attrs["threshold"] = ( - 0.0 if use_regular_nms else custom_options["nms_score_threshold"] - ) + multibox_transform_loc_attrs["threshold"] = 0.0 if use_regular_nms else score_threshold multibox_transform_loc_attrs["variances"] = ( - 1 / custom_options["x_scale"], - 1 / custom_options["y_scale"], - 1 / custom_options["w_scale"], - 1 / custom_options["h_scale"], + 1 / x_scale, + 1 / y_scale, + 1 / w_scale, + 1 / h_scale, ) multibox_transform_loc_attrs["keep_background"] = use_regular_nms - ret = relax.op.vision.multibox_transform_loc( + transformed_boxes, transformed_scores = relax.op.vision.multibox_transform_loc( # reshape cls_pred so it can be consumed by # multibox_transform_loc relax.op.permute_dims(cls_pred, [0, 2, 1]), @@ -3296,46 +3325,104 @@ def convert_detection_postprocess(self, op): ) if use_regular_nms: - # box coordinates need to be converted from ltrb to (ymin, xmin, ymax, xmax) - _, transformed_boxes = relax.op.split(ret[0], (2,), axis=2) - box_l, box_t, box_r, box_b = relax.op.split(transformed_boxes, 4, axis=2) - transformed_boxes = relax.op.concat([box_t, box_l, box_b, box_r], axis=2) - - return relax.op.vision.regular_non_max_suppression( - boxes=transformed_boxes, - scores=cls_pred, - max_detections_per_class=custom_options["detections_per_class"], - max_detections=custom_options["max_detections"], - num_classes=custom_options["num_classes"], - iou_threshold=custom_options["nms_iou_threshold"], - score_threshold=custom_options["nms_score_threshold"], - ) - - # attributes for non_max_suppression - non_max_suppression_attrs = {} - non_max_suppression_attrs["return_indices"] = False - non_max_suppression_attrs["iou_threshold"] = custom_options["nms_iou_threshold"] - non_max_suppression_attrs["force_suppress"] = True - non_max_suppression_attrs["top_k"] = anchor_boxes - non_max_suppression_attrs["max_output_size"] = custom_options["max_detections"] - non_max_suppression_attrs["invalid_to_bottom"] = False - - ret = relax.op.vision.non_max_suppression( - ret[0], ret[1], ret[1], **non_max_suppression_attrs + nms_out = relax.op.vision.all_class_non_max_suppression( + transformed_boxes, + transformed_scores, + relax.const(detections_per_class, "int64"), + relax.const(iou_threshold, "float32"), + relax.const(score_threshold, "float32"), + output_format="tensorflow", + ) + selected_indices = nms_out[0] + selected_scores = nms_out[1] + num_detections = nms_out[2] + class_id_from_score = None + else: + max_scores, class_id_from_score = relax.op.topk( + transformed_scores, k=1, axis=1, ret_type="both", largest=True + ) + nms_out = relax.op.vision.all_class_non_max_suppression( + transformed_boxes, + max_scores, + relax.const(max_detections, "int64"), + relax.const(iou_threshold, "float32"), + relax.const(score_threshold, "float32"), + output_format="tensorflow", + ) + selected_indices = nms_out[0] + selected_scores = nms_out[1] + num_detections = nms_out[2] + class_id_from_score = relax.op.squeeze(class_id_from_score, axis=[1]) + + num_detections = relax.op.minimum( + num_detections, relax.const(np.array([max_detections], dtype="int64")) + ) + detection_positions = relax.op.expand_dims( + relax.op.arange(max_detections, dtype="int64"), axis=0 + ) + valid_detection_mask = relax.op.less( + detection_positions, relax.op.expand_dims(num_detections, axis=1) + ) + masked_selected_scores = relax.op.where( + valid_detection_mask, + selected_scores, + relax.const(-1.0, selected_scores.struct_info.dtype), + ) + detection_scores, top_positions = relax.op.topk( + masked_selected_scores, k=max_detections, axis=1, ret_type="both", largest=True + ) + top_positions_expanded = relax.op.expand_dims(top_positions, axis=2) + top_positions_for_pairs = relax.op.repeat(top_positions_expanded, 2, axis=2) + top_index_pairs = relax.op.gather_elements( + selected_indices, top_positions_for_pairs, axis=1 + ) + top_box_ids = relax.op.squeeze( + relax.op.strided_slice( + top_index_pairs, begin=[0, 0, 1], end=[batch_size, max_detections, 2] + ), + axis=[2], + ) + top_box_ids_for_gather = relax.op.expand_dims(relax.op.astype(top_box_ids, "int64"), axis=2) + detection_boxes = relax.op.gather_nd( + transformed_boxes, top_box_ids_for_gather, batch_dims=1 + ) + + if use_regular_nms: + detection_classes = relax.op.squeeze( + relax.op.strided_slice( + top_index_pairs, begin=[0, 0, 0], end=[batch_size, max_detections, 1] + ), + axis=[2], + ) + else: + top_box_ids_for_class = relax.op.expand_dims( + relax.op.astype(top_box_ids, "int64"), axis=2 + ) + detection_classes = relax.op.gather_nd( + class_id_from_score, top_box_ids_for_class, batch_dims=1 + ) + + detection_mask = relax.op.expand_dims(valid_detection_mask, axis=2) + detection_boxes = relax.op.where( + detection_mask, + detection_boxes, + relax.op.zeros( + (batch_size, max_detections, 4), dtype=detection_boxes.struct_info.dtype + ), + ) + detection_classes = relax.op.where( + valid_detection_mask, + detection_classes, + relax.op.zeros((batch_size, max_detections), dtype=detection_classes.struct_info.dtype), ) - ret = relax.op.vision.get_valid_counts(ret, 0) - valid_count = ret[0] - # keep only the top 'max_detections' rows - ret = relax.op.strided_slice( - ret[1], [0, 0, 0], [batch_size, custom_options["max_detections"], 6] + detection_scores = relax.op.where( + valid_detection_mask, + detection_scores, + relax.op.zeros((batch_size, max_detections), dtype=detection_scores.struct_info.dtype), ) - # the output needs some reshaping to match tflite - ret = relax.op.split(ret, 6, axis=2) - cls_ids = relax.op.reshape(ret[0], [batch_size, -1]) - scores = relax.op.reshape(ret[1], [batch_size, -1]) - boxes = relax.op.concat([ret[3], ret[2], ret[5], ret[4]], axis=2) - ret = relax.Tuple(relax.Tuple([boxes, cls_ids, scores, valid_count]), size=4) - return ret + detection_classes = relax.op.astype(detection_classes, "float32") + num_detections = relax.op.astype(num_detections, "float32") + return relax.Tuple([detection_boxes, detection_classes, detection_scores, num_detections]) def convert_nms_v5(self, op): """Convert TFLite NonMaxSuppressionV5"""