Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 96 additions & 33 deletions src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,45 +67,73 @@ pub async fn handle_cached(
req: Request<axum::body::Body>,
) -> Result<axum::response::Response, String> {
let (parts, body) = req.into_parts();
// Buffering for now (later: streaming). For GET/HEAD this should be empty.
let body_bytes = axum::body::to_bytes(body, usize::MAX)
.await
.unwrap_or_default();

let norm_uri = normalize::normalize_uri(&parts.uri);

let cache_key = build_cache_key(&norm_uri, &parts.headers);

// BAN logic: if URL is banned, skip cache.
if state.cache.is_banned(&norm_uri) {
// fall through to upstream
} else if let Some(resp) = lookup(&state.cache, &cache_key, &parts.headers, &norm_uri)? {
return Ok(resp);
}

// miss: fetch upstream
let upstream_url = normalize::build_upstream_url(&state.cfg.origin, &norm_uri);
let banned = state.cache.is_banned(&norm_uri);

let mut headers = parts.headers.clone();
headers.insert(
http::header::HeaderName::from_static("surrogate-capability"),
http::HeaderValue::from_static("shopware=ESI/1.0"),
);

let body_bytes = axum::body::to_bytes(body, usize::MAX)
.await
.unwrap_or_default();
if !banned {
if let Some((resp, stale)) = lookup(&state.cache, &cache_key, &parts.headers, &norm_uri)? {
if !stale {
return Ok(resp);
}

let up = state
.client
.request(parts.method.clone(), upstream_url)
.headers(headers)
.body(body_bytes)
.send()
.await
.map_err(|e| format!("upstream: {e}"))?;
// Stale within grace: try to refresh from origin; serve stale only if origin fails.
match fetch_upstream_raw(&state, &parts, &norm_uri, body_bytes.clone()).await {
Ok((status, mut resp_headers, bytes)) => {
// Origin errors => serve stale
if status.is_server_error() {
let mut resp = resp;
resp.headers_mut().insert(
http::header::HeaderName::from_static("x-codycache"),
http::HeaderValue::from_static("STALE"),
);
return Ok(resp);
}

let ttl = ttl_from_headers(&resp_headers).unwrap_or(Duration::from_secs(0));
let cacheable = ttl.as_secs() > 0
&& (parts.method == http::Method::GET
|| parts.method == http::Method::HEAD);

if cacheable {
resp_headers.remove(http::header::SET_COOKIE);
store(
&state.cache,
&cache_key,
&norm_uri,
status,
&resp_headers,
&bytes,
ttl,
)?;
}

return Ok(build_response(status, resp_headers, bytes, &norm_uri));
}
Err(_e) => {
// Serve stale
let mut resp = resp;
resp.headers_mut().insert(
http::header::HeaderName::from_static("x-codycache"),
http::HeaderValue::from_static("STALE"),
);
return Ok(resp);
}
}
}
}

let status = up.status();
let mut resp_headers = up.headers().clone();
let bytes = up
.bytes()
.await
.map_err(|e| format!("upstream body: {e}"))?;
// miss: fetch upstream
let (status, mut resp_headers, bytes) =
fetch_upstream_raw(&state, &parts, &norm_uri, body_bytes.clone()).await?;

// Decide TTL
let ttl = ttl_from_headers(&resp_headers).unwrap_or(Duration::from_secs(0));
Expand Down Expand Up @@ -178,7 +206,7 @@ fn lookup(
key: &str,
req_headers: &HeaderMap,
uri: &Uri,
) -> Result<Option<axum::response::Response>, String> {
) -> Result<Option<(axum::response::Response, bool)>, String> {
let inner = cache.inner.read();
let Some((meta, body)) = inner.disk.get(key)? else {
return Ok(None);
Expand All @@ -193,6 +221,8 @@ fn lookup(
return Ok(None);
}

let stale = !fresh;

// VCL hit logic: pass if client states matches invalidation states
if let (Some(req_states), Some(obj_states)) = (
extract_cookie(req_headers, "sw-states"),
Expand All @@ -218,7 +248,40 @@ fn lookup(
let mut resp = resp;
*resp.headers_mut() = headers;

Ok(Some(resp))
Ok(Some((resp, stale)))
}

async fn fetch_upstream_raw(
state: &AppState,
parts: &http::request::Parts,
uri: &Uri,
body_bytes: bytes::Bytes,
) -> Result<(http::StatusCode, HeaderMap, Bytes), String> {
let upstream_url = normalize::build_upstream_url(&state.cfg.origin, uri);

let mut headers = parts.headers.clone();
headers.insert(
http::header::HeaderName::from_static("surrogate-capability"),
http::HeaderValue::from_static("shopware=ESI/1.0"),
);

let up = state
.client
.request(parts.method.clone(), upstream_url)
.headers(headers)
.body(body_bytes)
.send()
.await
.map_err(|e| format!("upstream: {e}"))?;

let status = up.status();
let resp_headers = up.headers().clone();
let bytes = up
.bytes()
.await
.map_err(|e| format!("upstream body: {e}"))?;

Ok((status, resp_headers, bytes))
}

fn store(
Expand Down