diff --git a/src/python/qubed/Qube.py b/src/python/qubed/Qube.py index de3e826..b998275 100644 --- a/src/python/qubed/Qube.py +++ b/src/python/qubed/Qube.py @@ -512,6 +512,28 @@ def hash_node(node: Qube) -> int: return hash_node(self) + def remove_branch(self, b: Qube) -> Qube: + b_key = b.children[0].key + + new_children = [] + for c in self.children: + if c.key == b_key: + update_c = type(self).make_root(children=(c,), update_depth=False) + new_c = set_operations.set_operation( + update_c, + b, + set_operations.SetOperation.DIFFERENCE, + type(self), + check_depth=False, + ) + if len(new_c.children) != 0: + new_children.extend(new_c.children) + else: + c = c.remove_branch(b) + if len(c.children) != 0: + new_children.append(c) + return self.replace(children=tuple(sorted(new_children))) + def compress(self) -> Qube: """ This method is quite computationally heavy because of trees like this: @@ -559,7 +581,7 @@ def compare_metadata(self, B: Qube) -> bool: return False for k in self.metadata.keys(): if k not in B.metadata: - print(f"'{k}' not in {B.metadata.keys() = }") + print(f"'{k}' not in {B.metadata.keys()=}") return False if not np.array_equal(self.metadata[k], B.metadata[k]): print(f"self.metadata[{k}] != B.metadata.[{k}]") diff --git a/src/python/qubed/set_operations.py b/src/python/qubed/set_operations.py index 2a4e7e6..40b38ba 100644 --- a/src/python/qubed/set_operations.py +++ b/src/python/qubed/set_operations.py @@ -324,7 +324,7 @@ def added_axis(size: int, metadata: dict[str, np.ndarray]) -> dict[str, np.ndarr # @line_profiler.profile def set_operation( - A: Qube, B: Qube, operation_type: SetOperation, node_type, depth=0 + A: Qube, B: Qube, operation_type: SetOperation, node_type, depth=0, check_depth=True ) -> Qube | None: if DEBUG: print(f"{pad()}operation({operation_type.name}, depth={depth})") @@ -334,7 +334,8 @@ def set_operation( assert A.key == B.key assert A.type == B.type assert A.values == B.values - assert A.depth == B.depth + if check_depth: + assert A.depth == B.depth new_children: list[Qube] = [] @@ -347,7 +348,9 @@ def set_operation( # For every node group, perform the set operation for A_nodes, B_nodes in nodes_by_key.values(): output = list( - _set_operation(A_nodes, B_nodes, operation_type, node_type, depth + 1) + _set_operation( + A_nodes, B_nodes, operation_type, node_type, depth + 1, check_depth + ) ) new_children.extend(output) @@ -398,6 +401,7 @@ def _set_operation( operation_type: SetOperation, node_type, depth: int, + check_depth, ) -> Iterable[Qube]: """ This operation get called from `operation` when we've found two nodes that match and now need @@ -461,6 +465,7 @@ def make_new_node(source: Qube, values_indices: ValuesIndices): operation_type, node_type, depth=depth + 1, + check_depth=check_depth, ) if result is not None: # If we're doing a difference or xor we might want to throw away the intersection @@ -482,7 +487,7 @@ def make_new_node(source: Qube, values_indices: ValuesIndices): continue else: raise ValueError( - f"Only one of set_ops_result.intersection_A and set_ops_result.intersection_B is None, I didn't think that could happen! {set_ops_result = }" + f"Only one of set_ops_result.intersection_A and set_ops_result.intersection_B is None, I didn't think that could happen! {set_ops_result=}" ) if keep_only_A: @@ -558,7 +563,7 @@ def merge_values(qubes: list[Qube]) -> Qube: axis = example.depth if DEBUG: - print(f"{pad()}merge_values --- {axis = }") + print(f"{pad()}merge_values --- {axis=}") for i, qube in enumerate(qubes): qube.display(f"{pad()}in_{i}") @@ -694,7 +699,7 @@ def concat_metadata( example = qubes[0] if DEBUG: - print(f"concat_metadata --- {axis = }, qubes:") + print(f"concat_metadata --- {axis=}, qubes:") for qube in qubes: qube.display() @@ -747,8 +752,8 @@ def shallow_concat_metadata( if DEBUG: print("shallow_concat_metadata") - print(f"{concatenation_axis = }") - print(f"{sorting_indices = }") + print(f"{concatenation_axis=}") + print(f"{sorting_indices=}") for k, metadata_group in metadata_groups.items(): print(k, [m.shape for m in metadata_group]) diff --git a/tests/test_remove_branch.py b/tests/test_remove_branch.py new file mode 100644 index 0000000..4100508 --- /dev/null +++ b/tests/test_remove_branch.py @@ -0,0 +1,73 @@ +from qubed import Qube + + +def test_remove_branch(): + a = Qube.from_tree(""" + root + ├── class=od, expver=0001/0002, param=1/2 + └── class=rd + ├── expver=0001, param=1/2/3 + └── expver=0002, param=1/2 + """) + + b = Qube.from_tree(""" + root + ├── class=od, expver=0001/0002, param=1/2 + """) + + c = Qube.from_tree(""" + root + └── class=rd + ├── expver=0001, param=1/2/3 + └── expver=0002, param=1/2 + """) + + assert a.remove_branch(b) == c + + +def test_2(): + a = Qube.from_tree(""" + root + ├── class=od, expver=0001/0002, param=1/2 + └── class=rd + ├── expver=0001, param=1/2/3 + └── expver=0002, param=1/2 + """) + + b = Qube.from_tree(""" + root + └── expver=0001/0002, param=1/2 + """) + + c = Qube.from_tree(""" + root + └── class=rd + ├── expver=0001, param=3 + """) + + assert a.remove_branch(b) == c + + +def test_3(): + a = Qube.from_tree(""" + root + ├── class=od, expver=0001/0002, param=1/2 + └── class=rd + ├── expver=0001, param=1/2/3 + └── expver=0002, param=1/2 + """) + + b = Qube.from_tree(""" + root + └── expver=0001, param=1/2 + """) + + c = Qube.from_tree(""" + root + ├── class=od, expver=0002, param=1/2 + └── class=rd + ├── expver=0001, param=3 + └── expver=0002, param=1/2 + """) + + assert a.remove_branch(b) == c