diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index dd43c9030b75..05ca3e42703a 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -40,7 +40,11 @@ def _get_mod_from_cfunc(cfunc): tf.lite.OpsSet.SELECT_TF_OPS, ] - tflite_model = tflite.Model.Model.GetRootAsModel(converter.convert(), 0) + tflite_model_buf = converter.convert() + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) mod = from_tflite(tflite_model) mod["main"] = mod["main"].without_attr("params") return mod