Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: CI

on:
push:
pull_request:

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt, clippy

- uses: Swatinem/rust-cache@v2

- name: fmt
run: cargo fmt --all -- --check

- name: clippy
run: cargo clippy --all-targets --all-features -- -D warnings

- name: test
run: cargo test --all
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ blake3 = "1"
sled = "0.34"
bincode = "1"

[dev-dependencies]
tempfile = "3"

[profile.release]
lto = true
codegen-units = 1
110 changes: 101 additions & 9 deletions src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ pub async fn handle_cached(
http::HeaderValue::from_static("shopware=ESI/1.0"),
);

let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap_or_default();
let body_bytes = axum::body::to_bytes(body, usize::MAX)
.await
.unwrap_or_default();

let up = state
.client
Expand All @@ -74,14 +76,22 @@ pub async fn handle_cached(

let status = up.status();
let mut resp_headers = up.headers().clone();
let bytes = up.bytes().await.map_err(|e| format!("upstream body: {e}"))?;
let bytes = up
.bytes()
.await
.map_err(|e| format!("upstream body: {e}"))?;

// Decide TTL
let ttl = ttl_from_headers(&resp_headers).unwrap_or(Duration::from_secs(0));
let cacheable = ttl.as_secs() > 0 && (parts.method == http::Method::GET || parts.method == http::Method::HEAD);
let cacheable = ttl.as_secs() > 0
&& (parts.method == http::Method::GET || parts.method == http::Method::HEAD);

// VCL: sw-dynamic-cache-bypass => hit-for-miss 1s
if resp_headers.get("sw-dynamic-cache-bypass").and_then(|v| v.to_str().ok()) == Some("1") {
if resp_headers
.get("sw-dynamic-cache-bypass")
.and_then(|v| v.to_str().ok())
== Some("1")
{
resp_headers.remove("sw-dynamic-cache-bypass");
return Ok(build_response(status, resp_headers, bytes, &norm_uri));
}
Expand All @@ -99,7 +109,10 @@ pub async fn handle_cached(
fn build_cache_key(uri: &Uri, headers: &HeaderMap) -> String {
let mut key = uri.to_string();

let ctx = headers.get("sw-cache-hash").and_then(|v| v.to_str().ok()).unwrap_or("");
let ctx = headers
.get("sw-cache-hash")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let cur = extract_cookie(headers, "sw-currency").unwrap_or_default();

if !ctx.is_empty() {
Expand All @@ -126,7 +139,12 @@ fn extract_cookie(headers: &HeaderMap, name: &str) -> Option<String> {
None
}

fn lookup(cache: &Cache, key: &str, req_headers: &HeaderMap, uri: &Uri) -> Result<Option<axum::response::Response>, String> {
fn lookup(
cache: &Cache,
key: &str,
req_headers: &HeaderMap,
uri: &Uri,
) -> Result<Option<axum::response::Response>, String> {
let inner = cache.inner.read();
let Some((meta, body)) = inner.disk.get(key)? else {
return Ok(None);
Expand All @@ -142,7 +160,10 @@ fn lookup(cache: &Cache, key: &str, req_headers: &HeaderMap, uri: &Uri) -> Resul
}

// VCL hit logic: pass if client states matches invalidation states
if let (Some(req_states), Some(obj_states)) = (extract_cookie(req_headers, "sw-states"), meta.invalidation_states.as_deref()) {
if let (Some(req_states), Some(obj_states)) = (
extract_cookie(req_headers, "sw-states"),
meta.invalidation_states.as_deref(),
) {
if req_states.contains("logged-in") && obj_states.contains("logged-in") {
return Ok(None);
}
Expand Down Expand Up @@ -179,7 +200,11 @@ fn store(
let tags = headers
.get("xkey")
.and_then(|v| v.to_str().ok())
.map(|s| s.split_whitespace().map(|t| t.to_string()).collect::<Vec<_>>())
.map(|s| {
s.split_whitespace()
.map(|t| t.to_string())
.collect::<Vec<_>>()
})
.unwrap_or_default();

let invalidation_states = headers
Expand Down Expand Up @@ -215,7 +240,12 @@ fn ttl_from_headers(headers: &HeaderMap) -> Option<Duration> {
None
}

fn build_response(status: http::StatusCode, mut headers: HeaderMap, bytes: Bytes, uri: &Uri) -> axum::response::Response {
fn build_response(
status: http::StatusCode,
mut headers: HeaderMap,
bytes: Bytes,
uri: &Uri,
) -> axum::response::Response {
normalize::apply_client_cache_policy(uri, &mut headers);
normalize::strip_internal_headers(&mut headers);

Expand All @@ -226,3 +256,65 @@ fn build_response(status: http::StatusCode, mut headers: HeaderMap, bytes: Bytes
*resp.headers_mut() = headers;
resp
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn cache_key_varies_by_context_hash_when_present() {
let uri: Uri = "/foo?a=1".parse().unwrap();
let mut h1 = HeaderMap::new();
h1.insert("sw-cache-hash", http::HeaderValue::from_static("abc"));
h1.insert(
http::header::COOKIE,
http::HeaderValue::from_static("sw-currency=EUR"),
);

let mut h2 = HeaderMap::new();
h2.insert("sw-cache-hash", http::HeaderValue::from_static("def"));
h2.insert(
http::header::COOKIE,
http::HeaderValue::from_static("sw-currency=EUR"),
);

let k1 = build_cache_key(&uri, &h1);
let k2 = build_cache_key(&uri, &h2);
assert_ne!(k1, k2);
}

#[test]
fn cache_key_falls_back_to_currency_when_no_context_hash() {
let uri: Uri = "/foo?a=1".parse().unwrap();

let mut h1 = HeaderMap::new();
h1.insert(
http::header::COOKIE,
http::HeaderValue::from_static("sw-currency=EUR"),
);

let mut h2 = HeaderMap::new();
h2.insert(
http::header::COOKIE,
http::HeaderValue::from_static("sw-currency=USD"),
);

let k1 = build_cache_key(&uri, &h1);
let k2 = build_cache_key(&uri, &h2);
assert_ne!(k1, k2);
}

#[test]
fn extract_cookie_parses_simple_cookie_header() {
let mut headers = HeaderMap::new();
headers.insert(
http::header::COOKIE,
http::HeaderValue::from_static("a=1; sw-currency=EUR; b=2"),
);
assert_eq!(
extract_cookie(&headers, "sw-currency").as_deref(),
Some("EUR")
);
assert_eq!(extract_cookie(&headers, "missing"), None);
}
}
24 changes: 18 additions & 6 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,31 @@ pub struct Config {

impl Config {
pub fn from_env() -> Result<Self, String> {
let listen = std::env::var("CODYCACHE_LISTEN").unwrap_or_else(|_| "0.0.0.0:8080".to_string());
let origin = std::env::var("CODYCACHE_ORIGIN").map_err(|_| "CODYCACHE_ORIGIN is required".to_string())?;
let listen =
std::env::var("CODYCACHE_LISTEN").unwrap_or_else(|_| "0.0.0.0:8080".to_string());
let origin = std::env::var("CODYCACHE_ORIGIN")
.map_err(|_| "CODYCACHE_ORIGIN is required".to_string())?;

let cache_dir = std::env::var("CODYCACHE_CACHE_DIR").unwrap_or_else(|_| "./cache".to_string());
let cache_dir =
std::env::var("CODYCACHE_CACHE_DIR").unwrap_or_else(|_| "./cache".to_string());

let purgers_raw = std::env::var("CODYCACHE_PURGERS").unwrap_or_else(|_| "127.0.0.1/32,::1/128".to_string());
let purgers_raw = std::env::var("CODYCACHE_PURGERS")
.unwrap_or_else(|_| "127.0.0.1/32,::1/128".to_string());
let purgers = purgers_raw
.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.map(|s| s.parse::<IpNet>().map_err(|e| format!("invalid CIDR/IP in CODYCACHE_PURGERS: {s}: {e}")))
.map(|s| {
s.parse::<IpNet>()
.map_err(|e| format!("invalid CIDR/IP in CODYCACHE_PURGERS: {s}: {e}"))
})
.collect::<Result<Vec<_>, _>>()?;

Ok(Self { listen, origin, purgers, cache_dir })
Ok(Self {
listen,
origin,
purgers,
cache_dir,
})
}
}
43 changes: 31 additions & 12 deletions src/disk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@ use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::{
fs,
io,
path::{Path, PathBuf},
time::{Duration, SystemTime, UNIX_EPOCH},
};

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct DiskStore {
root: PathBuf,
db: sled::Db,
Expand All @@ -33,7 +32,11 @@ impl DiskStore {
let root = root.as_ref().to_path_buf();
fs::create_dir_all(root.join("entries")).map_err(|e| format!("create cache dir: {e}"))?;
let db = sled::open(root.join("index")).map_err(|e| format!("open sled: {e}"))?;
Ok(Self { root, db, lock: Mutex::new(()) })
Ok(Self {
root,
db,
lock: Mutex::new(()),
})
}

fn entry_dir(&self, key: &str) -> PathBuf {
Expand All @@ -50,7 +53,8 @@ impl DiskStore {
}

let meta_bytes = fs::read(&meta_path).map_err(|e| format!("read meta: {e}"))?;
let meta: StoredMeta = serde_json::from_slice(&meta_bytes).map_err(|e| format!("parse meta: {e}"))?;
let meta: StoredMeta =
serde_json::from_slice(&meta_bytes).map_err(|e| format!("parse meta: {e}"))?;
let body = fs::read(&body_path).map_err(|e| format!("read body: {e}"))?;
Ok(Some((meta, Bytes::from(body))))
}
Expand All @@ -76,7 +80,9 @@ impl DiskStore {
.unwrap_or_default();
set.insert(key.to_string());
let enc = bincode::serialize(&set).map_err(|e| format!("bincode: {e}"))?;
self.db.insert(k.as_bytes(), enc).map_err(|e| format!("sled insert: {e}"))?;
self.db
.insert(k.as_bytes(), enc)
.map_err(|e| format!("sled insert: {e}"))?;
}

self.db.flush().map_err(|e| format!("sled flush: {e}"))?;
Expand All @@ -96,13 +102,19 @@ impl DiskStore {
for tag in meta.tags {
let k = format!("tag:{tag}");
if let Some(v) = self.db.get(&k).map_err(|e| format!("sled get: {e}"))? {
let mut set: std::collections::BTreeSet<String> = bincode::deserialize(&v).unwrap_or_default();
let mut set: std::collections::BTreeSet<String> =
bincode::deserialize(&v).unwrap_or_default();
set.remove(key);
if set.is_empty() {
self.db.remove(k.as_bytes()).map_err(|e| format!("sled remove: {e}"))?;
self.db
.remove(k.as_bytes())
.map_err(|e| format!("sled remove: {e}"))?;
} else {
let enc = bincode::serialize(&set).map_err(|e| format!("bincode: {e}"))?;
self.db.insert(k.as_bytes(), enc).map_err(|e| format!("sled insert: {e}"))?;
let enc =
bincode::serialize(&set).map_err(|e| format!("bincode: {e}"))?;
self.db
.insert(k.as_bytes(), enc)
.map_err(|e| format!("sled insert: {e}"))?;
}
}
}
Expand All @@ -121,7 +133,8 @@ impl DiskStore {
for tag in tags {
let k = format!("tag:{tag}");
if let Some(v) = self.db.get(&k).map_err(|e| format!("sled get: {e}"))? {
let set: std::collections::BTreeSet<String> = bincode::deserialize(&v).unwrap_or_default();
let set: std::collections::BTreeSet<String> =
bincode::deserialize(&v).unwrap_or_default();
keys.extend(set);
}
}
Expand All @@ -147,15 +160,21 @@ pub fn headers_to_pairs(headers: &HeaderMap) -> Vec<(String, String)> {
pub fn pairs_to_headers(pairs: &[(String, String)]) -> HeaderMap {
let mut out = HeaderMap::new();
for (k, v) in pairs {
if let (Ok(name), Ok(val)) = (http::header::HeaderName::from_bytes(k.as_bytes()), http::HeaderValue::from_str(v)) {
if let (Ok(name), Ok(val)) = (
http::header::HeaderName::from_bytes(k.as_bytes()),
http::HeaderValue::from_str(v),
) {
out.insert(name, val);
}
}
out
}

pub fn now_ms() -> u64 {
SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or(Duration::from_secs(0)).as_millis() as u64
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0))
.as_millis() as u64
}

fn remove_dir_all_best_effort(path: &Path) {
Expand Down
Loading