diff --git a/.env.example b/.env.example index 8829a0c..17067dd 100644 --- a/.env.example +++ b/.env.example @@ -12,4 +12,9 @@ WEBHOOK_PINGS="<@&role_id> <@user_id>" ADMIN_PASSWORD=supersecret # used for kept video access -PUBLIC_URL=https://vertd.your-domain.here \ No newline at end of file +PUBLIC_URL=https://vertd.your-domain.here + +# CORS origins setup +# Can either be "*" or comma separated origins: "https://origin1.com,https://origin2.com" +# If not defined, fall back automatically to * +CORS_ORIGINS=* \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 82f2067..871414c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,6 +8,7 @@ services: - WEBHOOK_PINGS=${WEBHOOK_PINGS} - ADMIN_PASSWORD=${ADMIN_PASSWORD} - PUBLIC_URL=${PUBLIC_URL} + - CORS_ORIGINS=${CORS_ORIGINS:-*} ports: - "${PORT:-24153}:24153" diff --git a/src/http/mod.rs b/src/http/mod.rs index 86928c5..1b27d58 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -8,16 +8,68 @@ use crate::http::services::keep::keep; mod response; mod services; +#[derive(Clone)] +enum CorsConfig { + Any, + Specific(Vec), +} + +fn parse_cors(origins_raw: &str) -> CorsConfig { + let raw = origins_raw.trim(); + + if raw.is_empty() || raw == "*" { + return CorsConfig::Any; + } + + let origins = raw + .split(',') + .map(str::trim) + .filter(|s| !s.is_empty()) + .map(String::from) + .collect::>(); + + CorsConfig::Specific(origins) +} + +fn build_cors(config: &CorsConfig) -> Cors { + match config { + CorsConfig::Any => Cors::default() + .allow_any_origin() + .allow_any_method() + .allow_any_header(), + + CorsConfig::Specific(origins) => { + let mut cors = Cors::default().allow_any_method().allow_any_header(); + + for origin in origins { + cors = cors.allowed_origin(origin); + } + + cors + } + } +} + pub async fn start_http() -> anyhow::Result<()> { - let server = HttpServer::new(|| { - App::new() - .wrap( - Cors::default() - .allow_any_origin() - .allow_any_method() - .allow_any_header(), - ) - .service( + let cors_origins = std::env::var("CORS_ORIGINS").unwrap_or_else(|_| "*".to_string()); + let cors_config = parse_cors(&cors_origins); + + match &cors_config { + CorsConfig::Any => info!("CORS: allow any origin (*)"), + CorsConfig::Specific(origins) => { + info!("CORS: allowed origins:"); + for origin in origins { + info!(" - {}", origin); + } + } + } + + let server = HttpServer::new({ + let cors_config = cors_config.clone(); + move || { + let cors = build_cors(&cors_config); + + App::new().wrap(cors).service( web::scope("/api") .service(upload) .service(download) @@ -25,7 +77,9 @@ pub async fn start_http() -> anyhow::Result<()> { .service(version) .service(keep), ) + } }); + let port = std::env::var("PORT").unwrap_or_else(|_| "24153".to_string()); if !port.chars().all(char::is_numeric) { anyhow::bail!("PORT must be a number");