add check for action and bot_name

This commit is contained in:
2026-06-02 18:42:59 +00:00
parent 0a22be252c
commit 14751f8db1
4 changed files with 126 additions and 39 deletions
+6 -16
View File
@@ -15,10 +15,12 @@ use crate::state::AppState;
pub async fn start(app_state: AppState) -> anyhow::Result<()> { pub async fn start(app_state: AppState) -> anyhow::Result<()> {
let http_port = app_state.config.http_port; let http_port = app_state.config.http_port;
let app = Router::new() let app = Router::new()
.route("/", get(root)) .route("/", get(root))
.route("/webhook", post(webhook)) .route("/webhook", post(webhook))
.with_state(app_state); .with_state(app_state);
let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", http_port)).await?; let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", http_port)).await?;
axum::serve(listener, app) axum::serve(listener, app)
.await .await
@@ -56,9 +58,7 @@ where
&body_bytes, &body_bytes,
)?; )?;
let webhook = parse_webhook(&type_header, &body_bytes)?; let webhook = parse_webhook(&type_header, &app_state.config.bot_name, &body_bytes)?;
reject_bot_user(&app_state, &webhook)?;
Ok(WebhookExtract(webhook)) Ok(WebhookExtract(webhook))
} }
} }
@@ -69,6 +69,7 @@ fn extract_header(key: &str, headers: &axum::http::HeaderMap) -> Result<String,
.ok_or(AppError::WebHookMissingHeaderErr(key.into()))? .ok_or(AppError::WebHookMissingHeaderErr(key.into()))?
.to_str() .to_str()
.map_err(anyhow::Error::from)?; .map_err(anyhow::Error::from)?;
Ok(value.to_owned()) Ok(value.to_owned())
} }
@@ -79,22 +80,11 @@ async fn read_body(body: axum::body::Body) -> Result<Bytes, AppError> {
.map_err(AppError::from) .map_err(AppError::from)
} }
fn parse_webhook(header: &str, body_bytes: &[u8]) -> Result<WebhookType, AppError> { fn parse_webhook(header: &str, bot_name: &str, body_bytes: &[u8]) -> Result<WebhookType, AppError> {
let Json(value) = let Json(value) =
Json::<Value>::from_bytes(body_bytes).map_err(|_| AppError::MalformedJsonErr)?; Json::<Value>::from_bytes(body_bytes).map_err(|_| AppError::MalformedJsonErr)?;
WebhookType::from_event(header, value) WebhookType::from_event(header, bot_name, value)
}
fn reject_bot_user(state: &AppState, webhook: &WebhookType) -> Result<(), AppError> {
let user_id = match webhook {
WebhookType::Review(review_payload) => review_payload.comment.user.id,
};
match user_id != state.config.bot_user_id {
true => Ok(()),
false => Err(AppError::UnauthorizedUserIdErr),
}
} }
fn verify_signature(secret_key: &[u8], sig_header: &str, body: &[u8]) -> Result<(), AppError> { fn verify_signature(secret_key: &[u8], sig_header: &str, body: &[u8]) -> Result<(), AppError> {
+4 -4
View File
@@ -6,21 +6,21 @@ pub struct EnvConfig {
pub http_port: u16, pub http_port: u16,
pub webhook_secret: String, pub webhook_secret: String,
pub open_router_api_key: String, pub open_router_api_key: String,
pub bot_user_id: u64, pub bot_name: String,
} }
pub fn load_config() -> anyhow::Result<EnvConfig> { pub fn load_config() -> anyhow::Result<EnvConfig> {
dotenv().ok(); dotenv().ok();
let http_port = try_get_env("HTTP_PORT")?.parse()?; let http_port = try_get_env("HTTP_PORT")?.parse()?;
let bot_user_id = try_get_env("BOT_USER_ID")?.parse()?; let bot_name = try_get_env("BOT_NAME")?;
let webhook_secret = try_get_env("WEBHOOK_SIG_HEADER_SECRET")?; let webhook_secret = try_get_env("WEBHOOK_SIG_HEADER_SECRET")?;
let open_router_api_key = try_get_env("OPEN_ROUTER_API_KEY")?; let open_router_api_key = try_get_env("OPEN_ROUTER_API_KEY")?;
Ok(EnvConfig { Ok(EnvConfig {
http_port, http_port,
webhook_secret, webhook_secret,
bot_user_id, bot_name,
open_router_api_key, open_router_api_key,
}) })
} }
@@ -29,7 +29,7 @@ fn try_get_env(key: &str) -> anyhow::Result<String> {
let env = std::env::var(key)?; let env = std::env::var(key)?;
if env.trim().is_empty() { if env.trim().is_empty() {
return Err(anyhow!(format!("env var {} is empty", env))); return Err(anyhow!(format!("env var {} is empty", key)));
} }
Ok(env) Ok(env)
+15 -7
View File
@@ -4,7 +4,7 @@ use reqwest::StatusCode;
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
pub enum AppError { pub enum AppError {
#[error("Unauthorized user id")] #[error("Unauthorized user id")]
UnauthorizedUserIdErr, UnauthorizedUserErr,
#[error("Unknow gitea event")] #[error("Unknow gitea event")]
UnknownEventErr, UnknownEventErr,
@@ -12,15 +12,18 @@ pub enum AppError {
#[error("Malformed Json")] #[error("Malformed Json")]
MalformedJsonErr, MalformedJsonErr,
#[error(transparent)]
BadJsonStructErr(#[from] serde_json::Error),
#[error("WebHook header not found")] #[error("WebHook header not found")]
WebHookMissingHeaderErr(String), WebHookMissingHeaderErr(String),
#[error("WebHook sig header is invalid")] #[error("WebHook sig header is invalid")]
WebHookSigHeaderInvalidErr, WebHookSigHeaderInvalidErr,
#[error("WebHook have bad action")]
InvalidActionErr,
#[error(transparent)]
BadJsonStructErr(#[from] serde_json::Error),
#[error(transparent)] #[error(transparent)]
Other(#[from] anyhow::Error), Other(#[from] anyhow::Error),
} }
@@ -28,12 +31,17 @@ pub enum AppError {
impl IntoResponse for AppError { impl IntoResponse for AppError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
match self { match self {
AppError::InvalidActionErr => (
StatusCode::UNPROCESSABLE_ENTITY,
"WebHook have bad action".to_string(),
),
AppError::UnknownEventErr => { AppError::UnknownEventErr => {
(StatusCode::BAD_REQUEST, "Unknow gitea event".to_string()) (StatusCode::BAD_REQUEST, "Unknow gitea event".to_string())
} }
AppError::UnauthorizedUserIdErr => { AppError::UnauthorizedUserErr => (
(StatusCode::BAD_REQUEST, "Unauthorized user id".to_string()) StatusCode::UNAUTHORIZED,
} "Unauthorized user name".to_string(),
),
AppError::MalformedJsonErr => (StatusCode::BAD_REQUEST, "Malformed Json".to_string()), AppError::MalformedJsonErr => (StatusCode::BAD_REQUEST, "Malformed Json".to_string()),
AppError::BadJsonStructErr(err) => ( AppError::BadJsonStructErr(err) => (
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
+101 -12
View File
@@ -18,6 +18,7 @@ pub struct ReviewPayload {
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub struct PullRequest { pub struct PullRequest {
pub id: u64, pub id: u64,
pub diff_url: String,
} }
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
@@ -33,11 +34,29 @@ pub struct User {
} }
impl WebhookType { impl WebhookType {
pub fn from_event(event: &str, json: Value) -> Result<Self, AppError> { pub fn from_event(event: &str, bot_name: &str, json: Value) -> Result<Self, AppError> {
match event { let wb = match event {
"pull_request_comment" => Ok(WebhookType::Review(serde_json::from_value(json)?)), "pull_request_comment" => Ok(WebhookType::Review(serde_json::from_value(json)?)),
_ => Err(AppError::UnknownEventErr), _ => Err(AppError::UnknownEventErr),
}?;
let pr_body = match &wb {
WebhookType::Review(review_payload) => &review_payload.comment.body,
};
if !pr_body.starts_with(&format!("@{}", bot_name)) {
return Err(AppError::UnauthorizedUserErr);
} }
let action = match &wb {
WebhookType::Review(review_payload) => &review_payload.action,
};
if action != "created" {
return Err(AppError::InvalidActionErr);
}
Ok(wb)
} }
} }
@@ -51,18 +70,19 @@ mod tests {
let json = json!({ let json = json!({
"action": "created", "action": "created",
"pull_request": { "pull_request": {
"id": 42 "id": 42,
"diff_url": "https://mydiff.fr"
}, },
"comment": { "comment": {
"id": 7, "id": 7,
"body": "LGTM", "body": "@test_bot LGTM",
"user": { "user": {
"id": 100 "id": 100
} }
} }
}); });
let result = WebhookType::from_event("pull_request_comment", json); let result = WebhookType::from_event("pull_request_comment", "test_bot", json);
assert!(result.is_ok()); assert!(result.is_ok());
match result.unwrap() { match result.unwrap() {
@@ -70,7 +90,7 @@ mod tests {
assert_eq!(payload.action, "created"); assert_eq!(payload.action, "created");
assert_eq!(payload.pull_request.id, 42); assert_eq!(payload.pull_request.id, 42);
assert_eq!(payload.comment.id, 7); assert_eq!(payload.comment.id, 7);
assert_eq!(payload.comment.body, "LGTM"); assert_eq!(payload.comment.body, "@test_bot LGTM");
assert_eq!(payload.comment.user.id, 100); assert_eq!(payload.comment.user.id, 100);
} }
} }
@@ -79,7 +99,7 @@ mod tests {
#[test] #[test]
fn test_from_event_unknown_event() { fn test_from_event_unknown_event() {
let json = json!({}); let json = json!({});
let result = WebhookType::from_event("push", json); let result = WebhookType::from_event("push", "test_bot", json);
assert!(result.is_err()); assert!(result.is_err());
match result.unwrap_err() { match result.unwrap_err() {
@@ -95,7 +115,7 @@ mod tests {
// pull_request and comment are missing // pull_request and comment are missing
}); });
let result = WebhookType::from_event("pull_request_comment", json); let result = WebhookType::from_event("pull_request_comment", "test_bot", json);
assert!(result.is_err()); assert!(result.is_err());
match result.unwrap_err() { match result.unwrap_err() {
@@ -105,11 +125,38 @@ mod tests {
} }
#[test] #[test]
fn test_deserialize_review_payload() { fn test_from_event_rejects_non_created_action() {
let json = json!({ let json = json!({
"action": "edited", "action": "edited",
"pull_request": { "pull_request": {
"id": 99 "id": 1,
"diff_url": "https://mydiff.fr"
},
"comment": {
"id": 1,
"body": "@test_bot body",
"user": {
"id": 1
}
}
});
let result = WebhookType::from_event("pull_request_comment", "test_bot", json);
assert!(result.is_err());
match result.unwrap_err() {
AppError::InvalidActionErr => {}
_ => panic!("expected InvalidActionErr"),
}
}
#[test]
fn test_deserialize_review_payload() {
let json = json!({
"action": "created",
"pull_request": {
"id": 99,
"diff_url": "https://mydiff.fr"
}, },
"comment": { "comment": {
"id": 12, "id": 12,
@@ -121,7 +168,7 @@ mod tests {
}); });
let payload: ReviewPayload = serde_json::from_value(json).unwrap(); let payload: ReviewPayload = serde_json::from_value(json).unwrap();
assert_eq!(payload.action, "edited"); assert_eq!(payload.action, "created");
assert_eq!(payload.pull_request.id, 99); assert_eq!(payload.pull_request.id, 99);
assert_eq!(payload.comment.id, 12); assert_eq!(payload.comment.id, 12);
assert_eq!(payload.comment.body, "Needs work"); assert_eq!(payload.comment.body, "Needs work");
@@ -130,8 +177,50 @@ mod tests {
#[test] #[test]
fn test_from_event_empty_json() { fn test_from_event_empty_json() {
let result = WebhookType::from_event("pull_request_comment", json!({})); let result = WebhookType::from_event("pull_request_comment", "test_bot", json!({}));
assert!(result.is_err()); assert!(result.is_err());
assert!(matches!(result.unwrap_err(), AppError::BadJsonStructErr(_))); assert!(matches!(result.unwrap_err(), AppError::BadJsonStructErr(_)));
} }
#[test]
fn test_from_event_rejects_wrong_bot_name() {
let json = json!({
"action": "created",
"pull_request": {
"id": 1,
"diff_url": "https://mydiff.fr"
},
"comment": {
"id": 1,
"body": "@other_bot do something",
"user": {
"id": 1
}
}
});
let result = WebhookType::from_event("pull_request_comment", "test_bot", json);
assert!(matches!(result.unwrap_err(), AppError::UnauthorizedUserErr));
}
#[test]
fn test_from_event_rejects_no_bot_prefix() {
let json = json!({
"action": "created",
"pull_request": {
"id": 1,
"diff_url": "https://mydiff.fr"
},
"comment": {
"id": 1,
"body": "just a comment without bot mention",
"user": {
"id": 1
}
}
});
let result = WebhookType::from_event("pull_request_comment", "test_bot", json);
assert!(matches!(result.unwrap_err(), AppError::UnauthorizedUserErr));
}
} }