From 78276311f77e442eb28b5abc2ade995c8c2e82c3 Mon Sep 17 00:00:00 2001 From: mathleur Date: Mon, 29 Sep 2025 09:48:26 +0200 Subject: [PATCH 1/3] remove_branch function --- src/python/qubed/Qube.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/python/qubed/Qube.py b/src/python/qubed/Qube.py index de3e826..4c17407 100644 --- a/src/python/qubed/Qube.py +++ b/src/python/qubed/Qube.py @@ -512,6 +512,25 @@ def hash_node(node: Qube) -> int: return hash_node(self) + def remove_branch(self, b: Qube) -> Qube: + for c in self.children: + if c.key != b.key: + c.remove_branch(b) + + # We have c.key = b.key, so we take the difference of the two Qubes + new_children = [] + for c in self.children: + if c.key == b.key: + new_c = set_operations.set_operation( + c, b, set_operations.SetOperation.DIFFERENCE, type(self) + ) + new_children.append(new_c) + else: + new_children.append(c) + + # and replace the children with the resulting Qube + return self.replace(children=tuple(sorted(new_children))) + def compress(self) -> Qube: """ This method is quite computationally heavy because of trees like this: @@ -559,7 +578,7 @@ def compare_metadata(self, B: Qube) -> bool: return False for k in self.metadata.keys(): if k not in B.metadata: - print(f"'{k}' not in {B.metadata.keys() = }") + print(f"'{k}' not in {B.metadata.keys()=}") return False if not np.array_equal(self.metadata[k], B.metadata[k]): print(f"self.metadata[{k}] != B.metadata.[{k}]") From 35eb6664b0ae9fe5b44361cee6374f100903a1fe Mon Sep 17 00:00:00 2001 From: mathleur Date: Mon, 29 Sep 2025 16:28:30 +0200 Subject: [PATCH 2/3] WIP: working remove_branch function --- src/python/qubed/Qube.py | 19 ++++--- src/python/qubed/set_operations.py | 12 ++--- tests/test_remove_branch.py | 83 ++++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 16 deletions(-) create mode 100644 tests/test_remove_branch.py diff --git a/src/python/qubed/Qube.py b/src/python/qubed/Qube.py index 4c17407..c5703f8 100644 --- a/src/python/qubed/Qube.py +++ b/src/python/qubed/Qube.py @@ -513,22 +513,21 @@ def hash_node(node: Qube) -> int: return hash_node(self) def remove_branch(self, b: Qube) -> Qube: - for c in self.children: - if c.key != b.key: - c.remove_branch(b) + b_key = b.children[0].key - # We have c.key = b.key, so we take the difference of the two Qubes new_children = [] for c in self.children: - if c.key == b.key: + if c.key == b_key: + update_c = type(self).make_root(children=(c,), update_depth=False) new_c = set_operations.set_operation( - c, b, set_operations.SetOperation.DIFFERENCE, type(self) + update_c, b, set_operations.SetOperation.DIFFERENCE, type(self) ) - new_children.append(new_c) + if len(new_c.children) != 0: + new_children.extend(new_c.children) else: - new_children.append(c) - - # and replace the children with the resulting Qube + c = c.remove_branch(b) + if len(c.children) != 0: + new_children.append(c) return self.replace(children=tuple(sorted(new_children))) def compress(self) -> Qube: diff --git a/src/python/qubed/set_operations.py b/src/python/qubed/set_operations.py index 2a4e7e6..d06f37e 100644 --- a/src/python/qubed/set_operations.py +++ b/src/python/qubed/set_operations.py @@ -334,7 +334,7 @@ def set_operation( assert A.key == B.key assert A.type == B.type assert A.values == B.values - assert A.depth == B.depth + # assert A.depth == B.depth new_children: list[Qube] = [] @@ -482,7 +482,7 @@ def make_new_node(source: Qube, values_indices: ValuesIndices): continue else: raise ValueError( - f"Only one of set_ops_result.intersection_A and set_ops_result.intersection_B is None, I didn't think that could happen! {set_ops_result = }" + f"Only one of set_ops_result.intersection_A and set_ops_result.intersection_B is None, I didn't think that could happen! {set_ops_result=}" ) if keep_only_A: @@ -558,7 +558,7 @@ def merge_values(qubes: list[Qube]) -> Qube: axis = example.depth if DEBUG: - print(f"{pad()}merge_values --- {axis = }") + print(f"{pad()}merge_values --- {axis=}") for i, qube in enumerate(qubes): qube.display(f"{pad()}in_{i}") @@ -694,7 +694,7 @@ def concat_metadata( example = qubes[0] if DEBUG: - print(f"concat_metadata --- {axis = }, qubes:") + print(f"concat_metadata --- {axis=}, qubes:") for qube in qubes: qube.display() @@ -747,8 +747,8 @@ def shallow_concat_metadata( if DEBUG: print("shallow_concat_metadata") - print(f"{concatenation_axis = }") - print(f"{sorting_indices = }") + print(f"{concatenation_axis=}") + print(f"{sorting_indices=}") for k, metadata_group in metadata_groups.items(): print(k, [m.shape for m in metadata_group]) diff --git a/tests/test_remove_branch.py b/tests/test_remove_branch.py new file mode 100644 index 0000000..6d4d19d --- /dev/null +++ b/tests/test_remove_branch.py @@ -0,0 +1,83 @@ +from qubed import Qube + + +def test_remove_branch(): + a = Qube.from_tree(""" + root + ├── class=od, expver=0001/0002, param=1/2 + └── class=rd + ├── expver=0001, param=1/2/3 + └── expver=0002, param=1/2 + """) + + b = Qube.from_tree(""" + root + ├── class=od, expver=0001/0002, param=1/2 + """) + + c = Qube.from_tree(""" + root + └── class=rd + ├── expver=0001, param=1/2/3 + └── expver=0002, param=1/2 + """) + + print(a.remove_branch(b)) + print("AND") + print(c) + + assert a.remove_branch(b) == c + + +def test_2(): + a = Qube.from_tree(""" + root + ├── class=od, expver=0001/0002, param=1/2 + └── class=rd + ├── expver=0001, param=1/2/3 + └── expver=0002, param=1/2 + """) + + b = Qube.from_tree(""" + root + └── expver=0001/0002, param=1/2 + """) + + c = Qube.from_tree(""" + root + └── class=rd + ├── expver=0001, param=3 + """) + + print(a.remove_branch(b)) + print(c) + + assert a.remove_branch(b) == c + + +def test_3(): + a = Qube.from_tree(""" + root + ├── class=od, expver=0001/0002, param=1/2 + └── class=rd + ├── expver=0001, param=1/2/3 + └── expver=0002, param=1/2 + """) + + b = Qube.from_tree(""" + root + └── expver=0001, param=1/2 + """) + + c = Qube.from_tree(""" + root + ├── class=od, expver=0002, param=1/2 + └── class=rd + ├── expver=0001, param=3 + └── expver=0002, param=1/2 + """) + + print(a.remove_branch(b)) + print(c) + + assert a.remove_branch(b) == c From 2fee3f3b6c3fb33c756fb926091fc6bd35c121a0 Mon Sep 17 00:00:00 2001 From: mathleur Date: Mon, 29 Sep 2025 16:41:43 +0200 Subject: [PATCH 3/3] add option to check depth in tree ops --- src/python/qubed/Qube.py | 6 +++++- src/python/qubed/set_operations.py | 11 ++++++++--- tests/test_remove_branch.py | 10 ---------- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/python/qubed/Qube.py b/src/python/qubed/Qube.py index c5703f8..b998275 100644 --- a/src/python/qubed/Qube.py +++ b/src/python/qubed/Qube.py @@ -520,7 +520,11 @@ def remove_branch(self, b: Qube) -> Qube: if c.key == b_key: update_c = type(self).make_root(children=(c,), update_depth=False) new_c = set_operations.set_operation( - update_c, b, set_operations.SetOperation.DIFFERENCE, type(self) + update_c, + b, + set_operations.SetOperation.DIFFERENCE, + type(self), + check_depth=False, ) if len(new_c.children) != 0: new_children.extend(new_c.children) diff --git a/src/python/qubed/set_operations.py b/src/python/qubed/set_operations.py index d06f37e..40b38ba 100644 --- a/src/python/qubed/set_operations.py +++ b/src/python/qubed/set_operations.py @@ -324,7 +324,7 @@ def added_axis(size: int, metadata: dict[str, np.ndarray]) -> dict[str, np.ndarr # @line_profiler.profile def set_operation( - A: Qube, B: Qube, operation_type: SetOperation, node_type, depth=0 + A: Qube, B: Qube, operation_type: SetOperation, node_type, depth=0, check_depth=True ) -> Qube | None: if DEBUG: print(f"{pad()}operation({operation_type.name}, depth={depth})") @@ -334,7 +334,8 @@ def set_operation( assert A.key == B.key assert A.type == B.type assert A.values == B.values - # assert A.depth == B.depth + if check_depth: + assert A.depth == B.depth new_children: list[Qube] = [] @@ -347,7 +348,9 @@ def set_operation( # For every node group, perform the set operation for A_nodes, B_nodes in nodes_by_key.values(): output = list( - _set_operation(A_nodes, B_nodes, operation_type, node_type, depth + 1) + _set_operation( + A_nodes, B_nodes, operation_type, node_type, depth + 1, check_depth + ) ) new_children.extend(output) @@ -398,6 +401,7 @@ def _set_operation( operation_type: SetOperation, node_type, depth: int, + check_depth, ) -> Iterable[Qube]: """ This operation get called from `operation` when we've found two nodes that match and now need @@ -461,6 +465,7 @@ def make_new_node(source: Qube, values_indices: ValuesIndices): operation_type, node_type, depth=depth + 1, + check_depth=check_depth, ) if result is not None: # If we're doing a difference or xor we might want to throw away the intersection diff --git a/tests/test_remove_branch.py b/tests/test_remove_branch.py index 6d4d19d..4100508 100644 --- a/tests/test_remove_branch.py +++ b/tests/test_remove_branch.py @@ -22,10 +22,6 @@ def test_remove_branch(): └── expver=0002, param=1/2 """) - print(a.remove_branch(b)) - print("AND") - print(c) - assert a.remove_branch(b) == c @@ -49,9 +45,6 @@ def test_2(): ├── expver=0001, param=3 """) - print(a.remove_branch(b)) - print(c) - assert a.remove_branch(b) == c @@ -77,7 +70,4 @@ def test_3(): └── expver=0002, param=1/2 """) - print(a.remove_branch(b)) - print(c) - assert a.remove_branch(b) == c