improve webhook parsing
This commit is contained in:
+60
-84
@@ -1,34 +1,28 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use axum::body::to_bytes;
|
||||
use axum::extract::{FromRef, FromRequest, FromRequestParts, State};
|
||||
use axum::body::{Bytes, to_bytes};
|
||||
use axum::extract::{FromRef, FromRequest};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, RequestExt, Router};
|
||||
use axum::{Json, Router};
|
||||
use hmac::{Hmac, KeyInit, Mac};
|
||||
use serde_json::Value;
|
||||
use sha2::Sha256;
|
||||
use subtle::ConstantTimeEq;
|
||||
|
||||
use crate::consts::{GITEA_EVENT_TYPE_HEADER_NAME, GITEA_SIG_HEADER_NAME, MAX_WEBHOOK_BODY_SIZE};
|
||||
use crate::errors::AppError;
|
||||
use crate::gitea::WebhookType;
|
||||
use crate::state::AppState;
|
||||
|
||||
const MAX_WEBHOOK_BODY_SIZE: usize = 1024 * 1024; // 1 Mo
|
||||
|
||||
pub async fn start(app_state: AppState) -> anyhow::Result<()> {
|
||||
let http_port = app_state.config.http_port;
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(root))
|
||||
.route("/webhook", post(webhook))
|
||||
.with_state(app_state);
|
||||
let listerner = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", http_port)).await?;
|
||||
|
||||
axum::serve(listerner, app)
|
||||
let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", http_port)).await?;
|
||||
axum::serve(listener, app)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!(e))
|
||||
.map_err(anyhow::Error::from)
|
||||
}
|
||||
|
||||
async fn root() -> &'static str {
|
||||
@@ -36,11 +30,7 @@ async fn root() -> &'static str {
|
||||
}
|
||||
|
||||
async fn webhook(WebhookExtract(wb): WebhookExtract) -> Result<Response, AppError> {
|
||||
Ok(match wb {
|
||||
WebhookType::Review(id, _) => format!("Received {} pr id", id),
|
||||
_ => String::from("Nothing to see :/"),
|
||||
}
|
||||
.into_response())
|
||||
Ok("lol".into_response())
|
||||
}
|
||||
|
||||
pub struct WebhookExtract(pub WebhookType);
|
||||
@@ -52,84 +42,70 @@ where
|
||||
{
|
||||
type Rejection = AppError;
|
||||
|
||||
async fn from_request(
|
||||
mut req: axum::extract::Request,
|
||||
state: &S,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
let State(state) = req
|
||||
.extract_parts_with_state::<State<AppState>, _>(state)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let secret_key = state.config.webhook_secret.as_bytes();
|
||||
|
||||
async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let app_state = AppState::from_ref(state);
|
||||
let headers = req.headers();
|
||||
let sig_header = headers
|
||||
.get("x-gitea-signature")
|
||||
.ok_or(AppError::WebHookSigHeaderNotFoundErr)?
|
||||
.to_str()
|
||||
.map_err(|err| anyhow!(err))?
|
||||
.to_string();
|
||||
|
||||
let body = req.into_body();
|
||||
let body_bytes = to_bytes(body, MAX_WEBHOOK_BODY_SIZE)
|
||||
.await
|
||||
.map_err(|err| anyhow!(err))?;
|
||||
let sig_header = extract_header(GITEA_SIG_HEADER_NAME, headers)?;
|
||||
let type_header = extract_header(GITEA_EVENT_TYPE_HEADER_NAME, headers)?;
|
||||
let body_bytes = read_body(req.into_body()).await?;
|
||||
|
||||
check_sig_header(secret_key, sig_header.as_bytes(), &body_bytes)?;
|
||||
verify_signature(
|
||||
app_state.config.webhook_secret.as_bytes(),
|
||||
&sig_header,
|
||||
&body_bytes,
|
||||
)?;
|
||||
|
||||
let Json(value) =
|
||||
Json::<Value>::from_bytes(&body_bytes).map_err(|_| AppError::MalformedJsonErr)?;
|
||||
let webhook = WebhookType::try_from(value)?;
|
||||
let webhook = parse_webhook(&type_header, &body_bytes)?;
|
||||
reject_bot_user(&app_state, &webhook)?;
|
||||
|
||||
Ok(WebhookExtract(webhook))
|
||||
}
|
||||
}
|
||||
|
||||
fn check_sig_header(secret_key: &[u8], sig_header: &[u8], body: &[u8]) -> Result<(), AppError> {
|
||||
let sig_header_decoded = hex::decode(sig_header).map_err(|_| AppError::WebHookSigHeaderInvalidErr)?;
|
||||
fn extract_header(key: &str, headers: &axum::http::HeaderMap) -> Result<String, AppError> {
|
||||
let value = headers
|
||||
.get(key)
|
||||
.ok_or(AppError::WebHookMissingHeaderErr(key.into()))?
|
||||
.to_str()
|
||||
.map_err(anyhow::Error::from)?;
|
||||
Ok(value.to_owned())
|
||||
}
|
||||
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(secret_key).map_err(|err| anyhow!(err))?;
|
||||
async fn read_body(body: axum::body::Body) -> Result<Bytes, AppError> {
|
||||
to_bytes(body, MAX_WEBHOOK_BODY_SIZE)
|
||||
.await
|
||||
.map_err(anyhow::Error::from)
|
||||
.map_err(AppError::from)
|
||||
}
|
||||
|
||||
fn parse_webhook(header: &str, body_bytes: &[u8]) -> Result<WebhookType, AppError> {
|
||||
let Json(value) =
|
||||
Json::<Value>::from_bytes(body_bytes).map_err(|_| AppError::MalformedJsonErr)?;
|
||||
|
||||
WebhookType::from_event(header, value)
|
||||
}
|
||||
|
||||
fn reject_bot_user(state: &AppState, webhook: &WebhookType) -> Result<(), AppError> {
|
||||
let user_id = match webhook {
|
||||
WebhookType::Review(review_payload) => review_payload.comment.user.id,
|
||||
};
|
||||
|
||||
match user_id != state.config.bot_user_id {
|
||||
true => Ok(()),
|
||||
false => Err(AppError::UnauthorizedUserIdErr),
|
||||
}
|
||||
}
|
||||
|
||||
fn verify_signature(secret_key: &[u8], sig_header: &str, body: &[u8]) -> Result<(), AppError> {
|
||||
let sig_header_decoded =
|
||||
hex::decode(sig_header).map_err(|_| AppError::WebHookSigHeaderInvalidErr)?;
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(secret_key).map_err(anyhow::Error::from)?;
|
||||
|
||||
mac.update(body);
|
||||
|
||||
let generated_hmac = mac.finalize().into_bytes();
|
||||
let check_result: bool = generated_hmac.ct_eq(&sig_header_decoded).into();
|
||||
|
||||
match check_result {
|
||||
true => Ok(()),
|
||||
false => Err(AppError::WebHookSigHeaderInvalidErr),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn valid_json_bytes_parse_to_value() {
|
||||
let body = serde_json::to_vec(
|
||||
&json!({"action": "created", "pull_request": {"id": 1}, "comment": {"body": "hi"}}),
|
||||
)
|
||||
.unwrap();
|
||||
let Json(value) = Json::<Value>::from_bytes(&body).unwrap();
|
||||
assert_eq!(value["action"], "created");
|
||||
assert_eq!(value["pull_request"]["id"], 1);
|
||||
assert_eq!(value["comment"]["body"], "hi");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn malformed_json_bytes_return_malformed_error() {
|
||||
let body = b"not valid json";
|
||||
let result = Json::<Value>::from_bytes(body);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_body_returns_malformed_error() {
|
||||
let body = b"";
|
||||
let result = Json::<Value>::from_bytes(body);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
bool::from(generated_hmac.ct_eq(&sig_header_decoded))
|
||||
.then_some(())
|
||||
.ok_or(AppError::WebHookSigHeaderInvalidErr)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user