diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index e7d81cf5fea0..f1f91002c4eb 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -700,22 +700,12 @@ def main(x: R.Tensor((5, 30), dtype="float32")) -> R.Tensor(out_shape, dtype="in verify(TfInput, Expected) -@pytest.mark.parametrize( - "data, kernel, data_format, strides, padding", - [ - # Tf on CI (CPU) support only NHWC - ((1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "SAME"), - ((1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "VALID"), - # ((1, 32, 128, 128), (3, 3, 32, 32), "NCHW", (1, 1, 1, 1), "SAME"), - # ((1, 32, 128, 128), (3, 3, 32, 32), "NCHW", (1, 1, 1, 1), "VALID"), - ], -) -def test_conv2d(data, kernel, data_format, strides, padding): +def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, padding): class Conv2DModule(tf.Module): @tf.function( input_signature=[ - tf.TensorSpec(shape=data, dtype=tf.float32), - tf.TensorSpec(shape=kernel, dtype=tf.float32), + tf.TensorSpec(shape=data_shape, dtype=tf.float32), + tf.TensorSpec(shape=kernel_shape, dtype=tf.float32), ] ) def func(self, data, kernel): @@ -727,39 +717,180 @@ def func(self, data, kernel): padding=padding, ) - verify(Conv2DModule) + return Conv2DModule -@pytest.mark.parametrize( - "pool", - [tf.nn.avg_pool2d, tf.nn.max_pool2d], -) -@pytest.mark.parametrize( - "data, kernel, data_format, strides, padding", - [ - # Tf on CI (CPU) support only NHWC - ((1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "SAME"), - ((1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "VALID"), - # ((1, 32, 128, 128), (3, 3), "NCHW", (1, 1, 1, 1), "SAME"), - # ((1, 32, 128, 128), (3, 3), "NCHW", (1, 1, 1, 1), "VALID"), - ], -) -def test_pool_2d(pool, data, kernel, data_format, strides, padding): +def test_conv2d_same(): + Conv2DModule = _make_conv2d_module( + (1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "SAME" + ) + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((1, 128, 128, 32), dtype="float32"), + kernel: R.Tensor((3, 3, 32, 32), dtype="float32"), + ) -> R.Tensor((1, 128, 128, 32), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((32, 3, 3, 32), dtype="float32") = R.permute_dims( + kernel, axes=[3, 0, 1, 2] + ) + lv1: R.Tensor((3, 3, 32, 32), dtype="float32") = R.permute_dims( + lv, axes=[1, 2, 3, 0] + ) + lv2: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.conv2d( + data, + lv1, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.add( + lv2, R.const(np.zeros((32,), dtype="float32")) + ) + R.output(gv) + return gv + + verify(Conv2DModule, Expected) + + +def test_conv2d_valid(): + Conv2DModule = _make_conv2d_module( + (1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "VALID" + ) + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((1, 128, 128, 32), dtype="float32"), + kernel: R.Tensor((3, 3, 32, 32), dtype="float32"), + ) -> R.Tensor((1, 126, 126, 32), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((32, 3, 3, 32), dtype="float32") = R.permute_dims( + kernel, axes=[3, 0, 1, 2] + ) + lv1: R.Tensor((3, 3, 32, 32), dtype="float32") = R.permute_dims( + lv, axes=[1, 2, 3, 0] + ) + lv2: R.Tensor((1, 126, 126, 32), dtype="float32") = R.nn.conv2d( + data, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + gv: R.Tensor((1, 126, 126, 32), dtype="float32") = R.add( + lv2, R.const(np.zeros((32,), dtype="float32")) + ) + R.output(gv) + return gv + + verify(Conv2DModule, Expected) + + +def _make_pool2d_module(pool, data_shape, ksize, data_format, strides, padding): class Pool2DModule(tf.Module): @tf.function( input_signature=[ - tf.TensorSpec(shape=data, dtype=tf.float32), + tf.TensorSpec(shape=data_shape, dtype=tf.float32), ] ) def func(self, data): return pool( input=data, - ksize=kernel, + ksize=ksize, data_format=data_format, strides=strides, padding=padding, ) + return Pool2DModule + + +def test_avg_pool2d_same(): + Pool2DModule = _make_pool2d_module( + tf.nn.avg_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "SAME" + ) + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((1, 128, 128, 32), dtype="float32"), + ) -> R.Tensor((1, 128, 128, 32), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.avg_pool2d( + data, + pool_size=[2, 2], + strides=[1, 1], + dilation=[1, 1], + padding=[0, 0, 1, 1], + ceil_mode=False, + count_include_pad=False, + layout="NHWC", + out_layout="NHWC", + ) + R.output(gv) + return gv + + verify(Pool2DModule, Expected) + + +def test_avg_pool2d_valid(): + Pool2DModule = _make_pool2d_module( + tf.nn.avg_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "VALID" + ) + verify(Pool2DModule) + + +def test_max_pool2d_same(): + Pool2DModule = _make_pool2d_module( + tf.nn.max_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "SAME" + ) + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((1, 128, 128, 32), dtype="float32"), + ) -> R.Tensor((1, 128, 128, 32), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.max_pool2d( + data, + pool_size=[2, 2], + strides=[1, 1], + dilation=[1, 1], + padding=[0, 0, 1, 1], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + R.output(gv) + return gv + + verify(Pool2DModule, Expected) + + +def test_max_pool2d_valid(): + Pool2DModule = _make_pool2d_module( + tf.nn.max_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "VALID" + ) verify(Pool2DModule) @@ -836,7 +967,21 @@ class BatchMatMul(tf.Module): def func(self, x, y): return tf.matmul(x, y) - verify(BatchMatMul) + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 4), dtype="float32"), + y: R.Tensor((2, 4, 5), dtype="float32"), + ) -> R.Tensor((2, 3, 5), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((2, 3, 5), dtype="float32") = R.matmul(x, y, out_dtype="void") + gv: R.Tensor((2, 3, 5), dtype="float32") = R.reshape(lv, R.shape([2, 3, 5])) + R.output(gv) + return gv + + verify(BatchMatMul, Expected) def test_batch_matmul_adj(): @@ -850,7 +995,23 @@ class BatchMatMulAdj(tf.Module): def func(self, x, y): return tf.matmul(x, y, transpose_a=True, transpose_b=True) - verify(BatchMatMulAdj) + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 4, 3), dtype="float32"), + y: R.Tensor((2, 5, 4), dtype="float32"), + ) -> R.Tensor((2, 3, 5), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((2, 3, 4), dtype="float32") = R.permute_dims(x, axes=[0, 2, 1]) + lv1: R.Tensor((2, 4, 5), dtype="float32") = R.permute_dims(y, axes=[0, 2, 1]) + lv2: R.Tensor((2, 3, 5), dtype="float32") = R.matmul(lv, lv1, out_dtype="void") + gv: R.Tensor((2, 3, 5), dtype="float32") = R.reshape(lv2, R.shape([2, 3, 5])) + R.output(gv) + return gv + + verify(BatchMatMulAdj, Expected) if __name__ == "__main__":