Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 194 additions & 33 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,22 +700,12 @@ def main(x: R.Tensor((5, 30), dtype="float32")) -> R.Tensor(out_shape, dtype="in
verify(TfInput, Expected)


@pytest.mark.parametrize(
"data, kernel, data_format, strides, padding",
[
# Tf on CI (CPU) support only NHWC
((1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "SAME"),
((1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "VALID"),
# ((1, 32, 128, 128), (3, 3, 32, 32), "NCHW", (1, 1, 1, 1), "SAME"),
# ((1, 32, 128, 128), (3, 3, 32, 32), "NCHW", (1, 1, 1, 1), "VALID"),
],
)
def test_conv2d(data, kernel, data_format, strides, padding):
def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, padding):
class Conv2DModule(tf.Module):
@tf.function(
input_signature=[
tf.TensorSpec(shape=data, dtype=tf.float32),
tf.TensorSpec(shape=kernel, dtype=tf.float32),
tf.TensorSpec(shape=data_shape, dtype=tf.float32),
tf.TensorSpec(shape=kernel_shape, dtype=tf.float32),
]
)
def func(self, data, kernel):
Expand All @@ -727,39 +717,180 @@ def func(self, data, kernel):
padding=padding,
)

verify(Conv2DModule)
return Conv2DModule


@pytest.mark.parametrize(
"pool",
[tf.nn.avg_pool2d, tf.nn.max_pool2d],
)
@pytest.mark.parametrize(
"data, kernel, data_format, strides, padding",
[
# Tf on CI (CPU) support only NHWC
((1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "SAME"),
((1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "VALID"),
# ((1, 32, 128, 128), (3, 3), "NCHW", (1, 1, 1, 1), "SAME"),
# ((1, 32, 128, 128), (3, 3), "NCHW", (1, 1, 1, 1), "VALID"),
],
)
def test_pool_2d(pool, data, kernel, data_format, strides, padding):
def test_conv2d_same():
Conv2DModule = _make_conv2d_module(
(1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "SAME"
)

@I.ir_module
class Expected:
@R.function
def main(
data: R.Tensor((1, 128, 128, 32), dtype="float32"),
kernel: R.Tensor((3, 3, 32, 32), dtype="float32"),
) -> R.Tensor((1, 128, 128, 32), dtype="float32"):
R.func_attr({"num_input": 2})
with R.dataflow():
lv: R.Tensor((32, 3, 3, 32), dtype="float32") = R.permute_dims(
kernel, axes=[3, 0, 1, 2]
)
lv1: R.Tensor((3, 3, 32, 32), dtype="float32") = R.permute_dims(
lv, axes=[1, 2, 3, 0]
)
lv2: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.conv2d(
data,
lv1,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
groups=1,
data_layout="NHWC",
kernel_layout="HWIO",
out_layout="NHWC",
out_dtype="void",
)
gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.add(
lv2, R.const(np.zeros((32,), dtype="float32"))
)
R.output(gv)
return gv

verify(Conv2DModule, Expected)


def test_conv2d_valid():
Conv2DModule = _make_conv2d_module(
(1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "VALID"
)

@I.ir_module
class Expected:
@R.function
def main(
data: R.Tensor((1, 128, 128, 32), dtype="float32"),
kernel: R.Tensor((3, 3, 32, 32), dtype="float32"),
) -> R.Tensor((1, 126, 126, 32), dtype="float32"):
R.func_attr({"num_input": 2})
with R.dataflow():
lv: R.Tensor((32, 3, 3, 32), dtype="float32") = R.permute_dims(
kernel, axes=[3, 0, 1, 2]
)
lv1: R.Tensor((3, 3, 32, 32), dtype="float32") = R.permute_dims(
lv, axes=[1, 2, 3, 0]
)
lv2: R.Tensor((1, 126, 126, 32), dtype="float32") = R.nn.conv2d(
data,
lv1,
strides=[1, 1],
padding=[0, 0, 0, 0],
dilation=[1, 1],
groups=1,
data_layout="NHWC",
kernel_layout="HWIO",
out_layout="NHWC",
out_dtype="void",
)
gv: R.Tensor((1, 126, 126, 32), dtype="float32") = R.add(
lv2, R.const(np.zeros((32,), dtype="float32"))
)
R.output(gv)
return gv

verify(Conv2DModule, Expected)


def _make_pool2d_module(pool, data_shape, ksize, data_format, strides, padding):
class Pool2DModule(tf.Module):
@tf.function(
input_signature=[
tf.TensorSpec(shape=data, dtype=tf.float32),
tf.TensorSpec(shape=data_shape, dtype=tf.float32),
]
)
def func(self, data):
return pool(
input=data,
ksize=kernel,
ksize=ksize,
data_format=data_format,
strides=strides,
padding=padding,
)

return Pool2DModule


def test_avg_pool2d_same():
Pool2DModule = _make_pool2d_module(
tf.nn.avg_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "SAME"
)

@I.ir_module
class Expected:
@R.function
def main(
data: R.Tensor((1, 128, 128, 32), dtype="float32"),
) -> R.Tensor((1, 128, 128, 32), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.avg_pool2d(
data,
pool_size=[2, 2],
strides=[1, 1],
dilation=[1, 1],
padding=[0, 0, 1, 1],
ceil_mode=False,
count_include_pad=False,
layout="NHWC",
out_layout="NHWC",
)
R.output(gv)
return gv

verify(Pool2DModule, Expected)


def test_avg_pool2d_valid():
Pool2DModule = _make_pool2d_module(
tf.nn.avg_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "VALID"
)
verify(Pool2DModule)


def test_max_pool2d_same():
Pool2DModule = _make_pool2d_module(
tf.nn.max_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "SAME"
)

@I.ir_module
class Expected:
@R.function
def main(
data: R.Tensor((1, 128, 128, 32), dtype="float32"),
) -> R.Tensor((1, 128, 128, 32), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.max_pool2d(
data,
pool_size=[2, 2],
strides=[1, 1],
dilation=[1, 1],
padding=[0, 0, 1, 1],
ceil_mode=False,
layout="NHWC",
out_layout="NHWC",
)
R.output(gv)
return gv

verify(Pool2DModule, Expected)


def test_max_pool2d_valid():
Pool2DModule = _make_pool2d_module(
tf.nn.max_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "VALID"
)
verify(Pool2DModule)


Expand Down Expand Up @@ -836,7 +967,21 @@ class BatchMatMul(tf.Module):
def func(self, x, y):
return tf.matmul(x, y)

verify(BatchMatMul)
@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((2, 3, 4), dtype="float32"),
y: R.Tensor((2, 4, 5), dtype="float32"),
) -> R.Tensor((2, 3, 5), dtype="float32"):
R.func_attr({"num_input": 2})
with R.dataflow():
lv: R.Tensor((2, 3, 5), dtype="float32") = R.matmul(x, y, out_dtype="void")
gv: R.Tensor((2, 3, 5), dtype="float32") = R.reshape(lv, R.shape([2, 3, 5]))
R.output(gv)
return gv

verify(BatchMatMul, Expected)


def test_batch_matmul_adj():
Expand All @@ -850,7 +995,23 @@ class BatchMatMulAdj(tf.Module):
def func(self, x, y):
return tf.matmul(x, y, transpose_a=True, transpose_b=True)

verify(BatchMatMulAdj)
@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((2, 4, 3), dtype="float32"),
y: R.Tensor((2, 5, 4), dtype="float32"),
) -> R.Tensor((2, 3, 5), dtype="float32"):
R.func_attr({"num_input": 2})
with R.dataflow():
lv: R.Tensor((2, 3, 4), dtype="float32") = R.permute_dims(x, axes=[0, 2, 1])
lv1: R.Tensor((2, 4, 5), dtype="float32") = R.permute_dims(y, axes=[0, 2, 1])
lv2: R.Tensor((2, 3, 5), dtype="float32") = R.matmul(lv, lv1, out_dtype="void")
gv: R.Tensor((2, 3, 5), dtype="float32") = R.reshape(lv2, R.shape([2, 3, 5]))
R.output(gv)
return gv

verify(BatchMatMulAdj, Expected)


if __name__ == "__main__":
Expand Down
Loading