From 14751f8db1d6f358272b2e9ae1b201f5b0b2a092 Mon Sep 17 00:00:00 2001 From: qpismont Date: Tue, 2 Jun 2026 18:42:59 +0000 Subject: [PATCH] add check for action and bot_name --- src/api.rs | 22 +++------- src/env.rs | 8 ++-- src/errors.rs | 22 ++++++---- src/gitea.rs | 113 ++++++++++++++++++++++++++++++++++++++++++++------ 4 files changed, 126 insertions(+), 39 deletions(-) diff --git a/src/api.rs b/src/api.rs index 6e0e66f..a62d234 100644 --- a/src/api.rs +++ b/src/api.rs @@ -15,10 +15,12 @@ use crate::state::AppState; pub async fn start(app_state: AppState) -> anyhow::Result<()> { let http_port = app_state.config.http_port; + let app = Router::new() .route("/", get(root)) .route("/webhook", post(webhook)) .with_state(app_state); + let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", http_port)).await?; axum::serve(listener, app) .await @@ -56,9 +58,7 @@ where &body_bytes, )?; - let webhook = parse_webhook(&type_header, &body_bytes)?; - reject_bot_user(&app_state, &webhook)?; - + let webhook = parse_webhook(&type_header, &app_state.config.bot_name, &body_bytes)?; Ok(WebhookExtract(webhook)) } } @@ -69,6 +69,7 @@ fn extract_header(key: &str, headers: &axum::http::HeaderMap) -> Result Result { .map_err(AppError::from) } -fn parse_webhook(header: &str, body_bytes: &[u8]) -> Result { +fn parse_webhook(header: &str, bot_name: &str, body_bytes: &[u8]) -> Result { let Json(value) = Json::::from_bytes(body_bytes).map_err(|_| AppError::MalformedJsonErr)?; - WebhookType::from_event(header, value) -} - -fn reject_bot_user(state: &AppState, webhook: &WebhookType) -> Result<(), AppError> { - let user_id = match webhook { - WebhookType::Review(review_payload) => review_payload.comment.user.id, - }; - - match user_id != state.config.bot_user_id { - true => Ok(()), - false => Err(AppError::UnauthorizedUserIdErr), - } + WebhookType::from_event(header, bot_name, value) } fn verify_signature(secret_key: &[u8], sig_header: &str, body: &[u8]) -> Result<(), AppError> { diff --git a/src/env.rs b/src/env.rs index 6b8ef2a..e4a8d1a 100644 --- a/src/env.rs +++ b/src/env.rs @@ -6,21 +6,21 @@ pub struct EnvConfig { pub http_port: u16, pub webhook_secret: String, pub open_router_api_key: String, - pub bot_user_id: u64, + pub bot_name: String, } pub fn load_config() -> anyhow::Result { dotenv().ok(); let http_port = try_get_env("HTTP_PORT")?.parse()?; - let bot_user_id = try_get_env("BOT_USER_ID")?.parse()?; + let bot_name = try_get_env("BOT_NAME")?; let webhook_secret = try_get_env("WEBHOOK_SIG_HEADER_SECRET")?; let open_router_api_key = try_get_env("OPEN_ROUTER_API_KEY")?; Ok(EnvConfig { http_port, webhook_secret, - bot_user_id, + bot_name, open_router_api_key, }) } @@ -29,7 +29,7 @@ fn try_get_env(key: &str) -> anyhow::Result { let env = std::env::var(key)?; if env.trim().is_empty() { - return Err(anyhow!(format!("env var {} is empty", env))); + return Err(anyhow!(format!("env var {} is empty", key))); } Ok(env) diff --git a/src/errors.rs b/src/errors.rs index 459eb42..95b70ae 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -4,7 +4,7 @@ use reqwest::StatusCode; #[derive(thiserror::Error, Debug)] pub enum AppError { #[error("Unauthorized user id")] - UnauthorizedUserIdErr, + UnauthorizedUserErr, #[error("Unknow gitea event")] UnknownEventErr, @@ -12,15 +12,18 @@ pub enum AppError { #[error("Malformed Json")] MalformedJsonErr, - #[error(transparent)] - BadJsonStructErr(#[from] serde_json::Error), - #[error("WebHook header not found")] WebHookMissingHeaderErr(String), #[error("WebHook sig header is invalid")] WebHookSigHeaderInvalidErr, + #[error("WebHook have bad action")] + InvalidActionErr, + + #[error(transparent)] + BadJsonStructErr(#[from] serde_json::Error), + #[error(transparent)] Other(#[from] anyhow::Error), } @@ -28,12 +31,17 @@ pub enum AppError { impl IntoResponse for AppError { fn into_response(self) -> axum::response::Response { match self { + AppError::InvalidActionErr => ( + StatusCode::UNPROCESSABLE_ENTITY, + "WebHook have bad action".to_string(), + ), AppError::UnknownEventErr => { (StatusCode::BAD_REQUEST, "Unknow gitea event".to_string()) } - AppError::UnauthorizedUserIdErr => { - (StatusCode::BAD_REQUEST, "Unauthorized user id".to_string()) - } + AppError::UnauthorizedUserErr => ( + StatusCode::UNAUTHORIZED, + "Unauthorized user name".to_string(), + ), AppError::MalformedJsonErr => (StatusCode::BAD_REQUEST, "Malformed Json".to_string()), AppError::BadJsonStructErr(err) => ( StatusCode::BAD_REQUEST, diff --git a/src/gitea.rs b/src/gitea.rs index b82af15..b5ab824 100644 --- a/src/gitea.rs +++ b/src/gitea.rs @@ -18,6 +18,7 @@ pub struct ReviewPayload { #[derive(Deserialize, Debug)] pub struct PullRequest { pub id: u64, + pub diff_url: String, } #[derive(Deserialize, Debug)] @@ -33,11 +34,29 @@ pub struct User { } impl WebhookType { - pub fn from_event(event: &str, json: Value) -> Result { - match event { + pub fn from_event(event: &str, bot_name: &str, json: Value) -> Result { + let wb = match event { "pull_request_comment" => Ok(WebhookType::Review(serde_json::from_value(json)?)), _ => Err(AppError::UnknownEventErr), + }?; + + let pr_body = match &wb { + WebhookType::Review(review_payload) => &review_payload.comment.body, + }; + + if !pr_body.starts_with(&format!("@{}", bot_name)) { + return Err(AppError::UnauthorizedUserErr); } + + let action = match &wb { + WebhookType::Review(review_payload) => &review_payload.action, + }; + + if action != "created" { + return Err(AppError::InvalidActionErr); + } + + Ok(wb) } } @@ -51,18 +70,19 @@ mod tests { let json = json!({ "action": "created", "pull_request": { - "id": 42 + "id": 42, + "diff_url": "https://mydiff.fr" }, "comment": { "id": 7, - "body": "LGTM", + "body": "@test_bot LGTM", "user": { "id": 100 } } }); - let result = WebhookType::from_event("pull_request_comment", json); + let result = WebhookType::from_event("pull_request_comment", "test_bot", json); assert!(result.is_ok()); match result.unwrap() { @@ -70,7 +90,7 @@ mod tests { assert_eq!(payload.action, "created"); assert_eq!(payload.pull_request.id, 42); assert_eq!(payload.comment.id, 7); - assert_eq!(payload.comment.body, "LGTM"); + assert_eq!(payload.comment.body, "@test_bot LGTM"); assert_eq!(payload.comment.user.id, 100); } } @@ -79,7 +99,7 @@ mod tests { #[test] fn test_from_event_unknown_event() { let json = json!({}); - let result = WebhookType::from_event("push", json); + let result = WebhookType::from_event("push", "test_bot", json); assert!(result.is_err()); match result.unwrap_err() { @@ -95,7 +115,7 @@ mod tests { // pull_request and comment are missing }); - let result = WebhookType::from_event("pull_request_comment", json); + let result = WebhookType::from_event("pull_request_comment", "test_bot", json); assert!(result.is_err()); match result.unwrap_err() { @@ -105,11 +125,38 @@ mod tests { } #[test] - fn test_deserialize_review_payload() { + fn test_from_event_rejects_non_created_action() { let json = json!({ "action": "edited", "pull_request": { - "id": 99 + "id": 1, + "diff_url": "https://mydiff.fr" + }, + "comment": { + "id": 1, + "body": "@test_bot body", + "user": { + "id": 1 + } + } + }); + + let result = WebhookType::from_event("pull_request_comment", "test_bot", json); + assert!(result.is_err()); + + match result.unwrap_err() { + AppError::InvalidActionErr => {} + _ => panic!("expected InvalidActionErr"), + } + } + + #[test] + fn test_deserialize_review_payload() { + let json = json!({ + "action": "created", + "pull_request": { + "id": 99, + "diff_url": "https://mydiff.fr" }, "comment": { "id": 12, @@ -121,7 +168,7 @@ mod tests { }); let payload: ReviewPayload = serde_json::from_value(json).unwrap(); - assert_eq!(payload.action, "edited"); + assert_eq!(payload.action, "created"); assert_eq!(payload.pull_request.id, 99); assert_eq!(payload.comment.id, 12); assert_eq!(payload.comment.body, "Needs work"); @@ -130,8 +177,50 @@ mod tests { #[test] fn test_from_event_empty_json() { - let result = WebhookType::from_event("pull_request_comment", json!({})); + let result = WebhookType::from_event("pull_request_comment", "test_bot", json!({})); assert!(result.is_err()); assert!(matches!(result.unwrap_err(), AppError::BadJsonStructErr(_))); } + + #[test] + fn test_from_event_rejects_wrong_bot_name() { + let json = json!({ + "action": "created", + "pull_request": { + "id": 1, + "diff_url": "https://mydiff.fr" + }, + "comment": { + "id": 1, + "body": "@other_bot do something", + "user": { + "id": 1 + } + } + }); + + let result = WebhookType::from_event("pull_request_comment", "test_bot", json); + assert!(matches!(result.unwrap_err(), AppError::UnauthorizedUserErr)); + } + + #[test] + fn test_from_event_rejects_no_bot_prefix() { + let json = json!({ + "action": "created", + "pull_request": { + "id": 1, + "diff_url": "https://mydiff.fr" + }, + "comment": { + "id": 1, + "body": "just a comment without bot mention", + "user": { + "id": 1 + } + } + }); + + let result = WebhookType::from_event("pull_request_comment", "test_bot", json); + assert!(matches!(result.unwrap_err(), AppError::UnauthorizedUserErr)); + } }