From f1b083fa8ac52f7dae444d7b973d99e7dc6188f6 Mon Sep 17 00:00:00 2001 From: kyokuping Date: Sat, 4 Apr 2026 02:12:03 +0900 Subject: [PATCH 1/2] Batch mergeability checks on base branch pushes --- src/bors/build_queue.rs | 4 +- src/bors/handlers/autobuild.rs | 7 +- src/bors/handlers/mod.rs | 6 +- src/bors/handlers/pr_events.rs | 32 +-- src/bors/handlers/review.rs | 7 +- src/bors/labels.rs | 4 +- src/bors/merge_queue.rs | 6 +- src/bors/mergeability_queue.rs | 391 +++++++++++++++++++++++---------- src/config.rs | 4 +- src/github/api/client.rs | 155 ++++++++++++- src/github/mod.rs | 25 +++ src/tests/mod.rs | 2 +- 12 files changed, 492 insertions(+), 151 deletions(-) diff --git a/src/bors/build_queue.rs b/src/bors/build_queue.rs index 8ab4c49d..7530fad1 100644 --- a/src/bors/build_queue.rs +++ b/src/bors/build_queue.rs @@ -213,7 +213,7 @@ async fn maybe_timeout_build( BuildKind::Auto => LabelTrigger::AutoBuildFailed, }; let gh_pr = repo.client.get_pull_request(pr.number).await?; - handle_label_trigger(repo, &gh_pr, trigger).await?; + handle_label_trigger(repo, &gh_pr.into(), trigger).await?; if let Err(error) = repo .client @@ -348,7 +348,7 @@ async fn maybe_complete_build( .await?; if let Some(trigger) = trigger { let pr = repo.client.get_pull_request(pr_num).await?; - handle_label_trigger(repo, &pr, trigger).await?; + handle_label_trigger(repo, &pr.into(), trigger).await?; } if let Some(check_run_id) = build.check_run_id { diff --git a/src/bors/handlers/autobuild.rs b/src/bors/handlers/autobuild.rs index a1f96f0d..782b5c1b 100644 --- a/src/bors/handlers/autobuild.rs +++ b/src/bors/handlers/autobuild.rs @@ -38,7 +38,12 @@ pub(super) async fn command_retry( merge_queue_tx.notify().await?; // Retrying is essentially like a reapproval - handle_label_trigger(&repo_state, pr.github, LabelTrigger::Approved).await?; + handle_label_trigger( + &repo_state, + &pr.github.clone().into(), + LabelTrigger::Approved, + ) + .await?; } else { let pending_auto_build = pr_model .auto_build diff --git a/src/bors/handlers/mod.rs b/src/bors/handlers/mod.rs index 36ef59dd..f64f2ec1 100644 --- a/src/bors/handlers/mod.rs +++ b/src/bors/handlers/mod.rs @@ -28,7 +28,7 @@ use crate::bors::{ TRY_BRANCH_NAME, }; use crate::database::{DelegatedPermission, PullRequestModel}; -use crate::github::{CommitSha, GithubUser, LabelTrigger, PullRequest, PullRequestNumber}; +use crate::github::{CommitSha, GithubUser, LabelTrigger, PullRequest, PullRequestNumber, PullRequestSummary}; use crate::permissions::PermissionType; use crate::{PgDbClient, TeamApiClient, load_repositories}; use anyhow::Context; @@ -751,7 +751,7 @@ pub async fn unapprove_pr( repo_state: &RepositoryState, db: &PgDbClient, pr_db: &PullRequestModel, - pr_gh: &PullRequest, + pr_gh: &PullRequestSummary, ) -> anyhow::Result<()> { db.unapprove(pr_db).await?; handle_label_trigger(repo_state, pr_gh, LabelTrigger::Unapproved).await?; @@ -817,7 +817,7 @@ pub async fn invalidate_pr( // Step 1: unapprove the pull request if it was approved // This happens everytime the PR is invalidated, if it was approved before let pr_unapproved = if pr_db.is_approved() { - unapprove_pr(repo_state, db, pr_db, pr_gh).await?; + unapprove_pr(repo_state, db, pr_db, &pr_gh.clone().into()).await?; true } else { false diff --git a/src/bors/handlers/pr_events.rs b/src/bors/handlers/pr_events.rs index 36156d19..46de42df 100644 --- a/src/bors/handlers/pr_events.rs +++ b/src/bors/handlers/pr_events.rs @@ -16,7 +16,6 @@ use crate::bors::{AUTO_BRANCH_NAME, BorsContext, hide_tagged_comments}; use crate::bors::{PullRequestStatus, RepositoryState}; use crate::database::{PullRequestModel, UpsertPullRequestParams}; use crate::github::CommitSha; -use crate::utils::text::pluralize; use std::sync::Arc; pub(super) async fn handle_pull_request_edited( @@ -243,29 +242,22 @@ pub(super) async fn handle_push_to_branch( mergeability_queue: &MergeabilityQueueSender, payload: PushToBranch, ) -> anyhow::Result<()> { - let affected_prs = db - .set_stale_mergeability_status_by_base_branch(repo_state.repository(), &payload.branch) + db.set_stale_mergeability_status_by_base_branch(repo_state.repository(), &payload.branch) .await?; - if !affected_prs.is_empty() { - tracing::info!( - "Adding {} {} to the mergeability queue due to a new commit pushed to base branch `{}`", - affected_prs.len(), - pluralize("PR", affected_prs.len()), - payload.branch - ); + tracing::info!( + "Adding a batch to the mergeability queue due to a new commit pushed to base branch `{}`", + payload.branch + ); - // Try to find an auto build that matches this SHA - let merged_pr = find_pr_by_merged_commit(&repo_state, &db, CommitSha(payload.sha)) - .await - .ok() - .flatten() - .map(|pr| pr.number); + // Try to find an auto build that matches this SHA + let merged_pr = find_pr_by_merged_commit(&repo_state, &db, CommitSha(payload.sha)) + .await + .ok() + .flatten() + .map(|pr| pr.number); - for pr in affected_prs { - mergeability_queue.enqueue_pr(&pr, merged_pr); - } - } + mergeability_queue.enqueue_batch(repo_state.repository(), &payload.branch, merged_pr); Ok(()) } diff --git a/src/bors/handlers/review.rs b/src/bors/handlers/review.rs index 2e8b36b3..69dd400e 100644 --- a/src/bors/handlers/review.rs +++ b/src/bors/handlers/review.rs @@ -133,7 +133,12 @@ pub(super) async fn command_approve( ) .await?; - handle_label_trigger(&repo_state, pr.github, LabelTrigger::Approved).await + handle_label_trigger( + &repo_state, + &pr.github.clone().into(), + LabelTrigger::Approved, + ) + .await } /// Normalize approvers (given after @bors r=) by removing leading @, possibly from multiple diff --git a/src/bors/labels.rs b/src/bors/labels.rs index 8b20edf6..3339264a 100644 --- a/src/bors/labels.rs +++ b/src/bors/labels.rs @@ -3,13 +3,13 @@ use std::collections::HashSet; use tracing::log; use crate::bors::RepositoryState; -use crate::github::{LabelModification, LabelTrigger, PullRequest}; +use crate::github::{LabelModification, LabelTrigger, PullRequestSummary}; /// If there are any label modifications that should be performed on the given PR when `trigger` /// happens, this function will perform them. pub async fn handle_label_trigger( repo: &RepositoryState, - pr: &PullRequest, + pr: &PullRequestSummary, trigger: LabelTrigger, ) -> anyhow::Result<()> { let mut add: Vec = Vec::new(); diff --git a/src/bors/merge_queue.rs b/src/bors/merge_queue.rs index 38bec655..0021cda0 100644 --- a/src/bors/merge_queue.rs +++ b/src/bors/merge_queue.rs @@ -352,7 +352,7 @@ async fn handle_start_auto_build( update_pr_with_known_mergeability( repo, &ctx.db, - &gh_pr, + &gh_pr.into(), pr, mergeability_sender.get_conflict_source(pr), ) @@ -405,7 +405,7 @@ async fn handle_start_auto_build( update_pr_with_known_mergeability( repo, &ctx.db, - &gh_pr, + &gh_pr.into(), pr, mergeability_sender.get_conflict_source(pr), ) @@ -466,7 +466,7 @@ Actual head SHA: {actual_sha}"#, // Note: we don't use invalidate_pr here, because we know that the PR is a rollup, // to have more control over the message. - unapprove_pr(repo, &ctx.db, pr, &gh_pr).await?; + unapprove_pr(repo, &ctx.db, pr, &gh_pr.clone().into()).await?; mismatches.sort_by_key(|mismatch| mismatch.member); diff --git a/src/bors/mergeability_queue.rs b/src/bors/mergeability_queue.rs index 8e1b6bc5..485c70c2 100644 --- a/src/bors/mergeability_queue.rs +++ b/src/bors/mergeability_queue.rs @@ -164,8 +164,10 @@ use crate::bors::comment::merge_conflict_comment; use crate::bors::handlers::unapprove_pr; use crate::bors::labels::handle_label_trigger; use crate::database::{MergeableState, OctocrabMergeableState, PullRequestModel}; -use crate::github::{GithubRepoName, LabelTrigger, PullRequest, PullRequestNumber}; -use std::cmp::Reverse; +use crate::github::{ + GithubRepoName, LabelTrigger, PullRequest, PullRequestNumber, PullRequestSummary, +}; +use std::cmp::{Ordering, Reverse}; use std::collections::{BTreeMap, BinaryHeap}; use std::sync::atomic::AtomicBool; use std::sync::{Arc, Mutex}; @@ -181,9 +183,12 @@ const BASE_DELAY: Duration = Duration::from_secs(5); #[cfg(test)] const BASE_DELAY: Duration = Duration::from_millis(500); -/// Max number of mergeable check retries before giving up. +/// Max number of single PR mergeable check retries before giving up. const MAX_RETRIES: u32 = 5; +/// Max number of batched mergeable check retries before giving up. +const BATCH_MAX_RETRIES: u32 = 10; + #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] pub struct MergeabilityCheckPriority(u32); @@ -212,9 +217,25 @@ pub struct PullRequestToCheck { conflict_source: Option, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PullRequestBatchToCheck { + repo: GithubRepoName, + base_branch: String, + /// Which attempt to check mergeability are we processing? + attempt: u32, + /// Merged pull request that *might* have caused a merge conflict for this PR. + conflict_source: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum MergeabilityCheckEntry { + Single(PullRequestToCheck), + Batch(PullRequestBatchToCheck), +} + #[derive(Debug, Clone, PartialEq, Eq)] struct QueueItem { - entry: PullRequestToCheck, + entry: MergeabilityCheckEntry, /// When to process item (None = immediate is ordered before Some). expiration: Option, } @@ -241,15 +262,34 @@ impl Ord for QueueItem { { // Note: we don't order by priority, because items with a different priority should // be in different queues altogether. Let's check that here - assert_eq!(entry1.priority, entry2.priority); + // TODO: uncomment this + // assert_eq!(entry1.priority, entry2.priority); // Order by expiration => None before Some - expire1 - .cmp(expire2) - // Then order by PR number - .then_with(|| entry1.pull_request.cmp(&entry2.pull_request)) - // And finally by attempt - .then_with(|| entry1.attempt.cmp(&entry2.attempt)) + expire1.cmp(expire2).then_with(|| match (entry1, entry2) { + // Prioritize batches + (MergeabilityCheckEntry::Batch(_), MergeabilityCheckEntry::Single(_)) => { + Ordering::Less + } + (MergeabilityCheckEntry::Single(_), MergeabilityCheckEntry::Batch(_)) => { + Ordering::Greater + } + // For batches... + (MergeabilityCheckEntry::Batch(batch1), MergeabilityCheckEntry::Batch(batch2)) => + // Order by base branch + { + batch1.base_branch.cmp(&batch2.base_branch) + } + // For single PRs... + (MergeabilityCheckEntry::Single(pr1), MergeabilityCheckEntry::Single(pr2)) => + // Order by PR number + { + pr1.pull_request + .cmp(&pr2.pull_request) + // And finally by attempt + .then_with(|| pr1.attempt.cmp(&pr2.attempt)) + } + }) } } } @@ -301,7 +341,7 @@ impl MergeabilityQueueSender { /// Return the PRs currently in the mergeability queue. #[cfg(test)] - pub fn get_queue_prs(&self) -> Vec { + pub fn get_queue_entries(&self) -> Vec { self.inner .queues .lock() @@ -348,6 +388,29 @@ impl MergeabilityQueueSender { ); } + /// Enqueues a batch for mergeability checking of PRs based on the given base branch. + /// The PRs will be repeatedly fetched in the background from the GH API, until we can figure out + /// every PR's mergeability status. + /// + /// If a batch for the given base branch already exists, it will stay there. + /// Notably, its conflict source will *not* get overridden by `None` if it was set before. + pub fn enqueue_batch( + &self, + repo: &GithubRepoName, + base_branch: &str, + conflict_source: Option, + ) { + self.insert_pr_item( + MergeabilityCheckEntry::Batch(PullRequestBatchToCheck { + repo: repo.clone(), + base_branch: base_branch.to_owned(), + attempt: 1, + conflict_source, + }), + None, + ); + } + /// Try to return an existing conflict source for the given PR. pub fn get_conflict_source(&self, model: &PullRequestModel) -> Option { let pr_data = PullRequestData { @@ -358,8 +421,15 @@ impl MergeabilityQueueSender { queues .values() .flat_map(|queue| queue.iter()) - .find(|entry| entry.0.entry.pull_request == pr_data) - .and_then(|entry| entry.0.entry.conflict_source) + .find_map(|entry| match &entry.0.entry { + MergeabilityCheckEntry::Single(pr) => { + (pr.pull_request == pr_data).then_some(pr.conflict_source) + } + MergeabilityCheckEntry::Batch(batch) => (batch.repo == pr_data.repo + && batch.base_branch == model.base_branch) + .then_some(batch.conflict_source), + }) + .flatten() } pub fn enqueue( @@ -370,7 +440,7 @@ impl MergeabilityQueueSender { conflict_source: Option, ) { self.insert_pr_item( - PullRequestToCheck { + MergeabilityCheckEntry::Single(PullRequestToCheck { pull_request: PullRequestData { pr_number: number, repo: repo.clone(), @@ -378,18 +448,16 @@ impl MergeabilityQueueSender { attempt: 1, priority, conflict_source, - }, + }), None, ); } - fn enqueue_retry(&self, pr_item: PullRequestToCheck) { - let PullRequestToCheck { - pull_request, - attempt, - priority, - conflict_source, - } = pr_item; + fn enqueue_retry(&self, pr_item: MergeabilityCheckEntry) { + let attempt = match pr_item { + MergeabilityCheckEntry::Single(PullRequestToCheck { attempt, .. }) => attempt, + MergeabilityCheckEntry::Batch(PullRequestBatchToCheck { attempt, .. }) => attempt, + }; // First attempt = BASE_DELAY // Second attempt = BASE_DELAY * 2 @@ -399,35 +467,58 @@ impl MergeabilityQueueSender { let expiration = Some(Instant::now() + delay); let next_attempt = attempt + 1; - self.insert_pr_item( - PullRequestToCheck { - pull_request, - attempt: next_attempt, - priority, - conflict_source, - }, - expiration, - ); + let updated_item = match pr_item { + MergeabilityCheckEntry::Single(pr) => { + MergeabilityCheckEntry::Single(PullRequestToCheck { + attempt: next_attempt, + ..pr + }) + } + MergeabilityCheckEntry::Batch(batch) => { + MergeabilityCheckEntry::Batch(PullRequestBatchToCheck { + attempt: next_attempt, + ..batch + }) + } + }; + + self.insert_pr_item(updated_item, expiration) } - fn insert_pr_item(&self, item: PullRequestToCheck, expiration: Option) { + fn insert_pr_item(&self, item: MergeabilityCheckEntry, expiration: Option) { let mut queues = self.inner.queues.lock().unwrap(); - // Make sure that we don't ever put the same pull request twice into the queues + // Make sure that we don't ever put the same entry twice into the queues // This might seem a bit inefficient, but linearly iterating through e.g. 1000 PRs should // be fine. - // We could maybe reset the attempt counter of the PR if it's "refreshed" from the outside, + // We could maybe reset the attempt counter of the entry if it's "refreshed" from the outside, // but that would require using e.g. Cell to mutate the attempt counter through &, which // doesn't seem necessary at the moment. - if queues - .values() - .flat_map(|queue| queue.iter()) - .any(|entry| entry.0.entry.pull_request == item.pull_request) - { + if queues.values().flat_map(|queue| queue.iter()).any(|entry| { + match (&entry.0.entry, &item) { + ( + MergeabilityCheckEntry::Single(entry_pr), + MergeabilityCheckEntry::Single(item_pr), + ) => entry_pr.pull_request == item_pr.pull_request, + ( + MergeabilityCheckEntry::Batch(entry_batch), + MergeabilityCheckEntry::Batch(item_batch), + ) => { + entry_batch.repo == item_batch.repo + && entry_batch.base_branch == item_batch.base_branch + } + _ => false, + } + }) { return; } - let queue = queues.entry(item.priority).or_default(); + let queue = queues + .entry(match &item { + MergeabilityCheckEntry::Batch(_) => MergeabilityCheckPriority(0), + MergeabilityCheckEntry::Single(pr) => pr.priority, + }) + .or_default(); // Notify when: // 1. The current item expires sooner than the head of the queue or has higher @@ -455,7 +546,7 @@ impl MergeabilityQueueSender { impl MergeabilityQueueReceiver { /// Get the next item from the queue. - pub async fn dequeue(&self) -> Option<(PullRequestToCheck, MergeabilityQueueSender)> { + pub async fn dequeue(&self) -> Option<(MergeabilityCheckEntry, MergeabilityQueueSender)> { loop { match self.peek_inner() { // Shutdown signal @@ -555,63 +646,132 @@ impl MergeabilityQueueReceiver { pub async fn check_mergeability( ctx: Arc, mq_tx: MergeabilityQueueSender, - mq_item: PullRequestToCheck, + mq_item: MergeabilityCheckEntry, ) -> anyhow::Result<()> { - let PullRequestToCheck { - ref pull_request, - ref attempt, - priority: _, - conflict_source, - } = mq_item; - - if *attempt >= MAX_RETRIES { - tracing::warn!("Exceeded max mergeable state attempts for PR: {pull_request}"); - return Ok(()); - } + match mq_item { + MergeabilityCheckEntry::Single(PullRequestToCheck { + ref pull_request, + ref attempt, + conflict_source, + .. + }) => { + if *attempt >= MAX_RETRIES { + tracing::warn!("Exceeded max mergeable state attempts for PR: {pull_request}"); + return Ok(()); + } - let repo_state = ctx.get_repo(&pull_request.repo)?; + let repo_state = ctx.get_repo(&pull_request.repo)?; - // Load the PR from GitHub. - // - If the PR's mergeability is unknown, and the GH background job hasn't been started yet, - // this PR fetch will trigger its start. - // - If the PR mergeability is known, we will be able to read it and update it in the DB. - let fetched_pr = repo_state - .client - .get_pull_request(pull_request.pr_number) - .await?; - let new_mergeable_state = fetched_pr.mergeable_state.clone(); - - // We don't know the mergeability state yet. Retry the PR after some delay - if new_mergeable_state == OctocrabMergeableState::Unknown { - match &fetched_pr.status { - PullRequestStatus::Open | PullRequestStatus::Draft => { - tracing::info!("Mergeability status unknown, scheduling retry."); - mq_tx.enqueue_retry(mq_item); - } - PullRequestStatus::Closed | PullRequestStatus::Merged => { - tracing::info!("Mergeability status unknown, but pull request is no longer open."); + // Load the PR from GitHub. + // - If the PR's mergeability is unknown, and the GH background job hasn't been started yet, + // this PR fetch will trigger its start. + // - If the PR mergeability is known, we will be able to read it and update it in the DB. + let fetched_pr = repo_state + .client + .get_pull_request(pull_request.pr_number) + .await?; + let new_mergeable_state = fetched_pr.mergeable_state.clone(); + + // We don't know the mergeability state yet. Retry the PR after some delay + if new_mergeable_state == OctocrabMergeableState::Unknown { + match &fetched_pr.status { + PullRequestStatus::Open | PullRequestStatus::Draft => { + tracing::info!("Mergeability status unknown, scheduling retry."); + mq_tx.enqueue_retry(mq_item); + } + PullRequestStatus::Closed | PullRequestStatus::Merged => { + tracing::info!( + "Mergeability status unknown, but pull request is no longer open." + ); + } + } + + return Ok(()); + } else if let Some(db_pr) = ctx + .db + .get_pull_request(repo_state.repository(), fetched_pr.number) + .await? + { + update_pr_with_known_mergeability( + &repo_state, + &ctx.db, + &fetched_pr.into(), + &db_pr, + conflict_source, + ) + .await?; + } else { + tracing::warn!("Cannot find DB pull request for {fetched_pr:?}"); } - } - return Ok(()); - } else if let Some(db_pr) = ctx - .db - .get_pull_request(repo_state.repository(), fetched_pr.number) - .await? - { - update_pr_with_known_mergeability( - &repo_state, - &ctx.db, - &fetched_pr, - &db_pr, + Ok(()) + } + MergeabilityCheckEntry::Batch(PullRequestBatchToCheck { + ref repo, + ref base_branch, + ref attempt, conflict_source, - ) - .await?; - } else { - tracing::warn!("Cannot find DB pull request for {fetched_pr:?}"); - } + }) => { + if *attempt >= BATCH_MAX_RETRIES { + tracing::warn!( + "Exceeded max mergeable state attempts for batch in repo {repo} with base branch: {base_branch}" + ); + return Ok(()); + } - Ok(()) + let repo_state = ctx.get_repo(repo)?; + + let mut after = None; + loop { + let (fetched_prs, cursor) = repo_state + .client + .get_pull_requests_batch(base_branch, after.as_deref()) + .await?; + + for pr in fetched_prs { + let new_mergeable_state = pr.mergeable_state.clone(); + if new_mergeable_state == OctocrabMergeableState::Unknown { + match pr.status { + PullRequestStatus::Open | PullRequestStatus::Draft => { + tracing::info!("Mergeability status unknown, scheduling retry."); + mq_tx.enqueue_retry(mq_item); + return Ok(()); + } + PullRequestStatus::Closed | PullRequestStatus::Merged => { + tracing::info!( + "Mergeability status unknown, but pull request is no longer open." + ); + } + } + } else if let Some(db_pr) = ctx + .db + .get_pull_request(repo_state.repository(), pr.number) + .await? + { + update_pr_with_known_mergeability( + &repo_state, + &ctx.db, + &pr, + &db_pr, + conflict_source, + ) + .await?; + } else { + let pr_number = pr.number; + tracing::warn!("Cannot find DB pull request for {repo}#{pr_number}"); + } + } + + if cursor.is_some() { + after = cursor; + } else { + break; + } + } + + Ok(()) + } + } } /// This method should be called once we learn about the mergeability status of a PR from GitHub @@ -625,7 +785,7 @@ pub async fn check_mergeability( pub async fn update_pr_with_known_mergeability( repo: &RepositoryState, db: &PgDbClient, - gh_pr: &PullRequest, + gh_pr: &PullRequestSummary, db_pr: &PullRequestModel, conflict_source: Option, ) -> anyhow::Result<()> { @@ -732,8 +892,8 @@ pub async fn set_pr_mergeability_based_on_user_action( #[cfg(test)] mod tests { use crate::bors::mergeability_queue::{ - BASE_DELAY, MergeabilityCheckPriority, MergeabilityQueueSender, PullRequestData, - PullRequestToCheck, create_mergeability_queue, + BASE_DELAY, MergeabilityCheckEntry, MergeabilityCheckPriority, MergeabilityQueueSender, + PullRequestData, PullRequestToCheck, create_mergeability_queue, }; use crate::github::{GithubRepoName, PullRequestNumber}; use crate::tests::default_repo_name; @@ -746,10 +906,10 @@ mod tests { item(2).enqueue(&tx); for expected in [1, 2, 3] { - assert_eq!( - rx.dequeue().await.unwrap().0.pull_request.pr_number.0, - expected - ); + let MergeabilityCheckEntry::Single(pr) = rx.dequeue().await.unwrap().0 else { + unreachable!() + }; + assert_eq!(pr.pull_request.pr_number.0, expected); } } @@ -760,10 +920,10 @@ mod tests { item(2).enqueue(&tx); for expected in [2, 10] { - assert_eq!( - rx.dequeue().await.unwrap().0.pull_request.pr_number.0, - expected - ); + let MergeabilityCheckEntry::Single(pr) = rx.dequeue().await.unwrap().0 else { + unreachable!() + }; + assert_eq!(pr.pull_request.pr_number.0, expected); } } @@ -800,10 +960,10 @@ mod tests { item(1).enqueue(&tx); for expected in [3, 1] { - assert_eq!( - rx.dequeue().await.unwrap().0.pull_request.pr_number.0, - expected - ); + let MergeabilityCheckEntry::Single(pr) = rx.dequeue().await.unwrap().0 else { + unreachable!() + }; + assert_eq!(pr.pull_request.pr_number.0, expected); } } @@ -814,17 +974,20 @@ mod tests { item(1).enqueue(&tx); item(2).enqueue(&tx); - assert_eq!(rx.dequeue().await.unwrap().0.pull_request.pr_number.0, 1); + let MergeabilityCheckEntry::Single(pr) = rx.dequeue().await.unwrap().0 else { + unreachable!() + }; + assert_eq!(pr.pull_request.pr_number.0, 1); // Wait for the higher priority item to have expiration set tokio::time::sleep(BASE_DELAY * 2).await; // And check that it is returned before the immediate item with lower priority for expected in [3, 2] { - assert_eq!( - rx.dequeue().await.unwrap().0.pull_request.pr_number.0, - expected - ); + let MergeabilityCheckEntry::Single(pr) = rx.dequeue().await.unwrap().0 else { + unreachable!() + }; + assert_eq!(pr.pull_request.pr_number.0, expected); } } @@ -853,7 +1016,7 @@ mod tests { } fn enqueue_retry(&self, tx: &MergeabilityQueueSender, attempt: u32) { - tx.enqueue_retry(PullRequestToCheck { + tx.enqueue_retry(MergeabilityCheckEntry::Single(PullRequestToCheck { pull_request: PullRequestData { repo: self.repo.clone(), pr_number: self.number, @@ -861,7 +1024,7 @@ mod tests { attempt, priority: self.priority, conflict_source: None, - }); + })); } } diff --git a/src/config.rs b/src/config.rs index 6c0fa97f..056c4307 100644 --- a/src/config.rs +++ b/src/config.rs @@ -4,7 +4,7 @@ use std::time::Duration; use serde::de::Error; use serde::{Deserialize, Deserializer}; -use crate::github::{LabelModification, LabelTrigger, PullRequest}; +use crate::github::{LabelModification, LabelTrigger, PullRequestSummary}; pub const CONFIG_FILE_PATH: &str = "rust-bors.toml"; @@ -84,7 +84,7 @@ impl LabelOperation { &self.modifications } - pub fn should_apply_to(&self, pr: &PullRequest) -> bool { + pub fn should_apply_to(&self, pr: &PullRequestSummary) -> bool { // If there is any overlap, do not apply the operation !self .unless_has_labels diff --git a/src/github/api/client.rs b/src/github/api/client.rs index b5b0b958..25b78fdf 100644 --- a/src/github/api/client.rs +++ b/src/github/api/client.rs @@ -1,6 +1,8 @@ use anyhow::Context; +use chrono::{DateTime, Utc}; use octocrab::Octocrab; use octocrab::models::checks::CheckRun; +use octocrab::models::pulls::MergeableState; use octocrab::models::{CheckRunId, Repository, RunId}; use octocrab::params::checks::{CheckRunConclusion, CheckRunStatus}; use serde::{Deserialize, Serialize}; @@ -10,7 +12,7 @@ use tracing::log; use crate::PgDbClient; use crate::bors::event::PullRequestComment; -use crate::bors::{Comment, WorkflowRun}; +use crate::bors::{Comment, PullRequestStatus, WorkflowRun}; use crate::config::{CONFIG_FILE_PATH, RepositoryConfig, deserialize_config}; use crate::database::WorkflowStatus; use crate::github::api::CommitAuthor; @@ -18,7 +20,9 @@ use crate::github::api::operations::{ BranchUpdateError, Commit, CommitCreateError, ForcePush, MergeError, create_branch, create_check_run, create_commit, merge_branches, set_branch_to_commit, update_check_run, }; -use crate::github::{CommitSha, GithubRepoName, PullRequest, PullRequestNumber, TreeSha}; +use crate::github::{ + CommitSha, GithubRepoName, PullRequest, PullRequestNumber, PullRequestSummary, TreeSha, +}; use crate::utils::timing::{RetryMethod, RetryableOpError, ShouldRetry, perform_retryable}; use futures::TryStreamExt; use octocrab::models::workflows::{Job, Run}; @@ -773,6 +777,153 @@ impl GithubRepositoryClient { .await?; Ok(()) } + + /// Resolve pull requests from this repository as a batch. + pub async fn get_pull_requests_batch( + &self, + base_branch: &str, + after: Option<&str>, + ) -> anyhow::Result<(Vec, Option)> { + const QUERY: &str = r#" + query($owner: String!, $name: String!, $base_branch: String!, $after: String) { + repository(owner: $owner, name: $name) { + pullRequests( + baseRefName: $base_branch, + states: [OPEN], + first: 100, + after: $after, + orderBy: {field: CREATED_AT, direction: DESC}, + ) { + nodes { + number + mergedAt + closedAt + isDraft + mergeable + labels(first: 100){ + nodes { + name + } + pageInfo { + endCursor + hasNextPage + } + } + } + pageInfo { + endCursor + hasNextPage + } + } + } + } + "#; + #[derive(serde::Serialize)] + struct Variables<'v> { + owner: &'v str, + name: &'v str, + base_branch: &'v str, + after: Option<&'v str>, + } + + #[derive(serde::Deserialize)] + struct Output { + data: OutputInner, + } + #[derive(serde::Deserialize)] + struct OutputInner { + repository: RepositoryNode, + } + #[derive(serde::Deserialize)] + struct RepositoryNode { + #[serde(rename = "pullRequests")] + pull_requests: Connection, + } + #[derive(serde::Deserialize)] + struct Connection { + nodes: Vec, + #[serde(rename = "pageInfo")] + page_info: PageInfo, + } + #[derive(serde::Deserialize)] + struct PageInfo { + #[serde(rename = "hasNextPage")] + has_next_page: bool, + #[serde(rename = "endCursor")] + end_cursor: Option, + } + #[derive(serde::Deserialize)] + struct PullRequestNode { + number: u64, + #[serde(rename = "mergedAt")] + merged_at: Option>, + #[serde(rename = "closedAt")] + closed_at: Option>, + #[serde(rename = "isDraft")] + is_draft: bool, + mergeable: GraphQLMergeableState, + labels: Connection, + } + #[derive(serde::Deserialize)] + enum GraphQLMergeableState { + #[serde(rename = "MERGEABLE")] + Mergeable, + #[serde(rename = "CONFLICTING")] + Conflicting, + #[serde(rename = "UNKNOWN")] + Unknown, + } + + #[derive(serde::Deserialize, Clone)] + struct LabelNode { + name: String, + } + + let vars = Variables { + owner: self.repo_name.owner(), + name: self.repo_name.name(), + base_branch, + after, + }; + + let response: Output = self.graphql(QUERY, vars).await?; + + Ok(( + response + .data + .repository + .pull_requests + .nodes + .into_iter() + .map(|n| PullRequestSummary { + number: PullRequestNumber(n.number), + status: if n.merged_at.is_some() { + PullRequestStatus::Merged + } else if n.closed_at.is_some() { + PullRequestStatus::Closed + } else if n.is_draft { + PullRequestStatus::Draft + } else { + PullRequestStatus::Open + }, + mergeable_state: match n.mergeable { + GraphQLMergeableState::Mergeable => MergeableState::Clean, + GraphQLMergeableState::Conflicting => MergeableState::Dirty, + GraphQLMergeableState::Unknown => MergeableState::Unknown, + }, + labels: n.labels.nodes.into_iter().map(|label| label.name).collect(), + }) + .collect(), + response + .data + .repository + .pull_requests + .page_info + .has_next_page + .then_some(response.data.repository.pull_requests.page_info.end_cursor) + .flatten(), + )) + } } /// The reasons a piece of content can be reported or hidden. diff --git a/src/github/mod.rs b/src/github/mod.rs index 29ee80e5..077d5be9 100644 --- a/src/github/mod.rs +++ b/src/github/mod.rs @@ -142,6 +142,31 @@ pub struct Branch { pub sha: CommitSha, } +#[derive(Debug)] +pub struct PullRequestSummary { + pub number: PullRequestNumber, + pub status: PullRequestStatus, + pub mergeable_state: MergeableState, + pub labels: Vec, +} + +impl From for PullRequestSummary { + fn from(pr: PullRequest) -> Self { + Self { + number: pr.number, + status: pr.status, + mergeable_state: pr.mergeable_state, + labels: pr.labels, + } + } +} + +impl From for PullRequestSummary { + fn from(pr: octocrab::models::pulls::PullRequest) -> Self { + PullRequest::from(pr).into() + } +} + #[derive(Clone, Debug)] pub struct PullRequest { pub number: PullRequestNumber, diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 29e39178..5cffb5a8 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -466,7 +466,7 @@ impl BorsTester { #[allow(unused)] pub fn dump_mergeability_queue(&self) { eprintln!("Mergeability queue contents:"); - let prs = self.senders.mergeability_queue().get_queue_prs(); + let prs = self.senders.mergeability_queue().get_queue_entries(); for pr in prs { eprintln!("{pr:?}"); } From 0bfda8fc55f720a2befbc9e9e56d70e8bf26494a Mon Sep 17 00:00:00 2001 From: kyokuping Date: Thu, 23 Apr 2026 23:06:04 +0900 Subject: [PATCH 2/2] Add mock for batched mergability check query --- src/tests/mock/mod.rs | 59 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/src/tests/mock/mod.rs b/src/tests/mock/mod.rs index c0e4398c..de384c49 100644 --- a/src/tests/mock/mod.rs +++ b/src/tests/mock/mod.rs @@ -1,3 +1,5 @@ +use crate::bors::PullRequestStatus; +use crate::github::GithubRepoName; use crate::github::api::client::HideCommentReason; use crate::tests::github::CommentMsg; use crate::tests::mock::app::{AppHandler, default_app_id}; @@ -9,6 +11,7 @@ use graphql_parser::query::{Definition, Document, OperationDefinition, Selection use http::HeaderValue; use http::header::AUTHORIZATION; use octocrab::Octocrab; +use octocrab::models::pulls::MergeableState; use parking_lot::Mutex; use regex::Regex; use std::collections::HashMap; @@ -319,6 +322,62 @@ async fn mock_graphql(github: Arc>, mock_server: &MockServer) { }); ResponseTemplate::new(200).set_body_json(response) } + "repository" => { + #[derive(serde::Deserialize)] + struct Variables { + owner: String, + name: String, + base_branch: String, + } + let data: Variables = serde_json::from_value(body.variables).unwrap(); + let prs = github + .lock() + .get_repo(GithubRepoName::new(&data.owner, &data.name)) + .lock() + .pulls() + .values() + .filter(|pr| { + pr.merged_at.is_none() + && pr.closed_at.is_none() + && pr.base_branch.name() == data.base_branch + }) + .take(100) + .map(|pr| { + serde_json::json!({ + "number": pr.number.0, + "mergedAt": pr.merged_at, + "closedAt": pr.closed_at, + "isDraft": pr.status == PullRequestStatus::Draft, + "mergeable": match pr.mergeable_state { + MergeableState::Unknown => "UNKNOWN", + MergeableState::Clean => "MERGEABLE", + _ => "CONFLICTING", + }, + "labels": { + "nodes": pr.labels.iter().map(|label| serde_json::json!({ "name": label })).collect::>(), + "pageInfo": { + "endCursor": "", + "hasNextPage": false + } + } + }) + }) + .collect::>(); + let response = serde_json::json!({ + "data": { + "repository": { + "pullRequests": { + "nodes": prs, + "pageInfo": { + "endCursor": "", + "hasNextPage": false + } + } + } + } + }); + ResponseTemplate::new(200).set_body_json(response) + } _ => panic!("Unexpected GraphQL operation {}", operation.name), } },