diff --git a/src/api.rs b/src/api.rs index 41ea0ab..6e0e66f 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,34 +1,28 @@ -use std::fmt::Debug; - -use anyhow::anyhow; -use axum::body::to_bytes; -use axum::extract::{FromRef, FromRequest, FromRequestParts, State}; +use axum::body::{Bytes, to_bytes}; +use axum::extract::{FromRef, FromRequest}; use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; -use axum::{Json, RequestExt, Router}; +use axum::{Json, Router}; use hmac::{Hmac, KeyInit, Mac}; use serde_json::Value; use sha2::Sha256; use subtle::ConstantTimeEq; +use crate::consts::{GITEA_EVENT_TYPE_HEADER_NAME, GITEA_SIG_HEADER_NAME, MAX_WEBHOOK_BODY_SIZE}; use crate::errors::AppError; use crate::gitea::WebhookType; use crate::state::AppState; -const MAX_WEBHOOK_BODY_SIZE: usize = 1024 * 1024; // 1 Mo - pub async fn start(app_state: AppState) -> anyhow::Result<()> { let http_port = app_state.config.http_port; - let app = Router::new() .route("/", get(root)) .route("/webhook", post(webhook)) .with_state(app_state); - let listerner = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", http_port)).await?; - - axum::serve(listerner, app) + let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", http_port)).await?; + axum::serve(listener, app) .await - .map_err(|e| anyhow::anyhow!(e)) + .map_err(anyhow::Error::from) } async fn root() -> &'static str { @@ -36,11 +30,7 @@ async fn root() -> &'static str { } async fn webhook(WebhookExtract(wb): WebhookExtract) -> Result { - Ok(match wb { - WebhookType::Review(id, _) => format!("Received {} pr id", id), - _ => String::from("Nothing to see :/"), - } - .into_response()) + Ok("lol".into_response()) } pub struct WebhookExtract(pub WebhookType); @@ -52,84 +42,70 @@ where { type Rejection = AppError; - async fn from_request( - mut req: axum::extract::Request, - state: &S, - ) -> Result { - let State(state) = req - .extract_parts_with_state::, _>(state) - .await - .unwrap(); - - let secret_key = state.config.webhook_secret.as_bytes(); - + async fn from_request(req: axum::extract::Request, state: &S) -> Result { + let app_state = AppState::from_ref(state); let headers = req.headers(); - let sig_header = headers - .get("x-gitea-signature") - .ok_or(AppError::WebHookSigHeaderNotFoundErr)? - .to_str() - .map_err(|err| anyhow!(err))? - .to_string(); - let body = req.into_body(); - let body_bytes = to_bytes(body, MAX_WEBHOOK_BODY_SIZE) - .await - .map_err(|err| anyhow!(err))?; + let sig_header = extract_header(GITEA_SIG_HEADER_NAME, headers)?; + let type_header = extract_header(GITEA_EVENT_TYPE_HEADER_NAME, headers)?; + let body_bytes = read_body(req.into_body()).await?; - check_sig_header(secret_key, sig_header.as_bytes(), &body_bytes)?; + verify_signature( + app_state.config.webhook_secret.as_bytes(), + &sig_header, + &body_bytes, + )?; - let Json(value) = - Json::::from_bytes(&body_bytes).map_err(|_| AppError::MalformedJsonErr)?; - let webhook = WebhookType::try_from(value)?; + let webhook = parse_webhook(&type_header, &body_bytes)?; + reject_bot_user(&app_state, &webhook)?; Ok(WebhookExtract(webhook)) } } -fn check_sig_header(secret_key: &[u8], sig_header: &[u8], body: &[u8]) -> Result<(), AppError> { - let sig_header_decoded = hex::decode(sig_header).map_err(|_| AppError::WebHookSigHeaderInvalidErr)?; +fn extract_header(key: &str, headers: &axum::http::HeaderMap) -> Result { + let value = headers + .get(key) + .ok_or(AppError::WebHookMissingHeaderErr(key.into()))? + .to_str() + .map_err(anyhow::Error::from)?; + Ok(value.to_owned()) +} - let mut mac = Hmac::::new_from_slice(secret_key).map_err(|err| anyhow!(err))?; +async fn read_body(body: axum::body::Body) -> Result { + to_bytes(body, MAX_WEBHOOK_BODY_SIZE) + .await + .map_err(anyhow::Error::from) + .map_err(AppError::from) +} + +fn parse_webhook(header: &str, body_bytes: &[u8]) -> Result { + let Json(value) = + Json::::from_bytes(body_bytes).map_err(|_| AppError::MalformedJsonErr)?; + + WebhookType::from_event(header, value) +} + +fn reject_bot_user(state: &AppState, webhook: &WebhookType) -> Result<(), AppError> { + let user_id = match webhook { + WebhookType::Review(review_payload) => review_payload.comment.user.id, + }; + + match user_id != state.config.bot_user_id { + true => Ok(()), + false => Err(AppError::UnauthorizedUserIdErr), + } +} + +fn verify_signature(secret_key: &[u8], sig_header: &str, body: &[u8]) -> Result<(), AppError> { + let sig_header_decoded = + hex::decode(sig_header).map_err(|_| AppError::WebHookSigHeaderInvalidErr)?; + let mut mac = Hmac::::new_from_slice(secret_key).map_err(anyhow::Error::from)?; mac.update(body); let generated_hmac = mac.finalize().into_bytes(); - let check_result: bool = generated_hmac.ct_eq(&sig_header_decoded).into(); - - match check_result { - true => Ok(()), - false => Err(AppError::WebHookSigHeaderInvalidErr), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn valid_json_bytes_parse_to_value() { - let body = serde_json::to_vec( - &json!({"action": "created", "pull_request": {"id": 1}, "comment": {"body": "hi"}}), - ) - .unwrap(); - let Json(value) = Json::::from_bytes(&body).unwrap(); - assert_eq!(value["action"], "created"); - assert_eq!(value["pull_request"]["id"], 1); - assert_eq!(value["comment"]["body"], "hi"); - } - - #[test] - fn malformed_json_bytes_return_malformed_error() { - let body = b"not valid json"; - let result = Json::::from_bytes(body); - assert!(result.is_err()); - } - - #[test] - fn empty_body_returns_malformed_error() { - let body = b""; - let result = Json::::from_bytes(body); - assert!(result.is_err()); - } + bool::from(generated_hmac.ct_eq(&sig_header_decoded)) + .then_some(()) + .ok_or(AppError::WebHookSigHeaderInvalidErr) } diff --git a/src/consts.rs b/src/consts.rs new file mode 100644 index 0000000..75ebdab --- /dev/null +++ b/src/consts.rs @@ -0,0 +1,3 @@ +pub const GITEA_SIG_HEADER_NAME: &str = "x-gitea-signature"; +pub const GITEA_EVENT_TYPE_HEADER_NAME: &str = "x-gitea-event-type"; +pub const MAX_WEBHOOK_BODY_SIZE: usize = 1024 * 1024; // 1 MiB diff --git a/src/env.rs b/src/env.rs index 6298de6..6b8ef2a 100644 --- a/src/env.rs +++ b/src/env.rs @@ -6,21 +6,21 @@ pub struct EnvConfig { pub http_port: u16, pub webhook_secret: String, pub open_router_api_key: String, - pub bot_name: String, + pub bot_user_id: u64, } pub fn load_config() -> anyhow::Result { dotenv().ok(); let http_port = try_get_env("HTTP_PORT")?.parse()?; - let bot_name = try_get_env("BOT_NAME")?; + let bot_user_id = try_get_env("BOT_USER_ID")?.parse()?; let webhook_secret = try_get_env("WEBHOOK_SIG_HEADER_SECRET")?; let open_router_api_key = try_get_env("OPEN_ROUTER_API_KEY")?; Ok(EnvConfig { http_port, webhook_secret, - bot_name, + bot_user_id, open_router_api_key, }) } @@ -28,7 +28,7 @@ pub fn load_config() -> anyhow::Result { fn try_get_env(key: &str) -> anyhow::Result { let env = std::env::var(key)?; - if env.trim().len() == 0 { + if env.trim().is_empty() { return Err(anyhow!(format!("env var {} is empty", env))); } diff --git a/src/errors.rs b/src/errors.rs index 16f4804..459eb42 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -3,24 +3,24 @@ use reqwest::StatusCode; #[derive(thiserror::Error, Debug)] pub enum AppError { + #[error("Unauthorized user id")] + UnauthorizedUserIdErr, + + #[error("Unknow gitea event")] + UnknownEventErr, + #[error("Malformed Json")] MalformedJsonErr, - #[error("Json not contains mandatory fields")] - BadJsonStructErr, + #[error(transparent)] + BadJsonStructErr(#[from] serde_json::Error), - #[error("WebHook sig header not found")] - WebHookSigHeaderNotFoundErr, + #[error("WebHook header not found")] + WebHookMissingHeaderErr(String), #[error("WebHook sig header is invalid")] WebHookSigHeaderInvalidErr, - #[error("Missing required field: {0}")] - MissingField(String), - - #[error("Wrong type for field: {0}")] - WrongFieldType(String), - #[error(transparent)] Other(#[from] anyhow::Error), } @@ -28,39 +28,29 @@ pub enum AppError { impl IntoResponse for AppError { fn into_response(self) -> axum::response::Response { match self { - AppError::MalformedJsonErr => { - (StatusCode::BAD_REQUEST, "Malformed Json".to_string()).into_response() + AppError::UnknownEventErr => { + (StatusCode::BAD_REQUEST, "Unknow gitea event".to_string()) } - AppError::BadJsonStructErr => ( + AppError::UnauthorizedUserIdErr => { + (StatusCode::BAD_REQUEST, "Unauthorized user id".to_string()) + } + AppError::MalformedJsonErr => (StatusCode::BAD_REQUEST, "Malformed Json".to_string()), + AppError::BadJsonStructErr(err) => ( StatusCode::BAD_REQUEST, - "Json not contains mandatory fields".to_string(), - ) - .into_response(), - AppError::WebHookSigHeaderNotFoundErr => ( - StatusCode::BAD_REQUEST, - "WebHook sig header not found".to_string(), - ) - .into_response(), + format!("Json not contains mandatory fields: {}", err), + ), + AppError::WebHookMissingHeaderErr(h) => { + (StatusCode::BAD_REQUEST, format!("header {} is missing", h)) + } AppError::WebHookSigHeaderInvalidErr => ( StatusCode::UNAUTHORIZED, "WebHook sig header is invalid".to_string(), - ) - .into_response(), - AppError::MissingField(ref field) => ( - StatusCode::BAD_REQUEST, - format!("Missing required field: {}", field), - ) - .into_response(), - AppError::WrongFieldType(ref field) => ( - StatusCode::BAD_REQUEST, - format!("Wrong type for field: {}", field), - ) - .into_response(), + ), AppError::Other(_) => ( StatusCode::INTERNAL_SERVER_ERROR, "Internal server error".to_string(), - ) - .into_response(), + ), } + .into_response() } } diff --git a/src/gitea.rs b/src/gitea.rs index 5d1d1b8..b82af15 100644 --- a/src/gitea.rs +++ b/src/gitea.rs @@ -1,46 +1,43 @@ +use serde::Deserialize; use serde_json::Value; use crate::errors::AppError; -#[derive(Debug, PartialEq)] +#[derive(Debug)] pub enum WebhookType { - Review(u64, String), + Review(ReviewPayload), } -impl TryFrom for WebhookType { - type Error = AppError; +#[derive(Deserialize, Debug)] +pub struct ReviewPayload { + pub action: String, + pub pull_request: PullRequest, + pub comment: Comment, +} - fn try_from(json: Value) -> Result { - let pull_request = json.get("pull_request"); - let comment = json.get("comment"); - let action = json - .get("action") - .ok_or(AppError::MissingField("action".into()))? - .as_str() - .ok_or(AppError::WrongFieldType("action".into()))?; +#[derive(Deserialize, Debug)] +pub struct PullRequest { + pub id: u64, +} - if action != "created" { - return Err(AppError::BadJsonStructErr); +#[derive(Deserialize, Debug)] +pub struct Comment { + pub id: u64, + pub body: String, + pub user: User, +} + +#[derive(Deserialize, Debug)] +pub struct User { + pub id: u64, +} + +impl WebhookType { + pub fn from_event(event: &str, json: Value) -> Result { + match event { + "pull_request_comment" => Ok(WebhookType::Review(serde_json::from_value(json)?)), + _ => Err(AppError::UnknownEventErr), } - - if let (Some(pull_request), Some(comment)) = (pull_request, comment) { - let comment_body = comment - .get("body") - .ok_or(AppError::MissingField("comment.body".into()))? - .as_str() - .ok_or(AppError::WrongFieldType("comment.body".into()))? - .to_string(); - - let pr_id = pull_request - .get("id") - .ok_or(AppError::MissingField("pull_request.id".into()))? - .as_u64() - .ok_or(AppError::WrongFieldType("pull_request.id".into()))?; - - return Ok(WebhookType::Review(pr_id, comment_body)); - } - - Err(AppError::BadJsonStructErr) } } @@ -50,149 +47,91 @@ mod tests { use serde_json::json; #[test] - fn valid_webhook_parses_review() { - let payload = json!({ + fn test_from_event_valid_pull_request_comment() { + let json = json!({ "action": "created", - "pull_request": { "id": 42 }, - "comment": { "body": "LGTM" } + "pull_request": { + "id": 42 + }, + "comment": { + "id": 7, + "body": "LGTM", + "user": { + "id": 100 + } + } }); - let result = WebhookType::try_from(payload).unwrap(); - assert_eq!(result, WebhookType::Review(42, "LGTM".into())); + + let result = WebhookType::from_event("pull_request_comment", json); + assert!(result.is_ok()); + + match result.unwrap() { + WebhookType::Review(payload) => { + assert_eq!(payload.action, "created"); + assert_eq!(payload.pull_request.id, 42); + assert_eq!(payload.comment.id, 7); + assert_eq!(payload.comment.body, "LGTM"); + assert_eq!(payload.comment.user.id, 100); + } + } } #[test] - fn missing_action_returns_error() { - let payload = json!({ - "pull_request": { "id": 1 }, - "comment": { "body": "ok" } - }); - let err = WebhookType::try_from(payload).unwrap_err(); - assert!(matches!(err, AppError::MissingField(ref f) if f == "action")); + fn test_from_event_unknown_event() { + let json = json!({}); + let result = WebhookType::from_event("push", json); + assert!(result.is_err()); + + match result.unwrap_err() { + AppError::UnknownEventErr => {} + _ => panic!("expected UnknownEventErr"), + } } #[test] - fn action_not_created_returns_bad_json_struct() { - let payload = json!({ - "action": "updated", - "pull_request": { "id": 1 }, - "comment": { "body": "ok" } + fn test_from_event_malformed_json() { + let json = json!({ + "action": "created" + // pull_request and comment are missing }); - let err = WebhookType::try_from(payload).unwrap_err(); - assert!(matches!(err, AppError::BadJsonStructErr)); + + let result = WebhookType::from_event("pull_request_comment", json); + assert!(result.is_err()); + + match result.unwrap_err() { + AppError::BadJsonStructErr(_) => {} + _ => panic!("expected BadJsonStructErr"), + } } #[test] - fn action_not_a_string_returns_error() { - let payload = json!({ - "action": 123, - "pull_request": { "id": 1 }, - "comment": { "body": "ok" } + fn test_deserialize_review_payload() { + let json = json!({ + "action": "edited", + "pull_request": { + "id": 99 + }, + "comment": { + "id": 12, + "body": "Needs work", + "user": { + "id": 200 + } + } }); - let err = WebhookType::try_from(payload).unwrap_err(); - assert!(matches!(err, AppError::WrongFieldType(ref f) if f == "action")); + + let payload: ReviewPayload = serde_json::from_value(json).unwrap(); + assert_eq!(payload.action, "edited"); + assert_eq!(payload.pull_request.id, 99); + assert_eq!(payload.comment.id, 12); + assert_eq!(payload.comment.body, "Needs work"); + assert_eq!(payload.comment.user.id, 200); } #[test] - fn missing_pull_request_returns_bad_json_struct() { - let payload = json!({ - "action": "created", - "comment": { "body": "ok" } - }); - let err = WebhookType::try_from(payload).unwrap_err(); - assert!(matches!(err, AppError::BadJsonStructErr)); - } - - #[test] - fn missing_comment_returns_bad_json_struct() { - let payload = json!({ - "action": "created", - "pull_request": { "id": 1 } - }); - let err = WebhookType::try_from(payload).unwrap_err(); - assert!(matches!(err, AppError::BadJsonStructErr)); - } - - #[test] - fn missing_pr_id_returns_error() { - let payload = json!({ - "action": "created", - "pull_request": { "number": 1 }, - "comment": { "body": "ok" } - }); - let err = WebhookType::try_from(payload).unwrap_err(); - assert!(matches!(err, AppError::MissingField(ref f) if f == "pull_request.id")); - } - - #[test] - fn pr_id_not_a_number_returns_error() { - let payload = json!({ - "action": "created", - "pull_request": { "id": "not-a-number" }, - "comment": { "body": "ok" } - }); - let err = WebhookType::try_from(payload).unwrap_err(); - assert!(matches!(err, AppError::WrongFieldType(ref f) if f == "pull_request.id")); - } - - #[test] - fn missing_comment_body_returns_error() { - let payload = json!({ - "action": "created", - "pull_request": { "id": 1 }, - "comment": { "text": "no body" } - }); - let err = WebhookType::try_from(payload).unwrap_err(); - assert!(matches!(err, AppError::MissingField(ref f) if f == "comment.body")); - } - - #[test] - fn comment_body_not_a_string_returns_error() { - let payload = json!({ - "action": "created", - "pull_request": { "id": 1 }, - "comment": { "body": 999 } - }); - let err = WebhookType::try_from(payload).unwrap_err(); - assert!(matches!(err, AppError::WrongFieldType(ref f) if f == "comment.body")); - } - - #[test] - fn null_pull_request_returns_error() { - let payload = json!({ - "action": "created", - "pull_request": null, - "comment": { "body": "ok" } - }); - let err = WebhookType::try_from(payload).unwrap_err(); - assert!(matches!(err, AppError::MissingField(ref f) if f == "pull_request.id")); - } - - #[test] - fn null_comment_returns_error() { - let payload = json!({ - "action": "created", - "pull_request": { "id": 1 }, - "comment": null - }); - let err = WebhookType::try_from(payload).unwrap_err(); - assert!(matches!(err, AppError::MissingField(ref f) if f == "comment.body")); - } - - #[test] - fn large_pr_id_parses_correctly() { - let payload = json!({ - "action": "created", - "pull_request": { "id": 18446744073709551615u64 }, - "comment": { "body": "max u64" } - }); - let result = WebhookType::try_from(payload).unwrap(); - assert_eq!(result, WebhookType::Review(18446744073709551615, "max u64".into())); - } - - #[test] - fn full_webhook_payload_parses() { - let payload: Value = serde_json::from_str(include_str!("../docs/webhook_pr_body.json")).unwrap(); - let result = WebhookType::try_from(payload).unwrap(); - assert_eq!(result, WebhookType::Review(1, "Test comment".into())); + fn test_from_event_empty_json() { + let result = WebhookType::from_event("pull_request_comment", json!({})); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), AppError::BadJsonStructErr(_))); } } diff --git a/src/main.rs b/src/main.rs index 6c3db06..83830ed 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ use crate::{bot::Bot, state::AppState}; mod api; mod bot; +mod consts; mod env; mod errors; mod gitea;