Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 256 additions & 27 deletions crates/zeph-channels/src/discord/rest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@

//! Discord REST API client for message operations.

use std::time::Duration;

use serde::{Deserialize, Serialize};

const BASE_URL: &str = "https://discord.com/api/v10";
const MAX_RETRY_SECS: f64 = 60.0;
const MAX_RETRIES: u32 = 3;
/// Per-request HTTP timeout applied to every Discord REST call.
const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);

#[derive(Deserialize)]
struct CurrentApplication {
Expand Down Expand Up @@ -68,6 +74,72 @@ struct EditMessage<'a> {
content: &'a str,
}

/// Body returned by Discord on HTTP 429.
#[derive(Deserialize, Default)]
struct RateLimitBody {
retry_after: Option<f64>,
}

/// Executes a request with automatic 429 retry-after backoff.
///
/// Builds a fresh request each attempt via `make_req`. On HTTP 429 the function
/// reads `Retry-After` header (falling back to the JSON body `retry_after` field),
/// clamps to [`MAX_RETRY_SECS`], sleeps, and retries up to [`MAX_RETRIES`] times.
/// When retries are exhausted a final request is issued to obtain a `reqwest::Error`
/// with the original HTTP status — `reqwest::Error` cannot be constructed directly.
///
/// # Errors
///
/// Returns a [`reqwest::Error`] when all retries are exhausted, a non-429 HTTP error
/// is received, or the per-request timeout ([`REQUEST_TIMEOUT`]) is exceeded.
async fn send_with_retry<F>(make_req: F) -> Result<reqwest::Response, reqwest::Error>
where
F: Fn() -> reqwest::RequestBuilder,
{
let mut attempts = 0u32;
loop {
let resp = make_req().send().await?;

if resp.status() != reqwest::StatusCode::TOO_MANY_REQUESTS {
return resp.error_for_status();
}

// Parse retry delay: header wins, then body field, then default 1 s.
let header_secs = resp
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<f64>().ok());

let body_secs = resp
.json::<RateLimitBody>()
.await
.unwrap_or_default()
.retry_after;

let delay_secs = header_secs.or(body_secs).unwrap_or(1.0).min(MAX_RETRY_SECS);

attempts += 1;
if attempts > MAX_RETRIES {
tracing::warn!(
delay_secs,
attempts,
"discord: rate-limited and retries exhausted"
);
// Surface as an error by issuing the request once more without retry.
return make_req().send().await?.error_for_status();
}

tracing::warn!(
delay_secs,
attempt = attempts,
max = MAX_RETRIES,
"discord: rate-limited (429), backing off"
);
tokio::time::sleep(Duration::from_secs_f64(delay_secs)).await;
}
}

impl RestClient {
#[must_use]
pub fn new(token: String) -> Self {
Expand All @@ -81,41 +153,44 @@ impl RestClient {

/// # Errors
///
/// Returns an error if the HTTP request fails.
/// Returns an error if the HTTP request fails or rate-limit retries are exhausted.
pub async fn send_message(
&self,
channel_id: &str,
content: &str,
) -> Result<DiscordMessage, reqwest::Error> {
self.client
.post(format!("{BASE_URL}/channels/{channel_id}/messages"))
.header("Authorization", self.auth_header())
.json(&CreateMessage { content })
.send()
.await?
.error_for_status()?
.json()
.await
let url = format!("{BASE_URL}/channels/{channel_id}/messages");
let auth = self.auth_header();
let resp = send_with_retry(|| {
self.client
.post(&url)
.header("Authorization", &auth)
.timeout(REQUEST_TIMEOUT)
.json(&CreateMessage { content })
})
.await?;
resp.json().await
}

/// # Errors
///
/// Returns an error if the HTTP request fails.
/// Returns an error if the HTTP request fails or rate-limit retries are exhausted.
pub async fn edit_message(
&self,
channel_id: &str,
message_id: &str,
content: &str,
) -> Result<(), reqwest::Error> {
self.client
.patch(format!(
"{BASE_URL}/channels/{channel_id}/messages/{message_id}"
))
.header("Authorization", self.auth_header())
.json(&EditMessage { content })
.send()
.await?
.error_for_status()?;
let url = format!("{BASE_URL}/channels/{channel_id}/messages/{message_id}");
let auth = self.auth_header();
send_with_retry(|| {
self.client
.patch(&url)
.header("Authorization", &auth)
.timeout(REQUEST_TIMEOUT)
.json(&EditMessage { content })
})
.await?;
Ok(())
}

Expand Down Expand Up @@ -162,14 +237,168 @@ impl RestClient {

/// # Errors
///
/// Returns an error if the HTTP request fails.
/// Returns an error if the HTTP request fails or rate-limit retries are exhausted.
pub async fn trigger_typing(&self, channel_id: &str) -> Result<(), reqwest::Error> {
self.client
.post(format!("{BASE_URL}/channels/{channel_id}/typing"))
.header("Authorization", self.auth_header())
.send()
.await?
.error_for_status()?;
let url = format!("{BASE_URL}/channels/{channel_id}/typing");
let auth = self.auth_header();
send_with_retry(|| {
self.client
.post(&url)
.header("Authorization", &auth)
.timeout(REQUEST_TIMEOUT)
})
.await?;
Ok(())
}
}

#[cfg(test)]
mod tests {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};

use super::*;

#[tokio::test]
async fn send_with_retry_succeeds_on_200() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/channels/ch1/messages"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({"id": "1"})))
.mount(&server)
.await;

let client = reqwest::Client::new();
let url = format!("{}/channels/ch1/messages", server.uri());
let resp = send_with_retry(|| client.post(&url)).await.unwrap();
assert_eq!(resp.status(), 200);
}

#[tokio::test]
async fn send_with_retry_retries_on_429_then_succeeds() {
let server = MockServer::start().await;

// First call → 429 with Retry-After header.
Mock::given(method("POST"))
.and(path("/channels/ch1/messages"))
.respond_with(
ResponseTemplate::new(429)
.append_header("Retry-After", "0")
.set_body_json(serde_json::json!({"retry_after": 0.0})),
)
.up_to_n_times(1)
.mount(&server)
.await;

// Second call → 200.
Mock::given(method("POST"))
.and(path("/channels/ch1/messages"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({"id": "2"})))
.mount(&server)
.await;

let client = reqwest::Client::new();
let url = format!("{}/channels/ch1/messages", server.uri());
let resp = send_with_retry(|| client.post(&url)).await.unwrap();
assert_eq!(resp.status(), 200);
}

#[tokio::test]
async fn send_with_retry_uses_body_retry_after_when_no_header() {
let server = MockServer::start().await;

// Three 429 responses without Retry-After header but with body field.
Mock::given(method("POST"))
.and(path("/channels/ch1/messages"))
.respond_with(
ResponseTemplate::new(429).set_body_json(serde_json::json!({"retry_after": 0.0})),
)
.up_to_n_times(3)
.mount(&server)
.await;

// Fourth → 200.
Mock::given(method("POST"))
.and(path("/channels/ch1/messages"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({"id": "3"})))
.mount(&server)
.await;

let client = reqwest::Client::new();
let url = format!("{}/channels/ch1/messages", server.uri());
let resp = send_with_retry(|| client.post(&url)).await.unwrap();
assert_eq!(resp.status(), 200);
}

#[tokio::test]
async fn send_with_retry_propagates_non_429_errors() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/channels/ch1/messages"))
.respond_with(ResponseTemplate::new(403))
.mount(&server)
.await;

let client = reqwest::Client::new();
let url = format!("{}/channels/ch1/messages", server.uri());
let result = send_with_retry(|| client.post(&url)).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.status(), Some(reqwest::StatusCode::FORBIDDEN));
}

#[tokio::test]
async fn send_with_retry_errors_when_retries_exhausted() {
let server = MockServer::start().await;

// Return 429 for all requests — exhausts MAX_RETRIES (3) then the final attempt.
Mock::given(method("POST"))
.and(path("/channels/ch1/messages"))
.respond_with(
ResponseTemplate::new(429)
.append_header("Retry-After", "0")
.set_body_json(serde_json::json!({"retry_after": 0.0})),
)
.mount(&server)
.await;

let client = reqwest::Client::new();
let url = format!("{}/channels/ch1/messages", server.uri());
let result = send_with_retry(|| client.post(&url)).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.status(), Some(reqwest::StatusCode::TOO_MANY_REQUESTS));
}

#[test]
fn rate_limit_body_defaults_to_none() {
let body: RateLimitBody = serde_json::from_str("{}").unwrap();
assert!(body.retry_after.is_none());
}

#[test]
fn rate_limit_body_parses_float() {
let body: RateLimitBody = serde_json::from_str(r#"{"retry_after": 1.5}"#).unwrap();
assert!((body.retry_after.unwrap() - 1.5).abs() < f64::EPSILON);
}

#[test]
fn max_retry_secs_clamps() {
let unclamped: f64 = 120.0;
assert_eq!(
unclamped.min(MAX_RETRY_SECS).to_bits(),
MAX_RETRY_SECS.to_bits()
);
}

#[test]
fn rest_client_debug_redacts_token() {
let rc = RestClient {
client: reqwest::Client::new(),
token: "secret-token".into(),
};
let debug = format!("{rc:?}");
assert!(!debug.contains("secret-token"));
assert!(debug.contains("REDACTED"));
}
}
Loading