From 1c7dc5efb3c79150f5010af782de1eb28a7ce055 Mon Sep 17 00:00:00 2001 From: Ana Gelez Date: Wed, 25 Sep 2024 11:45:14 +0200 Subject: [PATCH] Edit PR title if needed --- src/github.rs | 111 +++++++++++++++++++++++++++++++++++----- src/github/api.rs | 5 ++ src/github/api/check.rs | 3 ++ src/github/api/pr.rs | 85 ++++++++++++++++++++++++++++++ 4 files changed, 190 insertions(+), 14 deletions(-) create mode 100644 src/github/api/pr.rs diff --git a/src/github.rs b/src/github.rs index 877c4ec..4369121 100644 --- a/src/github.rs +++ b/src/github.rs @@ -18,7 +18,8 @@ use codespan_reporting::{ use eyre::Context; use hook::CheckRunPayload; use jwt_simple::prelude::*; -use tracing::{debug, info, warn}; +use pr::{AnyPullRequest, MinimalPullRequest, PullRequest, PullRequestUpdate}; +use tracing::{debug, error, info, warn}; use typst::syntax::{package::PackageSpec, FileId}; use crate::{check, world::SystemWorld}; @@ -86,9 +87,24 @@ async fn index() -> &'static str { async fn force( state: State, api_client: GitHub, - axum::extract::Path((install, sha)): axum::extract::Path<(String, String)>, + axum::extract::Path((install, pr)): axum::extract::Path<(String, usize)>, ) -> Result<&'static str, &'static str> { - debug!("Force review for {sha}"); + debug!("Force review for #{pr}"); + let repository = Repository::new("typst/packages").map_err(|e| { + error!("{}", e); + "Invalid repository path" + })?; + + let pr = MinimalPullRequest { number: pr }; + let full_pr = pr + .get_full(&api_client, repository.owner(), repository.name()) + .await + .map_err(|e| { + error!("{}", e); + "Failed to fetch PR context" + })?; + let sha = full_pr.head.sha.clone(); + github_hook( state, api_client, @@ -97,11 +113,11 @@ async fn force( installation: Installation { id: str::parse(&install).map_err(|_| "Invalid installation ID")?, }, - repository: Repository::new("typst/packages").map_err(|e| { - debug!("{}", e); - "Invalid repository path" - })?, - check_suite: CheckSuite { head_sha: sha }, + repository, + check_suite: CheckSuite { + head_sha: sha, + pull_requests: vec![AnyPullRequest::Full(full_pr)], + }, }), ) .await @@ -123,28 +139,50 @@ async fn github_hook( api_client.auth_installation(&payload).await?; debug!("Successfully authenticated application"); - let (head_sha, repository, previous_check_run) = match payload { + let (head_sha, repository, pr, previous_check_run) = match payload { HookPayload::CheckSuite(CheckSuitePayload { action: CheckSuiteAction::Requested | CheckSuiteAction::Rerequested, repository, - check_suite, + mut check_suite, .. - }) => (check_suite.head_sha, repository, None), + }) => ( + check_suite.head_sha, + repository, + check_suite.pull_requests.pop(), + None, + ), HookPayload::CheckRun(CheckRunPayload { action: CheckRunAction::Rerequested, repository, - check_run, + mut check_run, .. }) => ( check_run.check_suite.head_sha.clone(), repository, + check_run.check_suite.pull_requests.pop(), Some(check_run), ), HookPayload::CheckRun(_) => return Ok(()), _ => return Err(WebError::UnexpectedEvent), }; - debug!("Starting checks for {}", head_sha); + let pr = if let Some(pr) = pr { + pr.get_full(&api_client, repository.owner(), repository.name()) + .await + .ok() + } else { + None + }; + + debug!( + "Starting checks for {}{}", + head_sha, + if let Some(ref pr) = pr { + format!(" (#{})", pr.number) + } else { + String::new() + } + ); tokio::spawn(async move { async fn inner( state: AppState, @@ -152,6 +190,7 @@ async fn github_hook( api_client: GitHub, repository: Repository, previous_check_run: Option, + pr: Option, ) -> eyre::Result<()> { let git_repo = GitRepo::open(Path::new(&state.git_dir)); git_repo.pull_main().await?; @@ -180,6 +219,39 @@ async fn github_hook( }) .collect::>(); + if let Some(pr) = pr { + let mut package_names = touched_packages + .iter() + .map(|p| format!("{}:{}", p.name, p.version)) + .collect::>(); + package_names.sort(); + let last_package = package_names.pop(); + let penultimate_package = package_names.pop(); + let expected_pr_title = if let Some((penultimate_package, last_package)) = + penultimate_package.as_ref().zip(last_package.as_ref()) + { + package_names.push(format!("{} and {}", penultimate_package, last_package)); + Some(package_names.join(", ")) + } else { + last_package + }; + if let Some(expected_pr_title) = expected_pr_title { + if pr.title != expected_pr_title { + api_client + .update_pull_request( + repository.owner(), + repository.name(), + pr.number, + PullRequestUpdate { + title: expected_pr_title, + }, + ) + .await + .context("Failed to update pull request")?; + } + } + } + for ref package in touched_packages { let check_run_name = format!( "@{}/{}:{}", @@ -322,7 +394,16 @@ async fn github_hook( Ok(()) } - if let Err(e) = inner(state, head_sha, api_client, repository, previous_check_run).await { + if let Err(e) = inner( + state, + head_sha, + api_client, + repository, + previous_check_run, + pr, + ) + .await + { warn!("Error in hook handler: {:#}", e) } }); @@ -380,6 +461,8 @@ enum WebError { impl IntoResponse for WebError { fn into_response(self) -> axum::response::Response { + debug!("Web error: {:?}", &self); + Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Body::from(format!("{:?}", self))) diff --git a/src/github/api.rs b/src/github/api.rs index 8fc4bd8..f081feb 100644 --- a/src/github/api.rs +++ b/src/github/api.rs @@ -23,6 +23,7 @@ use super::AppState; pub mod check; pub mod hook; +pub mod pr; #[derive(Debug)] pub enum ApiError { @@ -137,6 +138,10 @@ impl GitHub { Ok(()) } + fn get(&self, url: impl AsRef) -> RequestBuilder { + self.with_headers(self.req.get(Self::url(url))) + } + fn patch(&self, url: impl AsRef) -> RequestBuilder { self.with_headers(self.req.patch(Self::url(url))) } diff --git a/src/github/api/check.rs b/src/github/api/check.rs index 322d1e1..888f708 100644 --- a/src/github/api/check.rs +++ b/src/github/api/check.rs @@ -2,6 +2,8 @@ use std::fmt::Display; use serde::{Deserialize, Serialize}; +use super::pr::AnyPullRequest; + #[derive(Debug, Deserialize, Clone, Copy)] #[serde(transparent)] pub struct CheckSuiteId(#[allow(dead_code)] u64); @@ -9,6 +11,7 @@ pub struct CheckSuiteId(#[allow(dead_code)] u64); #[derive(Clone, Deserialize)] pub struct CheckSuite { pub head_sha: String, + pub pull_requests: Vec, } #[derive(Clone, Deserialize)] diff --git a/src/github/api/pr.rs b/src/github/api/pr.rs new file mode 100644 index 0000000..4b38a64 --- /dev/null +++ b/src/github/api/pr.rs @@ -0,0 +1,85 @@ +use serde::{Deserialize, Serialize}; + +use super::{ApiError, GitHub, OwnerId, RepoId}; + +#[derive(Clone, Deserialize)] +pub struct MinimalPullRequest { + pub number: usize, +} + +impl MinimalPullRequest { + pub async fn get_full( + &self, + api: &GitHub, + owner: OwnerId, + repo: RepoId, + ) -> Result { + Ok(api + .get(format!( + "repos/{owner}/{repo}/pulls/{pull_number}", + owner = owner, + repo = repo, + pull_number = self.number + )) + .send() + .await? + .json() + .await?) + } +} + +#[derive(Clone, Deserialize)] +pub struct PullRequest { + pub number: usize, + pub head: Commit, + pub title: String, +} + +#[derive(Clone, Deserialize)] +#[serde(untagged)] +pub enum AnyPullRequest { + Full(PullRequest), + Minimal(MinimalPullRequest), +} + +impl AnyPullRequest { + pub async fn get_full( + self, + api: &GitHub, + owner: OwnerId, + repo: RepoId, + ) -> Result { + match self { + AnyPullRequest::Full(pr) => Ok(pr), + AnyPullRequest::Minimal(pr) => pr.get_full(api, owner, repo).await, + } + } +} + +#[derive(Clone, Deserialize)] +pub struct Commit { + pub sha: String, +} + +#[derive(Serialize)] +pub struct PullRequestUpdate { + pub title: String, +} + +impl GitHub { + pub async fn update_pull_request( + &self, + owner: OwnerId, + repo: RepoId, + pr: usize, + update: PullRequestUpdate, + ) -> Result { + Ok(self + .patch(format!("{}/{}/pulls/{}", owner, repo, pr)) + .json(&update) + .send() + .await? + .json() + .await?) + } +}