diff --git a/src/bot.rs b/src/bot.rs index 305ac6f..413de1e 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -9,17 +9,18 @@ use crate::{ }; #[derive(Deserialize, Debug)] -struct ReviewResult { - reviews: Vec, - comment: String, +pub struct ReviewResult { + pub reviews: Vec, + pub comment: String, + pub cost: Option, } #[derive(Deserialize, Debug)] -struct ReviewItem { - filename: String, - line: Option, - code: String, - message: String, +pub struct ReviewItem { + pub filename: String, + pub line: Option, + pub code: String, + pub message: String, } /// Map a filename to a markdown language identifier for syntax highlighting. @@ -69,7 +70,7 @@ pub struct Bot { impl Bot { pub fn new(config: EnvConfig) -> anyhow::Result { Ok(Self { - gitea_api: GiteaAPI::new(&config.gitea_url, &config.gitea_token, config.gitea_timeout), + gitea_api: GiteaAPI::new(&config.gitea_url, &config.gitea_token, config.gitea_timeout)?, open_router_client: OpenRouterClient::new( &config.open_router_api_key, &config.open_router_model, @@ -112,7 +113,7 @@ impl Bot { ) .await?; - let bot_result: Result = async { + let bot_result: Result = async { let git_diff = self .download_git_diff(&review_payload.pull_request.diff_url) .await?; @@ -122,7 +123,12 @@ impl Bot { .replace("{comment}", &review_payload.comment.body) .replace("{diff}", &git_diff); - self.open_router_client.chat(&bot_request).await + let chat_result = self.open_router_client.chat(&bot_request).await?; + let mut review_result = serde_json::from_str::(&chat_result.message)?; + + review_result.cost = chat_result.cost; + + Ok(review_result) } .await; @@ -142,17 +148,7 @@ impl Bot { Ok(()) } - fn review_result_to_markdown(&self, result: &str) -> String { - let review_result: ReviewResult = match serde_json::from_str(result) { - Ok(review_result) => review_result, - Err(_) => { - return format!( - "Failed to parse review result. Raw output:\n\n```json\n{}\n```", - result - ); - } - }; - + fn review_result_to_markdown(&self, review_result: &ReviewResult) -> String { if review_result.reviews.is_empty() { return String::from("No issues found. ✅"); } @@ -181,6 +177,12 @@ impl Bot { md.push('\n'); } + if let Some(cost) = review_result.cost { + md.push_str("\n---\n\n"); + md.push_str(&format!("### Cost: ${}", cost)); + md.push('\n'); + } + md } diff --git a/src/consts.rs b/src/consts.rs index 4eb1b41..ded647d 100644 --- a/src/consts.rs +++ b/src/consts.rs @@ -30,6 +30,7 @@ pub const REVIEW_PROMPT: &str = " The line number increments by 1 for each context or added line. Return your feedback, in french, with only this json format, reviews must contain each review + All fields are mandatory. (filename field must contain the full path with extension) and comment must contain a final summary: { diff --git a/src/gitea.rs b/src/gitea.rs index c8bff4a..4f81913 100644 --- a/src/gitea.rs +++ b/src/gitea.rs @@ -12,15 +12,22 @@ pub struct GiteaAPI { } impl GiteaAPI { - pub fn new(base_url: &str, token: &str, timeout: u64) -> Self { - Self { + pub fn new(base_url: &str, token: &str, timeout: u64) -> anyhow::Result { + let mut default_headers = reqwest::header::HeaderMap::new(); + default_headers.insert( + reqwest::header::HeaderName::from_static("authorization"), + reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))?, + ); + + Ok(Self { base_url: String::from(base_url), client: reqwest::Client::builder() .timeout(Duration::from_secs(timeout)) + .default_headers(default_headers) .build() .unwrap(), token: String::from(token), - } + }) } pub async fn comment( @@ -30,8 +37,8 @@ impl GiteaAPI { index: u64, ) -> anyhow::Result { let url = format!( - "{}/api/v1/repos/{}/issues/{}/comments?access_token={}", - self.base_url, full_name, index, self.token + "{}/api/v1/repos/{}/issues/{}/comments", + self.base_url, full_name, index ); let res = self @@ -53,8 +60,8 @@ impl GiteaAPI { comment_id: u64, ) -> anyhow::Result<()> { let url = format!( - "{}/api/v1/repos/{}/issues/comments/{}?access_token={}", - self.base_url, full_name, comment_id, self.token + "{}/api/v1/repos/{}/issues/comments/{}", + self.base_url, full_name, comment_id ); self.client diff --git a/src/open_router.rs b/src/open_router.rs index cb8dbfa..b5602d5 100644 --- a/src/open_router.rs +++ b/src/open_router.rs @@ -2,6 +2,11 @@ use std::time::Duration; use openrouter_rs::{Message, api::chat::ChatCompletionRequest}; +pub struct ChatResult { + pub message: String, + pub cost: Option, +} + pub struct OpenRouterClient { client: openrouter_rs::OpenRouterClient, model: String, @@ -22,21 +27,21 @@ impl OpenRouterClient { }) } - pub async fn chat(&self, msg: &str) -> anyhow::Result { + pub async fn chat(&self, msg: &str) -> anyhow::Result { let request = ChatCompletionRequest::builder() .model(&self.model) .enable_reasoning() - .messages(vec![Message::new( - openrouter_rs::types::Role::Developer, - msg, - )]) + .messages(vec![Message::new(openrouter_rs::types::Role::User, msg)]) .build()?; let response = self.client.chat().create(&request).await?; - response.choices[0] - .content() - .map(|msg| String::from(msg)) - .ok_or(anyhow::anyhow!("No content")) + Ok(ChatResult { + message: response.choices[0] + .content() + .map(|msg| String::from(msg)) + .ok_or(anyhow::anyhow!("No content"))?, + cost: response.usage.and_then(|u| u.cost), + }) } }