diff --git a/README.md b/README.md index d9770b7..1f30089 100644 --- a/README.md +++ b/README.md @@ -12,13 +12,11 @@ The Firebase Admin Rust SDK enables access to Firebase services from privileged use rs_firebase_admin_sdk::{ auth::{FirebaseAuthService, UserIdentifiers}, client::ApiHttpClient, - App, credentials_provider, + App, }; -// Load your GCP SA from env, see https://crates.io/crates/gcp_auth for more details -let gcp_service_account = credentials_provider().await.unwrap(); // Create live (not emulated) context for Firebase app -let live_app = App::live(gcp_service_account.into()).await.unwrap(); +let live_app = App::live().await.unwrap(); // Create Firebase authentication admin client let auth_admin = live_app.auth(); diff --git a/examples/clear_emulator/src/main.rs b/examples/clear_emulator/src/main.rs index 35ac3b9..074176f 100644 --- a/examples/clear_emulator/src/main.rs +++ b/examples/clear_emulator/src/main.rs @@ -12,7 +12,7 @@ where #[tokio::main] async fn main() { - let emulator_app = App::emulated("my_project".into()); + let emulator_app = App::emulated(); let emulator_admin = emulator_app.auth("http://localhost:9099".into()); clear_emulator(&emulator_admin).await; diff --git a/examples/get_users/src/main.rs b/examples/get_users/src/main.rs index 3388820..09c41ad 100644 --- a/examples/get_users/src/main.rs +++ b/examples/get_users/src/main.rs @@ -2,7 +2,6 @@ use rs_firebase_admin_sdk::{ App, auth::{FirebaseAuthService, UserList}, client::ApiHttpClient, - credentials_provider, }; /// Generic method to print out all live users, fetch 10 at a time @@ -27,18 +26,13 @@ where #[tokio::main] async fn main() { - // Live Firebase App - let gcp_service_account = credentials_provider().await.unwrap(); - - let live_app = App::live(gcp_service_account).await.unwrap(); - + let live_app = App::live().await.unwrap(); let live_auth_admin = live_app.auth(); print_all_users(&live_auth_admin).await; // Emulator Firebase App - let emulator_app = App::emulated("my_project".into()); - + let emulator_app = App::emulated(); let emulator_auth_admin = emulator_app.auth("http://localhost:9099".into()); print_all_users(&emulator_auth_admin).await; diff --git a/examples/verify_token/src/main.rs b/examples/verify_token/src/main.rs index d9f0386..1045a7c 100644 --- a/examples/verify_token/src/main.rs +++ b/examples/verify_token/src/main.rs @@ -1,29 +1,27 @@ -use rs_firebase_admin_sdk::{App, auth::token::TokenVerifier, credentials_provider}; +use rs_firebase_admin_sdk::{App, jwt::TokenValidator}; -async fn verify_token(token: &str, verifier: &T) { - match verifier.verify_token(token).await { +async fn verify_token(token: &str, validator: &T) { + match validator.validate(token).await { Ok(token) => { - let user_id = token.critical_claims.sub; + let user_id = token.get("sub").unwrap().as_str().unwrap(); println!("Token for user {user_id} is valid!") } Err(err) => { - println!("Token is invalid because {err}!") + println!("Token is invalid because {err:?}!") } } } #[tokio::main] async fn main() { + // Live let oidc_token = std::env::var("ID_TOKEN").unwrap(); + let live_app = App::live().await.unwrap(); + let live_token_validator = live_app.id_token_verifier().await.unwrap(); + verify_token(&oidc_token, &live_token_validator).await; - // Live Firebase App - let gcp_service_account = credentials_provider().await.unwrap(); - let live_app = App::live(gcp_service_account).await.unwrap(); - let live_token_verifier = live_app.id_token_verifier().await.unwrap(); - verify_token(&oidc_token, &live_token_verifier).await; - - // Emulator Firebase App - let emulator_app = App::emulated("my_project".into()); - let emulator_token_verifier = emulator_app.id_token_verifier(); - verify_token(&oidc_token, &emulator_token_verifier).await; + // Emulator + let emulator_app = App::emulated(); + let emulator_token_validator = emulator_app.id_token_verifier(); + verify_token(&oidc_token, &emulator_token_validator).await; } diff --git a/lib/Cargo.toml b/lib/Cargo.toml index d0323b3..97196c7 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rs-firebase-admin-sdk" -version = "3.0.0" +version = "4.0.0" rust-version = "1.85" edition = "2024" authors = ["Kostas Petrikas"] @@ -16,7 +16,7 @@ doctest = false [features] default = ["tokens", "reqwest/default-tls"] rustls-tls = ["reqwest/rustls-tls"] -tokens = ["dep:openssl"] +tokens = ["dep:jsonwebtoken", "dep:jsonwebtoken-jwks-cache"] [dependencies] tokio = { version = "1.48", features = ["sync"], default-features = false } @@ -29,10 +29,11 @@ headers = "0.4" reqwest = { version = "0.12", features = ["charset", "json"], default-features = false } urlencoding = "2.1" bytes = "1" -gcp_auth = "0.12" +google-cloud-auth = "1.3" time = { version = "0.3", features = ["serde"] } base64 = "0.22" -openssl = { version = "0.10", optional = true } +jsonwebtoken = { version = "10.2", optional = true } +jsonwebtoken-jwks-cache = { version = "0.1", optional = true } [dev-dependencies] tokio = { version = "1.48", features = ["macros", "rt-multi-thread"] } diff --git a/lib/src/auth/mod.rs b/lib/src/auth/mod.rs index d92e606..ab7c148 100644 --- a/lib/src/auth/mod.rs +++ b/lib/src/auth/mod.rs @@ -7,9 +7,6 @@ pub mod claims; pub mod import; pub mod oob_code; -#[cfg(feature = "tokens")] -pub mod token; - use crate::api_uri::{ApiUriBuilder, FirebaseAuthEmulatorRestApi, FirebaseAuthRestApi}; use crate::client::ApiHttpClient; use crate::client::error::ApiClientError; @@ -27,11 +24,6 @@ use time::{Duration, OffsetDateTime}; const FIREBASE_AUTH_REST_AUTHORITY: &str = "identitytoolkit.googleapis.com"; -const FIREBASE_AUTH_SCOPES: [&str; 2] = [ - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/userinfo.email", -]; - #[derive(Serialize, Debug, Clone, Default)] #[serde(rename_all = "camelCase")] pub struct NewUser { @@ -363,7 +355,7 @@ pub trait FirebaseAuthService: Send + Sync + 'static { .get_auth_uri_builder() .build(FirebaseAuthRestApi::CreateUser); - client.send_request_body(uri, Method::POST, user, &FIREBASE_AUTH_SCOPES) + client.send_request_body(uri, Method::POST, user) } /// Get first user that matches given identifier filter @@ -413,7 +405,6 @@ pub trait FirebaseAuthService: Send + Sync + 'static { uri_builder.build(FirebaseAuthRestApi::GetUsers), Method::POST, indentifiers, - &FIREBASE_AUTH_SCOPES, ) .await?; @@ -460,7 +451,6 @@ pub trait FirebaseAuthService: Send + Sync + 'static { uri_builder.build(FirebaseAuthRestApi::ListUsers), params.into_iter(), Method::GET, - &FIREBASE_AUTH_SCOPES, ) .await?; @@ -482,7 +472,6 @@ pub trait FirebaseAuthService: Send + Sync + 'static { uri_builder.build(FirebaseAuthRestApi::DeleteUser), Method::POST, UserId { uid }, - &FIREBASE_AUTH_SCOPES, ) .await } @@ -503,7 +492,6 @@ pub trait FirebaseAuthService: Send + Sync + 'static { uri_builder.build(FirebaseAuthRestApi::DeleteUsers), Method::POST, UserIds { uids, force }, - &FIREBASE_AUTH_SCOPES, ) .await } @@ -532,7 +520,6 @@ pub trait FirebaseAuthService: Send + Sync + 'static { uri_builder.build(FirebaseAuthRestApi::UpdateUser), Method::POST, update, - &FIREBASE_AUTH_SCOPES, ) .await } @@ -562,7 +549,6 @@ pub trait FirebaseAuthService: Send + Sync + 'static { uri_builder.build(FirebaseAuthRestApi::ImportUsers), Method::POST, UserImportRecords { users }, - &FIREBASE_AUTH_SCOPES, ) .await?; @@ -593,7 +579,6 @@ pub trait FirebaseAuthService: Send + Sync + 'static { uri_builder.build(FirebaseAuthRestApi::SendOobCode), Method::POST, oob_action, - &FIREBASE_AUTH_SCOPES, ) .await?; @@ -622,7 +607,6 @@ pub trait FirebaseAuthService: Send + Sync + 'static { uri_builder.build(FirebaseAuthRestApi::CreateSessionCookie), Method::POST, create_cookie, - &FIREBASE_AUTH_SCOPES, ) .await?; @@ -689,7 +673,6 @@ where .send_request( uri_builder.build(FirebaseAuthEmulatorRestApi::ClearUserAccounts), Method::DELETE, - &FIREBASE_AUTH_SCOPES, ) .await?; @@ -709,7 +692,6 @@ where .send_request( uri_builder.build(FirebaseAuthEmulatorRestApi::Configuration), Method::GET, - &FIREBASE_AUTH_SCOPES, ) .await } @@ -729,7 +711,6 @@ where uri_builder.build(FirebaseAuthEmulatorRestApi::Configuration), Method::PATCH, configuration, - &FIREBASE_AUTH_SCOPES, ) .await } @@ -747,7 +728,6 @@ where .send_request( uri_builder.build(FirebaseAuthEmulatorRestApi::OobCodes), Method::GET, - &FIREBASE_AUTH_SCOPES, ) .await?; @@ -767,7 +747,6 @@ where .send_request( uri_builder.build(FirebaseAuthEmulatorRestApi::SmsVerificationCodes), Method::GET, - &FIREBASE_AUTH_SCOPES, ) .await } diff --git a/lib/src/auth/test.rs b/lib/src/auth/test.rs index b25fcbb..4ff13bb 100644 --- a/lib/src/auth/test.rs +++ b/lib/src/auth/test.rs @@ -1,6 +1,6 @@ use super::import::{PasswordHash, UserImportRecord}; #[cfg(feature = "tokens")] -use super::token::jwt::JWToken; +// use super::token::jwt::JWToken; use super::{ AttributeOp, Claims, FirebaseAuth, FirebaseAuthService, FirebaseEmulatorAuthService, NewUser, OobCode, OobCodeAction, OobCodeActionType, UserIdentifiers, UserList, UserUpdate, @@ -18,7 +18,7 @@ use time::Duration; use tokio; fn get_auth_service() -> FirebaseAuth> { - App::emulated("demo-firebase-project".into()).auth("http://emulator:9099".parse().unwrap()) + App::emulated().auth("http://emulator:9099".parse().unwrap()) } #[derive(Serialize)] @@ -529,6 +529,7 @@ async fn test_generate_email_action_link() { #[tokio::test] #[serial] async fn test_create_session_cookie() { + use crate::jwt::{EmulatorValidator, TokenValidator}; let auth = get_auth_service(); auth.create_user(NewUser::email_and_password( @@ -544,7 +545,9 @@ async fn test_create_session_cookie() { .await .unwrap(); - JWToken::from_encoded(&cookie).expect("Got invalid session cookie token"); + let claims = EmulatorValidator.validate(&cookie).await.unwrap(); + let email = claims.get("email").unwrap().as_str().unwrap(); + assert_eq!(email, "test@example.com"); auth.clear_all_users().await.unwrap(); } diff --git a/lib/src/auth/token/cache/error.rs b/lib/src/auth/token/cache/error.rs deleted file mode 100644 index 1464268..0000000 --- a/lib/src/auth/token/cache/error.rs +++ /dev/null @@ -1,16 +0,0 @@ -use reqwest::StatusCode; -use thiserror::Error; - -#[derive(Error, Debug, Clone)] -#[error("Failed while caching resource")] -pub struct CacheError; - -#[derive(Error, Debug, Clone)] -pub enum ClientError { - #[error("Failed to fetch HTTP resource")] - FailedToFetch, - #[error("Unexpected HTTP status code {0}")] - BadHttpResponse(StatusCode), - #[error("Failed to deserialize resource")] - FailedToDeserialize, -} diff --git a/lib/src/auth/token/cache/mod.rs b/lib/src/auth/token/cache/mod.rs deleted file mode 100644 index 5aaa1e2..0000000 --- a/lib/src/auth/token/cache/mod.rs +++ /dev/null @@ -1,176 +0,0 @@ -//! Public key caching for use in efficient token verification - -#[cfg(test)] -mod test; - -pub mod error; - -use super::JwtRsaPubKey; -use bytes::Bytes; -use error::{CacheError, ClientError}; -use error_stack::{Report, ResultExt}; -use headers::{CacheControl, HeaderMapExt}; -use reqwest::Client; -use serde::de::DeserializeOwned; -use serde_json::from_slice; -use std::collections::BTreeMap; -use std::future::Future; -use std::sync::Arc; -use std::time::{Duration, SystemTime}; -use tokio::sync::{Mutex, RwLock}; - -#[derive(Clone, Debug)] -struct Cache { - expires_at: SystemTime, - content: ContentT, -} - -impl Cache { - pub fn new(max_age: Duration, content: ContentT) -> Self { - Self { - expires_at: SystemTime::now() + max_age, - content, - } - } - - pub fn is_expired(&self) -> bool { - self.expires_at <= SystemTime::now() - } - - pub fn update(&mut self, max_age: Duration, content: ContentT) { - self.expires_at = SystemTime::now() + max_age; - self.content = content; - } -} - -#[derive(Clone, Debug)] -pub struct Resource { - pub data: Bytes, - pub max_age: Duration, -} - -pub trait CacheClient: Sized + Send + Sync -where - Self::Error: std::error::Error + Send + Sync + 'static, -{ - type Error; - - /// Simple async interface to fetch data and its TTL for an URI - fn fetch( - &self, - uri: &str, - ) -> impl Future>> + Send; -} - -impl CacheClient for Client { - type Error = ClientError; - - async fn fetch(&self, uri: &str) -> Result> { - let response = self - .get(uri) - .send() - .await - .change_context(ClientError::FailedToFetch)?; - - let status = response.status(); - - if !status.is_success() { - return Err(Report::new(ClientError::BadHttpResponse(status))); - } - - let cache_header: Option = response.headers().typed_get(); - let body = response - .bytes() - .await - .change_context(ClientError::FailedToFetch)?; - - if let Some(cache_header) = cache_header { - let ttl = cache_header - .s_max_age() - .unwrap_or_else(|| cache_header.max_age().unwrap_or_default()); - - return Ok(Resource { - data: body, - max_age: ttl, - }); - } - - Ok(Resource { - data: body, - max_age: Duration::default(), - }) - } -} - -pub struct HttpCache { - client: CacheClientT, - path: String, - cache: Arc>>, - refresh: Mutex<()>, -} - -impl HttpCache -where - CacheClientT: CacheClient, - ContentT: DeserializeOwned + Clone + Send + Sync, -{ - pub async fn new(client: CacheClientT, path: String) -> Result> { - let resource = client.fetch(&path).await.change_context(CacheError)?; - - let initial_cache: Cache = Cache::new( - resource.max_age, - from_slice(&resource.data).change_context(CacheError)?, - ); - - Ok(Self { - client, - path, - cache: Arc::new(RwLock::new(initial_cache)), - refresh: Mutex::new(()), - }) - } - - pub async fn get(&self) -> Result> { - let cache = self.cache.read().await.clone(); - if cache.is_expired() { - // to make sure only a single connection is being established to refresh the resource - let _refresh_guard = self.refresh.lock().await; - - // check if the cache has been refreshed by another co-routine - let cache = self.cache.read().await.clone(); - if !cache.is_expired() { - return Ok(cache.content); - } - - // refresh resource - let resource = self - .client - .fetch(&self.path) - .await - .change_context(CacheError)?; - - let content: ContentT = from_slice(&resource.data).change_context(CacheError)?; - - self.cache - .write() - .await - .update(resource.max_age, content.clone()); - - return Ok(content); - } - - Ok(cache.content) - } -} - -pub type PubKeys = BTreeMap; - -pub trait KeyCache { - fn get_keys(&self) -> impl Future>> + Send; -} - -impl KeyCache for HttpCache { - fn get_keys(&self) -> impl Future>> + Send { - self.get() - } -} diff --git a/lib/src/auth/token/cache/test.rs b/lib/src/auth/token/cache/test.rs deleted file mode 100644 index 50c4957..0000000 --- a/lib/src/auth/token/cache/test.rs +++ /dev/null @@ -1,75 +0,0 @@ -use super::{CacheClient, CacheError, HttpCache, Resource}; -use bytes::Bytes; -use error_stack::Report; -use serde_json::to_string; -use std::sync::Arc; -use std::time::Duration; -use tokio::sync::Mutex; - -struct CacheClientMock { - pub calls: Arc>, - response: Resource, -} - -impl CacheClientMock { - pub fn new(response: Resource) -> Self { - Self { - calls: Arc::new(Mutex::new(0)), - response, - } - } -} - -impl CacheClient for CacheClientMock { - type Error = CacheError; - - async fn fetch(&self, _uri: &str) -> Result> { - *self.calls.lock().await += 1; - - Ok(self.response.clone()) - } -} - -#[tokio::test] -async fn test_http_cache() { - let json = Bytes::copy_from_slice(to_string(&123).unwrap().as_bytes()); - let response = Resource { - data: json, - max_age: Duration::from_secs(999), - }; - let client = CacheClientMock::new(response); - let calls = client.calls.clone(); - - let http_cache = HttpCache::new(client, "http://localhost".parse().unwrap()) - .await - .unwrap(); - - let _: i32 = http_cache.get().await.unwrap(); - let _: i32 = http_cache.get().await.unwrap(); - let cached: i32 = http_cache.get().await.unwrap(); - - assert_eq!(cached, 123); - assert_eq!(*calls.lock().await, 1); -} - -#[tokio::test] -async fn test_http_cache_zero_ttl() { - let json = Bytes::copy_from_slice(to_string(&123).unwrap().as_bytes()); - let response = Resource { - data: json, - max_age: Duration::from_secs(0), - }; - let client = CacheClientMock::new(response); - let calls = client.calls.clone(); - - let http_cache = HttpCache::new(client, "http://localhost".parse().unwrap()) - .await - .unwrap(); - - let _: i32 = http_cache.get().await.unwrap(); - let _: i32 = http_cache.get().await.unwrap(); - let cached: i32 = http_cache.get().await.unwrap(); - - assert_eq!(cached, 123); - assert_eq!(*calls.lock().await, 4); -} diff --git a/lib/src/auth/token/crypto/mod.rs b/lib/src/auth/token/crypto/mod.rs deleted file mode 100644 index 0be6677..0000000 --- a/lib/src/auth/token/crypto/mod.rs +++ /dev/null @@ -1,110 +0,0 @@ -use super::jwt::{JwtSigner, error::JWTError}; -use base64::{self, Engine}; -use error_stack::{Report, ResultExt}; -use openssl::{ - asn1::Asn1Time, - bn::{BigNum, MsbOption}, - error::ErrorStack, - hash::MessageDigest, - pkey::{PKey, Private, Public}, - rsa::Rsa, - sign::{Signer, Verifier}, - x509::{X509, X509Name}, -}; -use serde::de::{self, Visitor}; -use std::fmt; - -impl JwtSigner for Signer<'_> { - fn sign_jwt(&mut self, header: &str, payload: &str) -> Result> { - self.update(header.as_bytes()) - .change_context(JWTError::FailedToEncode)?; - self.update(b".").change_context(JWTError::FailedToEncode)?; - self.update(payload.as_bytes()) - .change_context(JWTError::FailedToEncode)?; - - let signature = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode( - self.sign_to_vec() - .change_context(JWTError::FailedToEncode)?, - ); - - Ok(signature) - } -} - -#[derive(Debug, Clone)] -pub struct JwtRsaPubKey { - key: PKey, -} - -impl JwtRsaPubKey { - pub fn new(key: PKey) -> Self { - Self { key } - } - - pub fn verify(&self, payload: &[u8], signature: &[u8]) -> Result> { - let mut verifier = Verifier::new(MessageDigest::sha256(), &self.key)?; - verifier.update(payload)?; - - verifier.verify(signature).map_err(error_stack::Report::new) - } -} - -struct JwtRsaPubKeyVisitor; - -impl Visitor<'_> for JwtRsaPubKeyVisitor { - type Value = JwtRsaPubKey; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a string with public key in PEM format.") - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - let cert = X509::from_pem(value.as_bytes()).map_err(|e| E::custom(format!("{e:?}")))?; - let key = cert.public_key().map_err(|e| E::custom(format!("{e:?}")))?; - - Ok(JwtRsaPubKey { key }) - } -} - -impl<'de> de::Deserialize<'de> for JwtRsaPubKey { - fn deserialize(deserializer: D) -> Result - where - D: de::Deserializer<'de>, - { - deserializer.deserialize_str(JwtRsaPubKeyVisitor) - } -} - -/// Utility method for generating x.509 certificate for testing purposes -pub fn generate_test_cert() -> Result<(X509, PKey), Report> { - let rsa = Rsa::generate(2048)?; - let key_pair = PKey::from_rsa(rsa)?; - - let mut name_builder = X509Name::builder()?; - name_builder.append_entry_by_text("C", "JP")?; - name_builder.append_entry_by_text("O", "Firebase")?; - name_builder.append_entry_by_text("CN", "Firebase test")?; - let cert_name = name_builder.build(); - - let serial_number = { - let mut serial = BigNum::new()?; - serial.rand(159, MsbOption::MAYBE_ZERO, false)?; - serial.to_asn1_integer()? - }; - - let mut cert_builder = X509::builder()?; - cert_builder.set_version(1)?; - cert_builder.set_serial_number(&serial_number)?; - cert_builder.set_not_after(Asn1Time::days_from_now(1)?.as_ref())?; - cert_builder.set_not_before(Asn1Time::days_from_now(0)?.as_ref())?; - cert_builder.set_subject_name(&cert_name)?; - cert_builder.set_issuer_name(&cert_name)?; - cert_builder.set_pubkey(&key_pair)?; - cert_builder.sign(&key_pair, MessageDigest::sha256())?; - let cert = cert_builder.build(); - - Ok((cert, key_pair)) -} diff --git a/lib/src/auth/token/error.rs b/lib/src/auth/token/error.rs deleted file mode 100644 index f8e4775..0000000 --- a/lib/src/auth/token/error.rs +++ /dev/null @@ -1,25 +0,0 @@ -use thiserror::Error; - -#[derive(Error, Debug, Clone)] -pub enum TokenVerificationError { - #[error("Error happened while parsing the token")] - FailedParsing, - #[error("Error happened while fetching public keys")] - FailedGettingKeys, - #[error("Invalid key for token's signature")] - InvalidSignatureKey, - #[error("Invalid token's signature")] - InvalidSignature, - #[error("Invalid token's signature algorithm")] - InvalidSignatureAlgorithm, - #[error("Token is expired")] - Expired, - #[error("Token was issued in the future")] - IssuedInFuture, - #[error("Token has invalid audience")] - InvalidAudience, - #[error("Token has invalid issuer")] - InvalidIssuer, - #[error("Token has empty subject")] - MissingSubject, -} diff --git a/lib/src/auth/token/jwt/error.rs b/lib/src/auth/token/jwt/error.rs deleted file mode 100644 index 55dc1c4..0000000 --- a/lib/src/auth/token/jwt/error.rs +++ /dev/null @@ -1,15 +0,0 @@ -use thiserror::Error; - -#[derive(Error, Debug, Clone)] -pub enum JWTError { - #[error("Failed to parse token")] - FailedToParse, - #[error("Failed to encode token")] - FailedToEncode, - #[error("Token is missing header")] - MissingHeader, - #[error("Token is missing payload")] - MissingPayload, - #[error("Token is missing signature")] - MissingSignature, -} diff --git a/lib/src/auth/token/jwt/mod.rs b/lib/src/auth/token/jwt/mod.rs deleted file mode 100644 index 1864d7b..0000000 --- a/lib/src/auth/token/jwt/mod.rs +++ /dev/null @@ -1,118 +0,0 @@ -#[cfg(test)] -mod test; - -pub mod error; -pub mod util; - -use base64::{self, Engine}; -use error::JWTError; -use error_stack::{Report, ResultExt}; -use serde::{Deserialize, Serialize}; -use serde_json::{Value, from_slice, to_string}; -use std::collections::BTreeMap; -use time::{OffsetDateTime, serde::timestamp}; - -#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)] -pub enum JWTAlgorithm { - #[serde(rename = "none")] - NONE, - HS256, - HS384, - HS512, - RS256, - RS384, - RS512, - ES256, - ES384, - ES512, -} - -#[derive(Debug, Deserialize, Serialize, Clone)] -pub struct TokenHeader { - pub alg: JWTAlgorithm, - pub kid: Option, -} - -#[derive(Debug, Deserialize, Serialize, Clone)] -pub struct TokenClaims { - #[serde(with = "timestamp")] - pub exp: OffsetDateTime, - #[serde(with = "timestamp")] - pub iat: OffsetDateTime, - pub aud: String, - pub iss: String, - pub sub: String, - #[serde(with = "timestamp")] - pub auth_time: OffsetDateTime, -} - -#[derive(Debug, Clone)] -pub struct JWToken { - pub header: TokenHeader, - pub critical_claims: TokenClaims, - pub all_claims: BTreeMap, - pub payload: String, - pub signature: Vec, -} - -impl JWToken { - pub fn from_encoded(encoded: &str) -> Result> { - let mut parts = encoded.split('.'); - - let header_slice = parts.next().ok_or(Report::new(JWTError::MissingHeader))?; - - let header: TokenHeader = from_slice( - &base64::engine::general_purpose::URL_SAFE_NO_PAD - .decode(header_slice) - .change_context(JWTError::FailedToParse)?, - ) - .change_context(JWTError::FailedToParse)?; - - let claims_slice = parts.next().ok_or(Report::new(JWTError::MissingHeader))?; - let claims = base64::engine::general_purpose::URL_SAFE_NO_PAD - .decode(claims_slice) - .change_context(JWTError::FailedToParse)?; - - let critical_claims: TokenClaims = - from_slice(&claims).change_context(JWTError::FailedToParse)?; - let all_claims: BTreeMap = - from_slice(&claims).change_context(JWTError::FailedToParse)?; - - let signature = base64::engine::general_purpose::URL_SAFE_NO_PAD - .decode( - parts - .next() - .ok_or(Report::new(JWTError::MissingSignature))?, - ) - .change_context(JWTError::FailedToParse)?; - - Ok(Self { - header, - critical_claims, - all_claims, - payload: String::new() + header_slice + "." + claims_slice, - signature, - }) - } -} - -pub trait JwtSigner { - fn sign_jwt(&mut self, header: &str, payload: &str) -> Result>; -} - -/// Utility method for generating JWTs -pub fn encode_jwt( - header: &HeaderT, - payload: &PayloadT, - mut signer: SignerT, -) -> Result> { - let encoded_header = base64::engine::general_purpose::URL_SAFE_NO_PAD - .encode(to_string(header).change_context(JWTError::FailedToEncode)?); - - let encoded_payload = base64::engine::general_purpose::URL_SAFE_NO_PAD - .encode(to_string(payload).change_context(JWTError::FailedToEncode)?); - - let encoded_signature = signer.sign_jwt(&encoded_header, &encoded_payload)?; - - Ok(encoded_header + "." + &encoded_payload + "." + &encoded_signature) -} diff --git a/lib/src/auth/token/jwt/test.rs b/lib/src/auth/token/jwt/test.rs deleted file mode 100644 index 3fcdfd7..0000000 --- a/lib/src/auth/token/jwt/test.rs +++ /dev/null @@ -1,62 +0,0 @@ -use super::util::generate_test_token; -use super::{JWTAlgorithm, JWToken, TokenClaims, TokenHeader}; -use serde_json::Value; -use std::collections::BTreeMap; -use time::{Duration, OffsetDateTime}; - -#[test] -fn test_jwt_parse() { - let issued_at = OffsetDateTime::now_utc() - .replace_microsecond(0) - .unwrap() - .replace_millisecond(0) - .unwrap(); - let valid_until = issued_at + Duration::days(1); - - let (encoded_token, _) = generate_test_token( - TokenHeader { - alg: JWTAlgorithm::RS256, - kid: Some("123".into()), - }, - TokenClaims { - exp: valid_until, - iat: issued_at, - aud: "FB aud".into(), - iss: "FB iss".into(), - sub: "FB sub".into(), - auth_time: issued_at, - }, - ); - let decoded = JWToken::from_encoded(&encoded_token).unwrap(); - - assert_eq!(decoded.header.alg, JWTAlgorithm::RS256); - assert_eq!(&decoded.header.kid, &Some("123".into())); - assert_eq!(&decoded.critical_claims.exp, &valid_until); - assert_eq!(&decoded.critical_claims.iat, &issued_at); - assert_eq!(&decoded.critical_claims.auth_time, &issued_at); - assert_eq!(&decoded.critical_claims.aud, "FB aud"); - assert_eq!(&decoded.critical_claims.iss, "FB iss"); - assert_eq!(&decoded.critical_claims.sub, "FB sub"); - - let expected_all_claims: BTreeMap = vec![ - ( - "exp".into(), - Value::Number(valid_until.unix_timestamp().into()), - ), - ( - "iat".into(), - Value::Number(issued_at.unix_timestamp().into()), - ), - ( - "auth_time".into(), - Value::Number(issued_at.unix_timestamp().into()), - ), - ("aud".into(), Value::String("FB aud".into())), - ("iss".into(), Value::String("FB iss".into())), - ("sub".into(), Value::String("FB sub".into())), - ] - .into_iter() - .collect(); - - assert_eq!(decoded.all_claims, expected_all_claims); -} diff --git a/lib/src/auth/token/jwt/util.rs b/lib/src/auth/token/jwt/util.rs deleted file mode 100644 index bab63b4..0000000 --- a/lib/src/auth/token/jwt/util.rs +++ /dev/null @@ -1,11 +0,0 @@ -use super::{TokenClaims, TokenHeader, encode_jwt}; -use crate::auth::token::crypto::generate_test_cert; -use openssl::{hash::MessageDigest, sign::Signer, x509::X509}; - -/// Utility method for generating signed RS256 JWTs to be used in tests -pub fn generate_test_token(header: TokenHeader, critical_claims: TokenClaims) -> (String, X509) { - let (cert, key_pair) = generate_test_cert().unwrap(); - let signer = Signer::new(MessageDigest::sha256(), &key_pair).unwrap(); - - (encode_jwt(&header, &critical_claims, signer).unwrap(), cert) -} diff --git a/lib/src/auth/token/mod.rs b/lib/src/auth/token/mod.rs deleted file mode 100644 index 1748abd..0000000 --- a/lib/src/auth/token/mod.rs +++ /dev/null @@ -1,184 +0,0 @@ -#[cfg(test)] -mod test; - -pub mod cache; -pub mod crypto; -pub mod error; -pub mod jwt; - -use cache::KeyCache; -use crypto::JwtRsaPubKey; -use error::TokenVerificationError; -use error_stack::{Report, ResultExt}; -use jwt::{JWTAlgorithm, JWToken}; -use std::future::Future; -use time::{Duration, OffsetDateTime}; - -const GOOGLE_ID_TOKEN_ISSUER_PREFIX: &str = "https://securetoken.google.com/"; -const GOOGLE_COOKIE_ISSUER_PREFIX: &str = "https://session.firebase.google.com/"; - -#[cfg(feature = "tokens")] -pub(crate) const GOOGLE_PUB_KEY_URI: &str = - "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com"; -#[cfg(feature = "tokens")] -pub(crate) const GOOGLE_COOKIE_PUB_KEY_URI: &str = - "https://www.googleapis.com/identitytoolkit/v3/relyingparty/publicKeys"; - -pub trait TokenVerifier: Sized + Sync + Send { - fn verify_token( - &self, - id_token: &str, - ) -> impl Future>> + Send; -} - -pub struct EmulatedTokenVerifier { - _project_id: String, - _issuer: String, -} - -impl EmulatedTokenVerifier { - pub fn new(project_id: String) -> Self { - Self { - _project_id: project_id.clone(), - _issuer: project_id, - } - } -} - -impl TokenVerifier for EmulatedTokenVerifier { - async fn verify_token( - &self, - id_token: &str, - ) -> Result> { - let token = JWToken::from_encoded(id_token) - .change_context(TokenVerificationError::FailedParsing)?; - - // TODO: implement claim checks for emulator - - Ok(token) - } -} - -pub struct LiveTokenVerifier { - project_id: String, - issuer: String, - key_cache: CacheT, -} - -impl TokenVerifier for LiveTokenVerifier { - async fn verify_token( - &self, - id_token: &str, - ) -> Result> { - let token = JWToken::from_encoded(id_token) - .change_context(TokenVerificationError::FailedParsing)?; - - self.verify(&token).await?; - - Ok(token) - } -} - -impl LiveTokenVerifier { - /// Create new ID token verifier - pub fn new_id_verifier( - project_id: String, - key_cache: CacheT, - ) -> Result> { - Ok(Self { - issuer: String::new() + GOOGLE_ID_TOKEN_ISSUER_PREFIX + &project_id, - project_id, - key_cache, - }) - } - - /// Create new cookie token verifier - pub fn new_cookie_verifier( - project_id: String, - key_cache: CacheT, - ) -> Result> { - Ok(Self { - issuer: String::new() + GOOGLE_COOKIE_ISSUER_PREFIX + &project_id, - project_id, - key_cache, - }) - } - - async fn verify_signature( - &self, - token: &JWToken, - ) -> Result<(), Report> { - let keys = self - .key_cache - .get_keys() - .await - .change_context(TokenVerificationError::FailedGettingKeys)?; - - let key_id = token - .header - .kid - .as_ref() - .ok_or(TokenVerificationError::FailedGettingKeys)?; - - let key = keys - .get(key_id) - .ok_or(Report::new(TokenVerificationError::InvalidSignatureKey))?; - - let is_valid = key - .verify(token.payload.as_bytes(), &token.signature) - .change_context(TokenVerificationError::InvalidSignature)?; - - if !is_valid { - return Err(Report::new(TokenVerificationError::InvalidSignature)); - } - - Ok(()) - } - - fn verify_header(&self, token: &JWToken) -> Result<(), Report> { - match token.header.alg { - JWTAlgorithm::RS256 => Ok(()), - _ => Err(Report::new( - TokenVerificationError::InvalidSignatureAlgorithm, - )), - } - } - - fn verify_claims(&self, token: &JWToken) -> Result<(), Report> { - let now = OffsetDateTime::now_utc(); - - if token.critical_claims.exp <= now { - return Err(Report::new(TokenVerificationError::Expired)); - } - - // Firebase sometimes has wonky iat, pad with 10secs - if token.critical_claims.iat > now + Duration::seconds(10) { - return Err(Report::new(TokenVerificationError::IssuedInFuture)); - } - - if token.critical_claims.auth_time > now { - return Err(Report::new(TokenVerificationError::IssuedInFuture)); - } - - if token.critical_claims.aud != self.project_id { - return Err(Report::new(TokenVerificationError::InvalidAudience)); - } - - if token.critical_claims.iss != self.issuer { - return Err(Report::new(TokenVerificationError::InvalidIssuer)); - } - - if token.critical_claims.sub.is_empty() { - return Err(Report::new(TokenVerificationError::MissingSubject)); - } - - Ok(()) - } - - /// verify JWToken's attributes and signature - pub async fn verify(&self, token: &JWToken) -> Result<(), Report> { - self.verify_header(token)?; - self.verify_claims(token)?; - self.verify_signature(token).await - } -} diff --git a/lib/src/auth/token/test.rs b/lib/src/auth/token/test.rs deleted file mode 100644 index ea5ba5d..0000000 --- a/lib/src/auth/token/test.rs +++ /dev/null @@ -1,310 +0,0 @@ -use super::cache::{CacheClient, HttpCache, Resource}; -use super::crypto::generate_test_cert; -use super::jwt::{JWTAlgorithm, JWToken, TokenClaims, TokenHeader, util::generate_test_token}; -use super::{LiveTokenVerifier, TokenVerificationError}; -use error_stack::Report; -use serde_json::to_string; -use std::collections::BTreeMap; -use thiserror::Error; -use time::{Duration, OffsetDateTime}; - -#[derive(Error, Debug, Clone)] -#[error("CertCacheClientMockError")] -pub struct CertCacheClientMockError; - -/// Mock for public x.509 certificate cache -struct CertCacheClientMock { - keys: Vec, -} - -impl CertCacheClientMock { - pub fn mock(keys: Vec) -> Self { - Self { keys } - } -} - -impl CacheClient for CertCacheClientMock { - type Error = CertCacheClientMockError; - - async fn fetch(&self, _: &str) -> Result> { - Ok(Resource { - data: self.keys.clone().into(), - max_age: std::time::Duration::from_secs(60), - }) - } -} - -/// Mock and test correct token verification -#[tokio::test] -async fn test_verify_correct_token() { - let issued_at = OffsetDateTime::now_utc() - .replace_microsecond(0) - .unwrap() - .replace_millisecond(0) - .unwrap(); - let valid_until = issued_at + Duration::days(1); - let project_id = String::from("test_project"); - - let (encoded_token, cert) = generate_test_token( - TokenHeader { - alg: JWTAlgorithm::RS256, - kid: Some("123".into()), - }, - TokenClaims { - exp: valid_until, - iat: issued_at, - aud: project_id.clone(), - iss: format!("https://securetoken.google.com/{project_id}"), - sub: "user123".into(), - auth_time: issued_at, - }, - ); - - let cert_pem = String::from_utf8(cert.to_pem().unwrap()).unwrap(); - let key_map: BTreeMap = - vec![(String::from("123"), cert_pem)].into_iter().collect(); - let key_map_json: Vec = to_string(&key_map).unwrap().as_bytes().to_vec(); - - let decoded_token = JWToken::from_encoded(&encoded_token).unwrap(); - let cache_client = HttpCache::new(CertCacheClientMock::mock(key_map_json), String::default()) - .await - .unwrap(); - - let verifier = LiveTokenVerifier::new_id_verifier(project_id, cache_client).unwrap(); - - verifier.verify(&decoded_token).await.unwrap(); -} - -/// Mock and test token with incorrect signature verification -#[tokio::test] -async fn test_verify_incorrect_token_signature_key() { - let issued_at = OffsetDateTime::now_utc() - .replace_microsecond(0) - .unwrap() - .replace_millisecond(0) - .unwrap(); - let valid_until = issued_at + Duration::days(1); - let project_id = String::from("test_project"); - - let (encoded_token, _) = generate_test_token( - TokenHeader { - alg: JWTAlgorithm::RS256, - kid: Some("123".into()), - }, - TokenClaims { - exp: valid_until, - iat: issued_at, - aud: project_id.clone(), - iss: format!("https://securetoken.google.com/{project_id}"), - sub: "user123".into(), - auth_time: issued_at, - }, - ); - - // Put different certificate than the one used to sign token into the cache - let (cert, _) = generate_test_cert().unwrap(); - - let cert_pem = String::from_utf8(cert.to_pem().unwrap()).unwrap(); - let key_map: BTreeMap = - vec![(String::from("123"), cert_pem)].into_iter().collect(); - let key_map_json: Vec = to_string(&key_map).unwrap().as_bytes().to_vec(); - - let decoded_token = JWToken::from_encoded(&encoded_token).unwrap(); - let cache_client = HttpCache::new(CertCacheClientMock::mock(key_map_json), String::default()) - .await - .unwrap(); - - let verifier = LiveTokenVerifier::new_id_verifier(project_id, cache_client).unwrap(); - - let result = verifier.verify(&decoded_token).await; - - if let Err(err) = result { - match err.current_context() { - TokenVerificationError::InvalidSignature => {} - _ => panic!("Expected invalid signature error but got {err}"), - } - } else { - panic!("Should not be a valid token because of incorrect certificate for signature used"); - } -} - -/// Mock and test token with incorrect expiration verification -#[tokio::test] -async fn test_verify_token_expiration() { - let issued_at = OffsetDateTime::now_utc() - .replace_microsecond(0) - .unwrap() - .replace_millisecond(0) - .unwrap(); - let valid_until = issued_at - Duration::days(1); - let project_id = String::from("test_project"); - - let (encoded_token, cert) = generate_test_token( - TokenHeader { - alg: JWTAlgorithm::RS256, - kid: Some("123".into()), - }, - TokenClaims { - exp: valid_until, - iat: issued_at, - aud: project_id.clone(), - iss: format!("https://securetoken.google.com/{project_id}"), - sub: "user123".into(), - auth_time: issued_at, - }, - ); - - let cert_pem = String::from_utf8(cert.to_pem().unwrap()).unwrap(); - let key_map: BTreeMap = - vec![(String::from("123"), cert_pem)].into_iter().collect(); - let key_map_json: Vec = to_string(&key_map).unwrap().as_bytes().to_vec(); - - let decoded_token = JWToken::from_encoded(&encoded_token).unwrap(); - let cache_client = HttpCache::new(CertCacheClientMock::mock(key_map_json), String::default()) - .await - .unwrap(); - - let verifier = LiveTokenVerifier::new_id_verifier(project_id.clone(), cache_client).unwrap(); - - let result = verifier.verify(&decoded_token).await; - - if let Err(err) = result { - match err.current_context() { - TokenVerificationError::Expired => {} - _ => panic!("Expected expired token error but got {err}"), - } - } else { - panic!("Should not be a valid token because the token is expired"); - } - - // test with issuing date in the future - let issued_at = issued_at + Duration::days(1); - let valid_until = issued_at + Duration::days(1); - - let (encoded_token, cert) = generate_test_token( - TokenHeader { - alg: JWTAlgorithm::RS256, - kid: Some("123".into()), - }, - TokenClaims { - exp: valid_until, - iat: issued_at, - aud: project_id.clone(), - iss: format!("https://securetoken.google.com/{project_id}"), - sub: "user123".into(), - auth_time: issued_at, - }, - ); - - let cert_pem = String::from_utf8(cert.to_pem().unwrap()).unwrap(); - let key_map: BTreeMap = - vec![(String::from("123"), cert_pem)].into_iter().collect(); - let key_map_json: Vec = to_string(&key_map).unwrap().as_bytes().to_vec(); - - let decoded_token = JWToken::from_encoded(&encoded_token).unwrap(); - let cache_client = HttpCache::new(CertCacheClientMock::mock(key_map_json), String::default()) - .await - .unwrap(); - - let verifier = LiveTokenVerifier::new_id_verifier(project_id, cache_client).unwrap(); - - let result = verifier.verify(&decoded_token).await; - - if let Err(err) = result { - match err.current_context() { - TokenVerificationError::IssuedInFuture => {} - _ => panic!("Expected token issued in the future error but got {err}"), - } - } else { - panic!("Should not be a valid token because the token was issued in the future"); - } -} - -/// Mock and test token with incorrect claims verification -#[tokio::test] -async fn test_verify_token_claims() { - let issued_at = OffsetDateTime::now_utc() - .replace_microsecond(0) - .unwrap() - .replace_millisecond(0) - .unwrap(); - let valid_until = issued_at + Duration::days(1); - let project_id = String::from("test_project"); - - let (encoded_token, cert) = generate_test_token( - TokenHeader { - alg: JWTAlgorithm::RS256, - kid: Some("123".into()), - }, - TokenClaims { - exp: valid_until, - iat: issued_at, - aud: "another_project".into(), - iss: format!("https://securetoken.google.com/{project_id}"), - sub: "user123".into(), - auth_time: issued_at, - }, - ); - - let cert_pem = String::from_utf8(cert.to_pem().unwrap()).unwrap(); - let key_map: BTreeMap = - vec![(String::from("123"), cert_pem)].into_iter().collect(); - let key_map_json: Vec = to_string(&key_map).unwrap().as_bytes().to_vec(); - - let decoded_token = JWToken::from_encoded(&encoded_token).unwrap(); - let cache_client = HttpCache::new(CertCacheClientMock::mock(key_map_json), String::default()) - .await - .unwrap(); - - let verifier = LiveTokenVerifier::new_id_verifier(project_id.clone(), cache_client).unwrap(); - - let result = verifier.verify(&decoded_token).await; - - if let Err(err) = result { - match err.current_context() { - TokenVerificationError::InvalidAudience => {} - _ => panic!("Expected invalid audience error but got {err}"), - } - } else { - panic!("Should not be a valid token because the audience is invalid"); - } - - // test with wrong issuer claim - let (encoded_token, cert) = generate_test_token( - TokenHeader { - alg: JWTAlgorithm::RS256, - kid: Some("123".into()), - }, - TokenClaims { - exp: valid_until, - iat: issued_at, - aud: project_id.clone(), - iss: "https://securetoken.google.com/another_project".into(), - sub: "user123".into(), - auth_time: issued_at, - }, - ); - - let cert_pem = String::from_utf8(cert.to_pem().unwrap()).unwrap(); - let key_map: BTreeMap = - vec![(String::from("123"), cert_pem)].into_iter().collect(); - let key_map_json: Vec = to_string(&key_map).unwrap().as_bytes().to_vec(); - - let decoded_token = JWToken::from_encoded(&encoded_token).unwrap(); - let cache_client = HttpCache::new(CertCacheClientMock::mock(key_map_json), String::default()) - .await - .unwrap(); - - let verifier = LiveTokenVerifier::new_id_verifier(project_id, cache_client).unwrap(); - - let result = verifier.verify(&decoded_token).await; - - if let Err(err) = result { - match err.current_context() { - TokenVerificationError::InvalidIssuer => {} - _ => panic!("Expected invalid token issuer error but got {err}"), - } - } else { - panic!("Should not be a valid token because the token has invalid issuer"); - } -} diff --git a/lib/src/client/mod.rs b/lib/src/client/mod.rs index ec48b9e..fa4b962 100644 --- a/lib/src/client/mod.rs +++ b/lib/src/client/mod.rs @@ -3,10 +3,11 @@ pub mod error; pub mod url_params; -use crate::credentials::Credentials; +use crate::credentials::get_headers; use bytes::Bytes; use error::{ApiClientError, FireBaseAPIErrorResponse}; use error_stack::{Report, ResultExt}; +use google_cloud_auth::credentials::CredentialsProvider; use http::Method; use serde::{Serialize, de::DeserializeOwned}; use std::future::Future; @@ -18,7 +19,6 @@ pub trait ApiHttpClient: Send + Sync + 'static { &self, uri: String, method: Method, - oauth_scopes: &[&str], ) -> impl Future>> + Send; fn send_request_with_params< @@ -29,7 +29,6 @@ pub trait ApiHttpClient: Send + Sync + 'static { uri: String, params: ParamsT, method: Method, - oauth_scopes: &[&str], ) -> impl Future>> + Send; fn send_request_body( @@ -37,7 +36,6 @@ pub trait ApiHttpClient: Send + Sync + 'static { uri: String, method: Method, request_body: RequestT, - oauth_scopes: &[&str], ) -> impl Future>> + Send; fn send_request_body_get_bytes( @@ -45,7 +43,6 @@ pub trait ApiHttpClient: Send + Sync + 'static { uri: String, method: Method, request_body: RequestT, - oauth_scopes: &[&str], ) -> impl Future>> + Send; fn send_request_body_empty_response( @@ -53,7 +50,6 @@ pub trait ApiHttpClient: Send + Sync + 'static { uri: String, method: Method, request_body: RequestT, - oauth_scopes: &[&str], ) -> impl Future>> + Send; } @@ -76,7 +72,7 @@ pub struct ReqwestApiClient { credentials: C, } -impl ReqwestApiClient { +impl ReqwestApiClient { pub fn new(client: reqwest::Client, credentials: C) -> Self { Self { client, @@ -105,14 +101,12 @@ impl ReqwestApiClient { &self, url: &str, method: Method, - oauth_scopes: &[&str], body: Option, ) -> Result> { self.client .request(method, url) - .bearer_auth( - self.credentials - .get_access_token(oauth_scopes) + .headers( + get_headers(&self.credentials) .await .change_context(ApiClientError::FailedToSendRequest)?, ) @@ -123,21 +117,17 @@ impl ReqwestApiClient { } } -impl ApiHttpClient for ReqwestApiClient { +impl ApiHttpClient for ReqwestApiClient { async fn send_request( &self, url: String, method: Method, - oauth_scopes: &[&str], ) -> Result> { - Self::handle_response( - self.handle_request::<()>(&url, method, oauth_scopes, None) - .await?, - ) - .await? - .json() - .await - .change_context(ApiClientError::FailedToReceiveResponse) + Self::handle_response(self.handle_request::<()>(&url, method, None).await?) + .await? + .json() + .await + .change_context(ApiClientError::FailedToReceiveResponse) } async fn send_request_with_params< @@ -148,17 +138,13 @@ impl ApiHttpClient for ReqwestApiClient { url: String, params: ParamsT, method: Method, - oauth_scopes: &[&str], ) -> Result> { let url: String = url + ¶ms.into_url_params(); - Self::handle_response( - self.handle_request::<()>(&url, method, oauth_scopes, None) - .await?, - ) - .await? - .json() - .await - .change_context(ApiClientError::FailedToReceiveResponse) + Self::handle_response(self.handle_request::<()>(&url, method, None).await?) + .await? + .json() + .await + .change_context(ApiClientError::FailedToReceiveResponse) } async fn send_request_body( @@ -166,10 +152,9 @@ impl ApiHttpClient for ReqwestApiClient { url: String, method: Method, request_body: RequestT, - oauth_scopes: &[&str], ) -> Result> { Self::handle_response( - self.handle_request(&url, method, oauth_scopes, Some(request_body)) + self.handle_request(&url, method, Some(request_body)) .await?, ) .await? @@ -183,10 +168,9 @@ impl ApiHttpClient for ReqwestApiClient { url: String, method: Method, request_body: RequestT, - oauth_scopes: &[&str], ) -> Result> { Self::handle_response( - self.handle_request(&url, method, oauth_scopes, Some(request_body)) + self.handle_request(&url, method, Some(request_body)) .await?, ) .await? @@ -200,10 +184,9 @@ impl ApiHttpClient for ReqwestApiClient { url: String, method: Method, request_body: RequestT, - oauth_scopes: &[&str], ) -> Result<(), Report> { Self::handle_response( - self.handle_request(&url, method, oauth_scopes, Some(request_body)) + self.handle_request(&url, method, Some(request_body)) .await?, ) .await?; diff --git a/lib/src/credentials/emulator/mod.rs b/lib/src/credentials/emulator/mod.rs index 3abe6d1..585cd8e 100644 --- a/lib/src/credentials/emulator/mod.rs +++ b/lib/src/credentials/emulator/mod.rs @@ -1,25 +1,46 @@ -use super::{Credentials, CredentialsError}; -use error_stack::Report; +#[cfg(test)] +mod test; + +use super::GoogleUserProject; +use google_cloud_auth::{ + credentials::{CacheableResource, CredentialsProvider, EntityTag}, + errors::CredentialsError, +}; +use headers::{Authorization, HeaderMapExt}; +use http::HeaderMap; #[derive(Debug, Clone)] pub struct EmulatorCredentials { - project_id: String, + pub(crate) project_id: String, } impl Default for EmulatorCredentials { fn default() -> Self { Self { - project_id: std::env::var("PROJECT_ID").unwrap_or("demo-firebase-project".into()), + project_id: std::env::var("GOOGLE_CLOUD_PROJECT").unwrap_or_else(|_| { + std::env::var("PROJECT_ID").unwrap_or("demo-firebase-project".into()) + }), } } } -impl Credentials for EmulatorCredentials { - async fn get_access_token(&self, _scopes: &[&str]) -> Result> { - Ok("owner".into()) +impl CredentialsProvider for EmulatorCredentials { + async fn headers( + &self, + _extensions: http::Extensions, + ) -> Result, CredentialsError> { + let mut headers = HeaderMap::with_capacity(2); + headers.typed_insert(Authorization::bearer("owner").expect("Should always be valid")); + + headers.typed_insert(GoogleUserProject(self.project_id.clone())); + + Ok(CacheableResource::New { + entity_tag: EntityTag::new(), + data: headers, + }) } - async fn get_project_id(&self) -> Result> { - Ok(self.project_id.clone()) + async fn universe_domain(&self) -> Option { + unimplemented!("unimplemented") } } diff --git a/lib/src/credentials/test.rs b/lib/src/credentials/emulator/test.rs similarity index 52% rename from lib/src/credentials/test.rs rename to lib/src/credentials/emulator/test.rs index e339915..076a86a 100644 --- a/lib/src/credentials/test.rs +++ b/lib/src/credentials/emulator/test.rs @@ -1,13 +1,18 @@ -use super::{Credentials, GoogleUserProject, emulator::EmulatorCredentials}; +use super::{super::GoogleUserProject, EmulatorCredentials}; +use google_cloud_auth::credentials::{CacheableResource, CredentialsProvider}; use headers::{Authorization, HeaderMapExt, authorization::Bearer}; -use http::header::{HeaderMap, HeaderValue}; +use http::Extensions; #[tokio::test] async fn test_credentials() { - let mut headers: HeaderMap = HeaderMap::default(); let creds = EmulatorCredentials::default(); - - creds.set_credentials(&mut headers, &[]).await.unwrap(); + let headers = match creds.headers(Extensions::new()).await.unwrap() { + CacheableResource::New { + entity_tag: _, + data, + } => data, + _ => unreachable!(), + }; let project_id: GoogleUserProject = headers.typed_get().unwrap(); let token: Authorization = headers.typed_get().unwrap(); diff --git a/lib/src/credentials/error.rs b/lib/src/credentials/error.rs deleted file mode 100644 index d5f3bcd..0000000 --- a/lib/src/credentials/error.rs +++ /dev/null @@ -1,11 +0,0 @@ -use thiserror::Error; - -#[derive(Error, Debug, Clone)] -pub enum CredentialsError { - #[error("Failed while parsing service credential JSON")] - FailedParsingServiceCredentials, - #[error("Received invalid access token")] - InvalidAccessToken, - #[error("Something unexpected happened")] - Internal, -} diff --git a/lib/src/credentials/gcp/mod.rs b/lib/src/credentials/gcp/mod.rs deleted file mode 100644 index 1d38cc9..0000000 --- a/lib/src/credentials/gcp/mod.rs +++ /dev/null @@ -1,21 +0,0 @@ -use super::{Credentials, CredentialsError}; -use crate::GcpCredentials; -use error_stack::{Report, ResultExt}; - -impl Credentials for GcpCredentials { - async fn get_access_token(&self, scopes: &[&str]) -> Result> { - let token = self - .token(scopes) - .await - .change_context(CredentialsError::Internal)?; - - Ok(token.as_str().into()) - } - - async fn get_project_id(&self) -> Result> { - self.project_id() - .await - .change_context(CredentialsError::Internal) - .map(|t| (*t).to_owned()) - } -} diff --git a/lib/src/credentials/mod.rs b/lib/src/credentials/mod.rs index eb73eed..96b746d 100644 --- a/lib/src/credentials/mod.rs +++ b/lib/src/credentials/mod.rs @@ -1,18 +1,16 @@ //! OAuth2 credential managers for GCP and Firebase Emulator pub mod emulator; -pub mod error; -pub mod gcp; -#[cfg(test)] -mod test; - -use error::CredentialsError; use error_stack::{Report, ResultExt}; -use headers::{Authorization, HeaderMapExt, authorization::Bearer}; +use google_cloud_auth::credentials::{CacheableResource, CredentialsProvider}; +use headers::HeaderMapExt; use headers::{Header, HeaderName, HeaderValue}; -use http::header::HeaderMap; -use std::future::Future; +use http::{Extensions, HeaderMap}; + +#[derive(thiserror::Error, Debug, Clone)] +#[error("Failed to extract GCP credentials")] +pub struct GCPCredentialsError; static X_GOOG_USER_PROJECT: HeaderName = HeaderName::from_static("x-goog-user-project"); @@ -48,35 +46,33 @@ impl Header for GoogleUserProject { } } -pub trait Credentials: Send + Sync + 'static { - /// Implementation for generation of OAuth2 access token - fn get_access_token( - &self, - scopes: &[&str], - ) -> impl Future>> + Send; - - /// Implementation for getting GCP project id - fn get_project_id( - &self, - ) -> impl Future>> + Send; - - /// Set credentials for a API request, by default use bearer authorization for passing access token - fn set_credentials( - &self, - headers: &mut HeaderMap, - scopes: &[&str], - ) -> impl Future>> + Send { - async move { - let token = self.get_access_token(scopes).await?; - - headers.typed_insert( - Authorization::::bearer(&token) - .change_context(CredentialsError::InvalidAccessToken)?, - ); - - headers.typed_insert(GoogleUserProject(self.get_project_id().await?)); - - Ok(()) - } - } +pub(crate) async fn get_project_id( + creds: &impl CredentialsProvider, +) -> Result> { + let headers = get_headers(creds).await?; + + let user_project: GoogleUserProject = headers + .typed_get() + .ok_or(Report::new(GCPCredentialsError))?; + + Ok(user_project.0) +} + +pub(crate) async fn get_headers( + creds: &impl CredentialsProvider, +) -> Result> { + let headers = creds + .headers(Extensions::new()) + .await + .change_context(GCPCredentialsError)?; + + let headers = match headers { + CacheableResource::New { + entity_tag: _, + data, + } => data, + _ => unreachable!(), + }; + + Ok(headers) } diff --git a/lib/src/jwt/mod.rs b/lib/src/jwt/mod.rs new file mode 100644 index 0000000..8f5e68e --- /dev/null +++ b/lib/src/jwt/mod.rs @@ -0,0 +1,115 @@ +use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; +use core::future::Future; +use error_stack::{Report, ResultExt}; +use jsonwebtoken::{DecodingKey, Validation, decode, decode_header}; +use jsonwebtoken_jwks_cache::{CachedJWKS, TimeoutSpec}; +use serde_json::{Value, from_slice}; +use std::collections::HashMap; +use std::time::Duration; +use thiserror::Error; + +const GOOGLE_JWKS_URI: &str = + "https://www.googleapis.com/service_accounts/v1/jwk/securetoken@system.gserviceaccount.com"; +const GOOGLE_ID_TOKEN_ISSUER_PREFIX: &str = "https://securetoken.google.com/"; +const GOOGLE_COOKIE_ISSUER_PREFIX: &str = "https://session.firebase.google.com/"; + +#[derive(Error, Debug, Clone)] +pub enum TokenVerificationError { + #[error("Token's key is missing")] + MissingKey, + #[error("Invalid token")] + Invalid, + #[error("Unexpected error")] + Internal, +} + +pub trait TokenValidator { + /// Validate JWT returning all claims on success + fn validate( + &self, + token: &str, + ) -> impl Future, Report>> + Send + Sync; +} + +pub struct LiveValidator { + project_id: String, + issuer: String, + jwks: CachedJWKS, +} + +impl LiveValidator { + pub fn new_jwt_validator(project_id: String) -> Result { + Ok(Self { + issuer: format!("{GOOGLE_ID_TOKEN_ISSUER_PREFIX}{project_id}"), + project_id, + jwks: CachedJWKS::new( + // should always succeed + GOOGLE_JWKS_URI.parse().unwrap(), + Duration::from_secs(60), + TimeoutSpec::default(), + )?, + }) + } + + pub fn new_cookie_validator(project_id: String) -> Result { + Ok(Self { + issuer: format!("{GOOGLE_COOKIE_ISSUER_PREFIX}{project_id}"), + project_id, + jwks: CachedJWKS::new( + // should always succeed + GOOGLE_JWKS_URI.parse().unwrap(), + Duration::from_secs(60), + TimeoutSpec::default(), + )?, + }) + } +} + +impl TokenValidator for LiveValidator { + async fn validate( + &self, + token: &str, + ) -> Result, Report> { + let jwks = self + .jwks + .get() + .await + .change_context(TokenVerificationError::Internal)?; + let jwt_header = decode_header(token).change_context(TokenVerificationError::Invalid)?; + + let jwk: DecodingKey = jwks + .find(&jwt_header.kid.ok_or(TokenVerificationError::MissingKey)?) + .ok_or(TokenVerificationError::MissingKey)? + .try_into() + .change_context(TokenVerificationError::Internal)?; + + let mut validator = Validation::new(jwt_header.alg); + validator.set_audience(&[&self.project_id]); + validator.set_issuer(&[&self.issuer]); + + decode::>(token, &jwk, &validator) + .change_context(TokenVerificationError::Invalid) + .map(|t| t.claims) + } +} + +#[derive(Default)] +pub struct EmulatorValidator; + +impl TokenValidator for EmulatorValidator { + async fn validate( + &self, + token: &str, + ) -> Result, Report> { + let header = token + .split(".") + .nth(1) + .ok_or(TokenVerificationError::Invalid)?; + + let header = URL_SAFE_NO_PAD + .decode(header) + .change_context(TokenVerificationError::Invalid)?; + + from_slice(&header).change_context(TokenVerificationError::Invalid) + } +} diff --git a/lib/src/lib.rs b/lib/src/lib.rs index cf130c2..c2a0c67 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -2,42 +2,38 @@ pub mod api_uri; pub mod auth; pub mod client; pub mod credentials; +#[cfg(feature = "tokens")] +pub mod jwt; pub mod util; use auth::FirebaseAuth; - -#[cfg(feature = "tokens")] -use auth::token::{ - EmulatedTokenVerifier, GOOGLE_COOKIE_PUB_KEY_URI, GOOGLE_PUB_KEY_URI, LiveTokenVerifier, - cache::{HttpCache, PubKeys}, - error::TokenVerificationError, -}; use client::ReqwestApiClient; -use credentials::emulator::EmulatorCredentials; -pub use credentials::{Credentials, error::CredentialsError}; +use credentials::{GCPCredentialsError, emulator::EmulatorCredentials, get_project_id}; use error_stack::{Report, ResultExt}; -use gcp_auth::TokenProvider; -pub use gcp_auth::provider as credentials_provider; -use std::sync::Arc; +use google_cloud_auth::credentials::{AccessTokenCredentials, Builder}; -/// Default Firebase Auth admin manager -pub type GcpCredentials = Arc; -pub type LiveAuthAdmin = FirebaseAuth>; +const FIREBASE_AUTH_SCOPES: [&str; 2] = [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", +]; + +pub type LiveAuthAdmin = FirebaseAuth>; /// Default Firebase Auth Emulator admin manager pub type EmulatorAuthAdmin = FirebaseAuth>; /// Base privileged manager for Firebase -pub struct App { - credentials: CredentialsT, +pub struct App { + credentials: C, project_id: String, } impl App { /// Firebase app backend by emulator - pub fn emulated(project_id: String) -> Self { + pub fn emulated() -> Self { + let credentials = EmulatorCredentials::default(); Self { - credentials: EmulatorCredentials::default(), - project_id, + project_id: credentials.project_id.clone(), + credentials, } } @@ -45,30 +41,27 @@ impl App { pub fn auth(&self, emulator_url: String) -> EmulatorAuthAdmin { let client = ReqwestApiClient::new(reqwest::Client::new(), self.credentials.clone()); - FirebaseAuth::emulated(emulator_url, &self.project_id, client) + FirebaseAuth::emulated(emulator_url, &self.credentials.project_id, client) } /// OIDC token verifier for emulator #[cfg(feature = "tokens")] - pub fn id_token_verifier(&self) -> EmulatedTokenVerifier { - EmulatedTokenVerifier::new(self.project_id.clone()) + pub fn id_token_verifier(&self) -> impl jwt::TokenValidator { + jwt::EmulatorValidator } } -impl App { +impl App { /// Create instance of Firebase app for live project - pub async fn live(credentials: GcpCredentials) -> Result> { - Self::live_shared(credentials).await - } + pub async fn live() -> Result> { + let credentials = Builder::default() + .with_scopes(FIREBASE_AUTH_SCOPES) + .build_access_token_credentials() + .change_context(GCPCredentialsError)?; - pub async fn live_shared( - credentials: GcpCredentials, - ) -> Result> { - let project_id = credentials - .project_id() + let project_id = get_project_id(&credentials) .await - .change_context(CredentialsError::Internal)? - .to_string(); + .change_context(GCPCredentialsError)?; Ok(Self { credentials, @@ -87,41 +80,21 @@ impl App { #[cfg(feature = "tokens")] pub async fn id_token_verifier( &self, - ) -> Result< - LiveTokenVerifier>, - Report, - > { - let cache_client = HttpCache::new( - reqwest::Client::new(), - GOOGLE_PUB_KEY_URI - .parse() - .map_err(error_stack::Report::new) - .change_context(TokenVerificationError::FailedGettingKeys)?, - ) - .await - .change_context(TokenVerificationError::FailedGettingKeys)?; - - LiveTokenVerifier::new_id_verifier(self.project_id.clone(), cache_client) + ) -> Result> { + let project_id = credentials::get_project_id(&self.credentials).await?; + + jwt::LiveValidator::new_jwt_validator(project_id) + .change_context(credentials::GCPCredentialsError) } - /// Create cookie token verifier + // /// Create cookie token verifier #[cfg(feature = "tokens")] pub async fn cookie_token_verifier( &self, - ) -> Result< - LiveTokenVerifier>, - Report, - > { - let cache_client = HttpCache::new( - reqwest::Client::new(), - GOOGLE_COOKIE_PUB_KEY_URI - .parse() - .map_err(error_stack::Report::new) - .change_context(TokenVerificationError::FailedGettingKeys)?, - ) - .await - .change_context(TokenVerificationError::FailedGettingKeys)?; - - LiveTokenVerifier::new_cookie_verifier(self.project_id.clone(), cache_client) + ) -> Result> { + let project_id = credentials::get_project_id(&self.credentials).await?; + + jwt::LiveValidator::new_cookie_validator(project_id) + .change_context(credentials::GCPCredentialsError) } }