Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions control_plane/contracts/merge_train_stack_collapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _validate_plan(self) -> "MergeTrainStackCollapsePlan":
self.entries = _normalize_entries(self.entries)
self.mutations = _normalize_mutations(self.entries, self.mutations)
self.child_dispositions = _normalize_child_dispositions(
self.entries, self.child_dispositions
self.entries, self.mutations, self.child_dispositions
)
return self

Expand Down Expand Up @@ -209,9 +209,7 @@ def execute_merge_train_stack_collapse_plan(
if plan.status not in {"planned", "collapsing"}:
raise ValueError("merge train stack collapse plan is not executable")
updated_mutations: list[MergeTrainStackCollapseMutation] = []
current_head_shas = {
entry.pull_request_number: entry.head_sha for entry in plan.entries
}
current_head_shas = {entry.pull_request_number: entry.head_sha for entry in plan.entries}
current_status: MergeTrainStackCollapseStatus = "waiting_for_root_checks"
for mutation in plan.mutations:
child_head_sha = current_head_shas[mutation.child_pull_request_number]
Expand Down Expand Up @@ -253,9 +251,7 @@ def execute_merge_train_stack_collapse_plan(
updated_mutations.extend(plan.mutations[len(updated_mutations) :])
updated_child_dispositions = tuple(
disposition.model_copy(
update={
"expected_head_sha": current_head_shas[disposition.pull_request_number]
}
update={"expected_head_sha": current_head_shas[disposition.pull_request_number]}
)
for disposition in plan.child_dispositions
)
Expand Down Expand Up @@ -442,7 +438,9 @@ def build_merge_train_stack_collapse_id(
json.dumps(
{
"entry_head_shas": [
_normalize_required_value(sha, "merge train stack collapse id requires head sha")
_normalize_required_value(
sha, "merge train stack collapse id requires head sha"
)
for sha in entry_head_shas
],
"root_pull_request_number": root_pull_request_number,
Expand All @@ -451,10 +449,7 @@ def build_merge_train_stack_collapse_id(
separators=(",", ":"),
).encode("utf-8")
).hexdigest()[:16]
return (
f"merge-train-stack-collapse-{normalized_repository}-"
f"{normalized_base_branch}-{digest}"
)
return f"merge-train-stack-collapse-{normalized_repository}-{normalized_base_branch}-{digest}"


def _normalize_entries(
Expand Down Expand Up @@ -492,30 +487,34 @@ def _normalize_mutations(
for mutation in mutations
)
if actual_pairs != expected_pairs:
raise ValueError("merge train stack collapse mutations must be leaf-to-root ordered")
root_to_leaf_pairs = tuple(reversed(expected_pairs))
if actual_pairs == root_to_leaf_pairs:
return tuple(reversed(mutations))
raise ValueError("merge train stack collapse mutations must connect adjacent stack layers")
return mutations


def _normalize_child_dispositions(
entries: tuple[MergeTrainStackCollapseEntry, ...],
mutations: tuple[MergeTrainStackCollapseMutation, ...],
child_dispositions: tuple[MergeTrainStackChildDisposition, ...],
) -> tuple[MergeTrainStackChildDisposition, ...]:
if not child_dispositions:
current_head_shas = {entry.pull_request_number: entry.head_sha for entry in entries}
for mutation in mutations:
if mutation.status == "mutated" and mutation.merge_commit_sha:
current_head_shas[mutation.parent_pull_request_number] = mutation.merge_commit_sha
child_dispositions = tuple(
MergeTrainStackChildDisposition(
pull_request_number=entry.pull_request_number,
expected_head_sha=entry.head_sha,
expected_head_sha=current_head_shas[entry.pull_request_number],
)
for entry in entries[1:]
)
expected_numbers = tuple(entry.pull_request_number for entry in entries[1:])
actual_numbers = tuple(
disposition.pull_request_number for disposition in child_dispositions
)
actual_numbers = tuple(disposition.pull_request_number for disposition in child_dispositions)
if actual_numbers != expected_numbers:
raise ValueError(
"merge train stack child dispositions must match stack child order"
)
raise ValueError("merge train stack child dispositions must match stack child order")
return child_dispositions


Expand All @@ -532,7 +531,9 @@ def _child_disposition_comment_body(


def _normalize_repository(repository: str) -> str:
normalized = _normalize_required_value(repository, "merge train stack collapse requires repository")
normalized = _normalize_required_value(
repository, "merge train stack collapse requires repository"
)
if "/" not in normalized:
raise ValueError("merge train stack collapse repository must be owner/name")
return normalized
Expand Down
53 changes: 39 additions & 14 deletions tests/test_merge_train_stack_collapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def test_collapse_id_is_deterministic(self) -> None:
)

self.assertEqual(first, second)
self.assertTrue(first.startswith("merge-train-stack-collapse-example-merge-train-repo-main-"))
self.assertTrue(
first.startswith("merge-train-stack-collapse-example-merge-train-repo-main-")
)

def test_plan_uses_root_ready_to_merge_intent_and_leaf_to_root_mutations(self) -> None:
discovery_result = discover_merge_train_stack(
Expand Down Expand Up @@ -74,9 +76,7 @@ def test_plan_requires_a_ready_stack_discovery_result(self) -> None:
snapshot=MergeTrainDryRunSnapshot(
repository="example/merge-train-repo",
base_branch="main",
pull_requests=(
_pull_request(20, head_ref="feature/root", base_ref="main"),
),
pull_requests=(_pull_request(20, head_ref="feature/root", base_ref="main"),),
),
root_pull_request_number=20,
)
Expand Down Expand Up @@ -153,9 +153,7 @@ def test_plan_record_id_canonicalizes_equivalent_utc_timestamps(self) -> None:

self.assertEqual(zulu_record.record_id, offset_record.record_id)
self.assertTrue(
zulu_record.record_id.startswith(
"merge-train-stack-collapse-plan-20260514T133100Z-"
)
zulu_record.record_id.startswith("merge-train-stack-collapse-plan-20260514T133100Z-")
)

def test_execute_plan_mutates_each_child_into_its_parent_branch(self) -> None:
Expand All @@ -180,10 +178,7 @@ def test_execute_plan_mutates_each_child_into_its_parent_branch(self) -> None:
["merge-32-31", "merge-31-30"],
)
self.assertEqual(
[
disposition.expected_head_sha
for disposition in executed_plan.child_dispositions
],
[disposition.expected_head_sha for disposition in executed_plan.child_dispositions],
["merge-32-31", "head-32"],
)
self.assertEqual(
Expand All @@ -210,6 +205,38 @@ def test_execute_plan_mutates_each_child_into_its_parent_branch(self) -> None:
],
)

def test_model_load_accepts_previous_root_to_leaf_mutation_order(self) -> None:
plan_payload = _collapse_plan().model_dump(mode="json")
plan_payload["mutations"] = tuple(reversed(plan_payload["mutations"]))

loaded_plan = MergeTrainStackCollapsePlan.model_validate(plan_payload)

self.assertEqual(
[
(mutation.child_pull_request_number, mutation.parent_pull_request_number)
for mutation in loaded_plan.mutations
],
[(32, 31), (31, 30)],
)

def test_model_load_rebuilds_missing_child_dispositions_from_mutations(self) -> None:
executed_plan = execute_merge_train_stack_collapse_plan(
plan=_collapse_plan(),
branch_client=_RecordingStackCollapseBranchClient(
merge_commit_shas=("merge-32-31", "merge-31-30")
),
updated_at="2026-05-14T13:45:00Z",
)
plan_payload = executed_plan.model_dump(mode="json")
plan_payload.pop("child_dispositions")

loaded_plan = MergeTrainStackCollapsePlan.model_validate(plan_payload)

self.assertEqual(
[disposition.expected_head_sha for disposition in loaded_plan.child_dispositions],
["merge-32-31", "head-32"],
)

def test_execute_plan_blocks_at_first_failed_mutation(self) -> None:
plan = _collapse_plan()
branch_client = _RecordingStackCollapseBranchClient(
Expand Down Expand Up @@ -383,9 +410,7 @@ def __init__(self, *, failure: Exception | None = None) -> None:
self.labels: list[tuple[int, str]] = []
self.closed: list[tuple[int, str]] = []

def comment_pull_request(
self, *, repository: str, pull_request_number: int, body: str
) -> str:
def comment_pull_request(self, *, repository: str, pull_request_number: int, body: str) -> str:
self.comments.append((pull_request_number, body))
return (
f"https://github.com/{repository}/pull/{pull_request_number}"
Expand Down