From a5636bf3a3960cd6835dba946db6251bada5ca24 Mon Sep 17 00:00:00 2001 From: Chi Zhang Date: Tue, 24 Mar 2026 18:42:34 +0800 Subject: [PATCH 1/2] update Signed-off-by: Chi Zhang --- tests/e2e/test_kv_interface_e2e.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index 0a248ea7..f0e84b04 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -432,14 +432,20 @@ def test_kv_batch_get_partial_keys(self, controller): keys = ["get_multi_3", "get_multi_4", "get_multi_5"] partial_keys = ["get_multi_3", "get_multi_5"] input_data = torch.tensor([[1, 2], [3, 4], [5, 6]]) + nested_data = torch.nested.nested_tensor([[10, 11, 12], [20], [30, 31]]) expected_data = torch.tensor([[1, 2], [5, 6]]) + expected_nested_data = [torch.tensor([10, 11, 12]), torch.tensor([30, 31])] - fields = TensorDict({"data": input_data}, batch_size=3) + fields = TensorDict({"data": input_data, "nested_data": nested_data}, batch_size=3) tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}, {}]) retrieved = tq.kv_batch_get(keys=partial_keys, partition_id=partition_id) assert_tensor_equal(retrieved["data"], expected_data) + print(retrieved["nested_data"]) + for actual, expected in zip(retrieved["nested_data"], expected_nested_data, strict=True): + assert_tensor_equal(actual, expected) + tq.kv_clear(keys=keys, partition_id=partition_id) def test_kv_batch_get_partial_fields(self, controller): From ea93a451178668538eaccb55ef54f090b33bd4ac Mon Sep 17 00:00:00 2001 From: Chi Zhang Date: Wed, 25 Mar 2026 14:30:43 +0800 Subject: [PATCH 2/2] add nested tensor test Signed-off-by: Chi Zhang --- tests/e2e/test_kv_interface_e2e.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index f0e84b04..8760c321 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -433,19 +433,27 @@ def test_kv_batch_get_partial_keys(self, controller): partial_keys = ["get_multi_3", "get_multi_5"] input_data = torch.tensor([[1, 2], [3, 4], [5, 6]]) nested_data = torch.nested.nested_tensor([[10, 11, 12], [20], [30, 31]]) + three_d_nested_data = torch.nested.nested_tensor( + [[[10, 11], [12, 13]], [[20, 21], [22, 23]], [[30, 31], [32, 33]]] + ) + expected_three_d_nested_data = [torch.tensor([[10, 11], [12, 13]]), torch.tensor([[30, 31], [32, 33]])] expected_data = torch.tensor([[1, 2], [5, 6]]) expected_nested_data = [torch.tensor([10, 11, 12]), torch.tensor([30, 31])] - fields = TensorDict({"data": input_data, "nested_data": nested_data}, batch_size=3) + fields = TensorDict( + {"data": input_data, "nested_data": nested_data, "three_d_nested_data": three_d_nested_data}, batch_size=3 + ) tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}, {}]) retrieved = tq.kv_batch_get(keys=partial_keys, partition_id=partition_id) assert_tensor_equal(retrieved["data"], expected_data) - print(retrieved["nested_data"]) for actual, expected in zip(retrieved["nested_data"], expected_nested_data, strict=True): assert_tensor_equal(actual, expected) + for actual, expected in zip(retrieved["three_d_nested_data"], expected_three_d_nested_data, strict=True): + assert_tensor_equal(actual, expected) + tq.kv_clear(keys=keys, partition_id=partition_id) def test_kv_batch_get_partial_fields(self, controller):