diff --git a/tests/hw4/test_nd_backend.py b/tests/hw4/test_nd_backend.py index c44d97f..4449308 100755 --- a/tests/hw4/test_nd_backend.py +++ b/tests/hw4/test_nd_backend.py @@ -169,6 +169,39 @@ def test_stack_backward(shape, axis, l, device): for i in range(l): np.testing.assert_allclose(A_t[i].grad.numpy(), A[i].grad.numpy(), atol=1e-5, rtol=1e-5) +SPLIT_PARAMETERS = [ + ((1, 5, 5), 0), + ((5, 5, 2), 2), + ((1, 5, 5, 7), 2) +] +@pytest.mark.parametrize("shape, axis", SPLIT_PARAMETERS) +@pytest.mark.parametrize("device", _DEVICES, ids=["cpu", "cuda"]) +def test_split(shape, axis, device): + _A = np.random.randn(*shape).astype(np.float32) + A = ndl.Tensor(nd.array(_A), device=device) + A_t = torch.Tensor(_A) + out = ndl.split(A, axis=axis) + out_t = torch.split(A_t, 1, dim=axis) + for i in range(shape[axis]): + np.testing.assert_allclose(out_t[i].squeeze(axis).numpy(), out[i].numpy(), atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize("shape, axis", SPLIT_PARAMETERS) +@pytest.mark.parametrize("device", _DEVICES, ids=["cpu", "cuda"]) +def test_split_backward(shape, axis, device): + _A = np.random.randn(*shape).astype(np.float32) + A = ndl.Tensor(nd.array(_A), device=device) + A_t = torch.Tensor(_A) + + A_t.requires_grad = True + + out = ndl.split(A, axis=axis) + out_t = torch.split(A_t, 1, dim=axis) + + sum((t.sum() for t in out)).backward() + sum((t.sum() for t in out_t)).backward() + + np.testing.assert_allclose(A_t.grad.numpy(), A.grad.numpy(), atol=1e-5, rtol=1e-5) SUMMATION_PARAMETERS = [((1, 1, 1), None), ((5, 3), 0),