From 233c87294b707da2784bdf3ad31b263f8292f569 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Thu, 2 Apr 2026 23:44:21 -0400 Subject: [PATCH 1/4] finish1 --- tests/python/relax/test_frontend_tflite.py | 163 ++++++++++++++++++++- 1 file changed, 159 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index e7d81cf5fea0..b57e20d0f99b 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -727,7 +727,77 @@ def func(self, data, kernel): padding=padding, ) - verify(Conv2DModule) + if padding == "SAME": + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((1, 128, 128, 32), dtype="float32"), + kernel: R.Tensor((3, 3, 32, 32), dtype="float32"), + ) -> R.Tensor((1, 128, 128, 32), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((32, 3, 3, 32), dtype="float32") = R.permute_dims( + kernel, axes=[3, 0, 1, 2] + ) + lv1: R.Tensor((3, 3, 32, 32), dtype="float32") = R.permute_dims( + lv, axes=[1, 2, 3, 0] + ) + lv2: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.conv2d( + data, + lv1, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.add( + lv2, R.const(np.zeros((32,), dtype="float32")) + ) + R.output(gv) + return gv + + else: + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((1, 128, 128, 32), dtype="float32"), + kernel: R.Tensor((3, 3, 32, 32), dtype="float32"), + ) -> R.Tensor((1, 126, 126, 32), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((32, 3, 3, 32), dtype="float32") = R.permute_dims( + kernel, axes=[3, 0, 1, 2] + ) + lv1: R.Tensor((3, 3, 32, 32), dtype="float32") = R.permute_dims( + lv, axes=[1, 2, 3, 0] + ) + lv2: R.Tensor((1, 126, 126, 32), dtype="float32") = R.nn.conv2d( + data, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + gv: R.Tensor((1, 126, 126, 32), dtype="float32") = R.add( + lv2, R.const(np.zeros((32,), dtype="float32")) + ) + R.output(gv) + return gv + + verify(Conv2DModule, Expected) @pytest.mark.parametrize( @@ -760,7 +830,62 @@ def func(self, data): padding=padding, ) - verify(Pool2DModule) + is_avg = pool == tf.nn.avg_pool2d + if padding == "SAME": + out_shape = (1, 128, 128, 32) + pad = [0, 0, 1, 1] + else: + out_shape = (1, 127, 127, 32) + pad = [0, 0, 0, 0] + + if is_avg: + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor(out_shape, dtype="float32"), + ) -> R.Tensor(out_shape, dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor(out_shape, dtype="float32") = R.nn.avg_pool2d( + data, + pool_size=[2, 2], + strides=[1, 1], + dilation=[1, 1], + padding=pad, + ceil_mode=False, + count_include_pad=False, + layout="NHWC", + out_layout="NHWC", + ) + R.output(gv) + return gv + + else: + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor(out_shape, dtype="float32"), + ) -> R.Tensor(out_shape, dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor(out_shape, dtype="float32") = R.nn.max_pool2d( + data, + pool_size=[2, 2], + strides=[1, 1], + dilation=[1, 1], + padding=pad, + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + R.output(gv) + return gv + + verify(Pool2DModule, Expected) @pytest.mark.parametrize( @@ -836,7 +961,21 @@ class BatchMatMul(tf.Module): def func(self, x, y): return tf.matmul(x, y) - verify(BatchMatMul) + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 4), dtype="float32"), + y: R.Tensor((2, 4, 5), dtype="float32"), + ) -> R.Tensor((2, 3, 5), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((2, 3, 5), dtype="float32") = R.matmul(x, y, out_dtype="void") + gv: R.Tensor((2, 3, 5), dtype="float32") = R.reshape(lv, R.shape([2, 3, 5])) + R.output(gv) + return gv + + verify(BatchMatMul, Expected) def test_batch_matmul_adj(): @@ -850,7 +989,23 @@ class BatchMatMulAdj(tf.Module): def func(self, x, y): return tf.matmul(x, y, transpose_a=True, transpose_b=True) - verify(BatchMatMulAdj) + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 4, 3), dtype="float32"), + y: R.Tensor((2, 5, 4), dtype="float32"), + ) -> R.Tensor((2, 3, 5), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((2, 3, 4), dtype="float32") = R.permute_dims(x, axes=[0, 2, 1]) + lv1: R.Tensor((2, 4, 5), dtype="float32") = R.permute_dims(y, axes=[0, 2, 1]) + lv2: R.Tensor((2, 3, 5), dtype="float32") = R.matmul(lv, lv1, out_dtype="void") + gv: R.Tensor((2, 3, 5), dtype="float32") = R.reshape(lv2, R.shape([2, 3, 5])) + R.output(gv) + return gv + + verify(BatchMatMulAdj, Expected) if __name__ == "__main__": From 1ce3750ac83e74d073561d707232735a2775c7cd Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 3 Apr 2026 10:30:34 -0400 Subject: [PATCH 2/4] finish2 --- tests/python/relax/test_frontend_tflite.py | 70 +++++++++++++++++----- 1 file changed, 55 insertions(+), 15 deletions(-) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index b57e20d0f99b..e2cf1474c1ba 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -831,29 +831,46 @@ def func(self, data): ) is_avg = pool == tf.nn.avg_pool2d - if padding == "SAME": - out_shape = (1, 128, 128, 32) - pad = [0, 0, 1, 1] - else: - out_shape = (1, 127, 127, 32) - pad = [0, 0, 0, 0] + if is_avg and padding == "SAME": + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((1, 128, 128, 32), dtype="float32"), + ) -> R.Tensor((1, 128, 128, 32), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.avg_pool2d( + data, + pool_size=[2, 2], + strides=[1, 1], + dilation=[1, 1], + padding=[0, 0, 1, 1], + ceil_mode=False, + count_include_pad=False, + layout="NHWC", + out_layout="NHWC", + ) + R.output(gv) + return gv - if is_avg: + elif is_avg and padding == "VALID": @I.ir_module class Expected: @R.function def main( - data: R.Tensor(out_shape, dtype="float32"), - ) -> R.Tensor(out_shape, dtype="float32"): + data: R.Tensor((1, 127, 127, 32), dtype="float32"), + ) -> R.Tensor((1, 127, 127, 32), dtype="float32"): R.func_attr({"num_input": 1}) with R.dataflow(): - gv: R.Tensor(out_shape, dtype="float32") = R.nn.avg_pool2d( + gv: R.Tensor((1, 127, 127, 32), dtype="float32") = R.nn.avg_pool2d( data, pool_size=[2, 2], strides=[1, 1], dilation=[1, 1], - padding=pad, + padding=[0, 0, 0, 0], ceil_mode=False, count_include_pad=False, layout="NHWC", @@ -862,22 +879,45 @@ def main( R.output(gv) return gv + elif not is_avg and padding == "SAME": + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((1, 128, 128, 32), dtype="float32"), + ) -> R.Tensor((1, 128, 128, 32), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.max_pool2d( + data, + pool_size=[2, 2], + strides=[1, 1], + dilation=[1, 1], + padding=[0, 0, 1, 1], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + R.output(gv) + return gv + else: @I.ir_module class Expected: @R.function def main( - data: R.Tensor(out_shape, dtype="float32"), - ) -> R.Tensor(out_shape, dtype="float32"): + data: R.Tensor((1, 127, 127, 32), dtype="float32"), + ) -> R.Tensor((1, 127, 127, 32), dtype="float32"): R.func_attr({"num_input": 1}) with R.dataflow(): - gv: R.Tensor(out_shape, dtype="float32") = R.nn.max_pool2d( + gv: R.Tensor((1, 127, 127, 32), dtype="float32") = R.nn.max_pool2d( data, pool_size=[2, 2], strides=[1, 1], dilation=[1, 1], - padding=pad, + padding=[0, 0, 0, 0], ceil_mode=False, layout="NHWC", out_layout="NHWC", From c5f073d5c2cc81f8f8d4c592dca0e1222bb48ad6 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 3 Apr 2026 13:39:17 -0400 Subject: [PATCH 3/4] finish3 --- tests/python/relax/test_frontend_tflite.py | 24 ++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index e2cf1474c1ba..915d1f4f00f9 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -730,7 +730,7 @@ def func(self, data, kernel): if padding == "SAME": @I.ir_module - class Expected: + class ExpectedSame: @R.function def main( data: R.Tensor((1, 128, 128, 32), dtype="float32"), @@ -762,10 +762,12 @@ def main( R.output(gv) return gv + verify(Conv2DModule, ExpectedSame) + else: @I.ir_module - class Expected: + class ExpectedValid: @R.function def main( data: R.Tensor((1, 128, 128, 32), dtype="float32"), @@ -797,7 +799,7 @@ def main( R.output(gv) return gv - verify(Conv2DModule, Expected) + verify(Conv2DModule, ExpectedValid) @pytest.mark.parametrize( @@ -834,7 +836,7 @@ def func(self, data): if is_avg and padding == "SAME": @I.ir_module - class Expected: + class ExpectedAvgSame: @R.function def main( data: R.Tensor((1, 128, 128, 32), dtype="float32"), @@ -855,10 +857,12 @@ def main( R.output(gv) return gv + verify(Pool2DModule, ExpectedAvgSame) + elif is_avg and padding == "VALID": @I.ir_module - class Expected: + class ExpectedAvgValid: @R.function def main( data: R.Tensor((1, 127, 127, 32), dtype="float32"), @@ -879,10 +883,12 @@ def main( R.output(gv) return gv + verify(Pool2DModule, ExpectedAvgValid) + elif not is_avg and padding == "SAME": @I.ir_module - class Expected: + class ExpectedMaxSame: @R.function def main( data: R.Tensor((1, 128, 128, 32), dtype="float32"), @@ -902,10 +908,12 @@ def main( R.output(gv) return gv + verify(Pool2DModule, ExpectedMaxSame) + else: @I.ir_module - class Expected: + class ExpectedMaxValid: @R.function def main( data: R.Tensor((1, 127, 127, 32), dtype="float32"), @@ -925,7 +933,7 @@ def main( R.output(gv) return gv - verify(Pool2DModule, Expected) + verify(Pool2DModule, ExpectedMaxValid) @pytest.mark.parametrize( From 93b7120fce1ccae2f4243491564b54659dd62221 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 3 Apr 2026 15:45:25 -0400 Subject: [PATCH 4/4] finish4 --- tests/python/relax/test_frontend_tflite.py | 360 +++++++++------------ 1 file changed, 159 insertions(+), 201 deletions(-) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 915d1f4f00f9..f1f91002c4eb 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -700,22 +700,12 @@ def main(x: R.Tensor((5, 30), dtype="float32")) -> R.Tensor(out_shape, dtype="in verify(TfInput, Expected) -@pytest.mark.parametrize( - "data, kernel, data_format, strides, padding", - [ - # Tf on CI (CPU) support only NHWC - ((1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "SAME"), - ((1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "VALID"), - # ((1, 32, 128, 128), (3, 3, 32, 32), "NCHW", (1, 1, 1, 1), "SAME"), - # ((1, 32, 128, 128), (3, 3, 32, 32), "NCHW", (1, 1, 1, 1), "VALID"), - ], -) -def test_conv2d(data, kernel, data_format, strides, padding): +def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, padding): class Conv2DModule(tf.Module): @tf.function( input_signature=[ - tf.TensorSpec(shape=data, dtype=tf.float32), - tf.TensorSpec(shape=kernel, dtype=tf.float32), + tf.TensorSpec(shape=data_shape, dtype=tf.float32), + tf.TensorSpec(shape=kernel_shape, dtype=tf.float32), ] ) def func(self, data, kernel): @@ -727,213 +717,181 @@ def func(self, data, kernel): padding=padding, ) - if padding == "SAME": - - @I.ir_module - class ExpectedSame: - @R.function - def main( - data: R.Tensor((1, 128, 128, 32), dtype="float32"), - kernel: R.Tensor((3, 3, 32, 32), dtype="float32"), - ) -> R.Tensor((1, 128, 128, 32), dtype="float32"): - R.func_attr({"num_input": 2}) - with R.dataflow(): - lv: R.Tensor((32, 3, 3, 32), dtype="float32") = R.permute_dims( - kernel, axes=[3, 0, 1, 2] - ) - lv1: R.Tensor((3, 3, 32, 32), dtype="float32") = R.permute_dims( - lv, axes=[1, 2, 3, 0] - ) - lv2: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.conv2d( - data, - lv1, - strides=[1, 1], - padding=[1, 1, 1, 1], - dilation=[1, 1], - groups=1, - data_layout="NHWC", - kernel_layout="HWIO", - out_layout="NHWC", - out_dtype="void", - ) - gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.add( - lv2, R.const(np.zeros((32,), dtype="float32")) - ) - R.output(gv) - return gv - - verify(Conv2DModule, ExpectedSame) + return Conv2DModule - else: - @I.ir_module - class ExpectedValid: - @R.function - def main( - data: R.Tensor((1, 128, 128, 32), dtype="float32"), - kernel: R.Tensor((3, 3, 32, 32), dtype="float32"), - ) -> R.Tensor((1, 126, 126, 32), dtype="float32"): - R.func_attr({"num_input": 2}) - with R.dataflow(): - lv: R.Tensor((32, 3, 3, 32), dtype="float32") = R.permute_dims( - kernel, axes=[3, 0, 1, 2] - ) - lv1: R.Tensor((3, 3, 32, 32), dtype="float32") = R.permute_dims( - lv, axes=[1, 2, 3, 0] - ) - lv2: R.Tensor((1, 126, 126, 32), dtype="float32") = R.nn.conv2d( - data, - lv1, - strides=[1, 1], - padding=[0, 0, 0, 0], - dilation=[1, 1], - groups=1, - data_layout="NHWC", - kernel_layout="HWIO", - out_layout="NHWC", - out_dtype="void", - ) - gv: R.Tensor((1, 126, 126, 32), dtype="float32") = R.add( - lv2, R.const(np.zeros((32,), dtype="float32")) - ) - R.output(gv) - return gv - - verify(Conv2DModule, ExpectedValid) +def test_conv2d_same(): + Conv2DModule = _make_conv2d_module( + (1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "SAME" + ) + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((1, 128, 128, 32), dtype="float32"), + kernel: R.Tensor((3, 3, 32, 32), dtype="float32"), + ) -> R.Tensor((1, 128, 128, 32), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((32, 3, 3, 32), dtype="float32") = R.permute_dims( + kernel, axes=[3, 0, 1, 2] + ) + lv1: R.Tensor((3, 3, 32, 32), dtype="float32") = R.permute_dims( + lv, axes=[1, 2, 3, 0] + ) + lv2: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.conv2d( + data, + lv1, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.add( + lv2, R.const(np.zeros((32,), dtype="float32")) + ) + R.output(gv) + return gv -@pytest.mark.parametrize( - "pool", - [tf.nn.avg_pool2d, tf.nn.max_pool2d], -) -@pytest.mark.parametrize( - "data, kernel, data_format, strides, padding", - [ - # Tf on CI (CPU) support only NHWC - ((1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "SAME"), - ((1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "VALID"), - # ((1, 32, 128, 128), (3, 3), "NCHW", (1, 1, 1, 1), "SAME"), - # ((1, 32, 128, 128), (3, 3), "NCHW", (1, 1, 1, 1), "VALID"), - ], -) -def test_pool_2d(pool, data, kernel, data_format, strides, padding): + verify(Conv2DModule, Expected) + + +def test_conv2d_valid(): + Conv2DModule = _make_conv2d_module( + (1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "VALID" + ) + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((1, 128, 128, 32), dtype="float32"), + kernel: R.Tensor((3, 3, 32, 32), dtype="float32"), + ) -> R.Tensor((1, 126, 126, 32), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((32, 3, 3, 32), dtype="float32") = R.permute_dims( + kernel, axes=[3, 0, 1, 2] + ) + lv1: R.Tensor((3, 3, 32, 32), dtype="float32") = R.permute_dims( + lv, axes=[1, 2, 3, 0] + ) + lv2: R.Tensor((1, 126, 126, 32), dtype="float32") = R.nn.conv2d( + data, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + gv: R.Tensor((1, 126, 126, 32), dtype="float32") = R.add( + lv2, R.const(np.zeros((32,), dtype="float32")) + ) + R.output(gv) + return gv + + verify(Conv2DModule, Expected) + + +def _make_pool2d_module(pool, data_shape, ksize, data_format, strides, padding): class Pool2DModule(tf.Module): @tf.function( input_signature=[ - tf.TensorSpec(shape=data, dtype=tf.float32), + tf.TensorSpec(shape=data_shape, dtype=tf.float32), ] ) def func(self, data): return pool( input=data, - ksize=kernel, + ksize=ksize, data_format=data_format, strides=strides, padding=padding, ) - is_avg = pool == tf.nn.avg_pool2d - if is_avg and padding == "SAME": - - @I.ir_module - class ExpectedAvgSame: - @R.function - def main( - data: R.Tensor((1, 128, 128, 32), dtype="float32"), - ) -> R.Tensor((1, 128, 128, 32), dtype="float32"): - R.func_attr({"num_input": 1}) - with R.dataflow(): - gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.avg_pool2d( - data, - pool_size=[2, 2], - strides=[1, 1], - dilation=[1, 1], - padding=[0, 0, 1, 1], - ceil_mode=False, - count_include_pad=False, - layout="NHWC", - out_layout="NHWC", - ) - R.output(gv) - return gv - - verify(Pool2DModule, ExpectedAvgSame) - - elif is_avg and padding == "VALID": - - @I.ir_module - class ExpectedAvgValid: - @R.function - def main( - data: R.Tensor((1, 127, 127, 32), dtype="float32"), - ) -> R.Tensor((1, 127, 127, 32), dtype="float32"): - R.func_attr({"num_input": 1}) - with R.dataflow(): - gv: R.Tensor((1, 127, 127, 32), dtype="float32") = R.nn.avg_pool2d( - data, - pool_size=[2, 2], - strides=[1, 1], - dilation=[1, 1], - padding=[0, 0, 0, 0], - ceil_mode=False, - count_include_pad=False, - layout="NHWC", - out_layout="NHWC", - ) - R.output(gv) - return gv - - verify(Pool2DModule, ExpectedAvgValid) - - elif not is_avg and padding == "SAME": - - @I.ir_module - class ExpectedMaxSame: - @R.function - def main( - data: R.Tensor((1, 128, 128, 32), dtype="float32"), - ) -> R.Tensor((1, 128, 128, 32), dtype="float32"): - R.func_attr({"num_input": 1}) - with R.dataflow(): - gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.max_pool2d( - data, - pool_size=[2, 2], - strides=[1, 1], - dilation=[1, 1], - padding=[0, 0, 1, 1], - ceil_mode=False, - layout="NHWC", - out_layout="NHWC", - ) - R.output(gv) - return gv - - verify(Pool2DModule, ExpectedMaxSame) + return Pool2DModule + + +def test_avg_pool2d_same(): + Pool2DModule = _make_pool2d_module( + tf.nn.avg_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "SAME" + ) + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((1, 128, 128, 32), dtype="float32"), + ) -> R.Tensor((1, 128, 128, 32), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.avg_pool2d( + data, + pool_size=[2, 2], + strides=[1, 1], + dilation=[1, 1], + padding=[0, 0, 1, 1], + ceil_mode=False, + count_include_pad=False, + layout="NHWC", + out_layout="NHWC", + ) + R.output(gv) + return gv + + verify(Pool2DModule, Expected) + + +def test_avg_pool2d_valid(): + Pool2DModule = _make_pool2d_module( + tf.nn.avg_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "VALID" + ) + verify(Pool2DModule) + + +def test_max_pool2d_same(): + Pool2DModule = _make_pool2d_module( + tf.nn.max_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "SAME" + ) + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((1, 128, 128, 32), dtype="float32"), + ) -> R.Tensor((1, 128, 128, 32), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.max_pool2d( + data, + pool_size=[2, 2], + strides=[1, 1], + dilation=[1, 1], + padding=[0, 0, 1, 1], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + R.output(gv) + return gv + + verify(Pool2DModule, Expected) - else: - @I.ir_module - class ExpectedMaxValid: - @R.function - def main( - data: R.Tensor((1, 127, 127, 32), dtype="float32"), - ) -> R.Tensor((1, 127, 127, 32), dtype="float32"): - R.func_attr({"num_input": 1}) - with R.dataflow(): - gv: R.Tensor((1, 127, 127, 32), dtype="float32") = R.nn.max_pool2d( - data, - pool_size=[2, 2], - strides=[1, 1], - dilation=[1, 1], - padding=[0, 0, 0, 0], - ceil_mode=False, - layout="NHWC", - out_layout="NHWC", - ) - R.output(gv) - return gv - - verify(Pool2DModule, ExpectedMaxValid) +def test_max_pool2d_valid(): + Pool2DModule = _make_pool2d_module( + tf.nn.max_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "VALID" + ) + verify(Pool2DModule) @pytest.mark.parametrize(