Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ reqwest = { version = "0.13", default-features = false, optional = true, feature

ureq = { version = "3.0.6", optional = true, default-features = false, features = ["gzip", "json", "socks-proxy", "charset"]}

hmac = { version = "0.12.1", optional = true }
percent-encoding = { version = "2.3.2", optional = true }
sha2 = { version = "0.10.0", optional = true }
time = { version = "0.3.45", optional = true }
url = { version = "2.5.8", optional = true }

[features]
default = ["reqwest", "default-tls"]

Expand All @@ -64,3 +70,4 @@ rustls = ["reqwest?/rustls", "ureq?/rustls"]

reqwest = ["dep:reqwest"]
ureq = ["dep:ureq"]
s3-auth = ["dep:hmac", "dep:percent-encoding", "dep:sha2", "dep:url", "dep:time"]
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ fn update() -> Result<(), Box<::std::error::Error>> {
.asset_prefix("something/self_update")
.region("eu-west-2")
.bin_name("self_update_example")
// .access_key((env!("AWS_ACCESS_KEY_ID"), env!("AWS_SECRET_ACCESS_KEY")))
.show_download_progress(true)
.current_version(cargo_crate_version!())
.build()?
Expand Down
227 changes: 222 additions & 5 deletions src/backends/s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ pub struct ReleaseListBuilder {
asset_prefix: Option<String>,
target: Option<String>,
region: Option<String>,
#[cfg(feature = "s3-auth")]
access_key: Option<auth::AccessKey>,
}

impl ReleaseListBuilder {
Expand Down Expand Up @@ -100,6 +102,13 @@ impl ReleaseListBuilder {
self
}

#[cfg(feature = "s3-auth")]
/// Set the access key
pub fn access_key(&mut self, access_key: impl Into<auth::AccessKey>) -> &mut Self {
self.access_key = Some(access_key.into());
self
}

/// Verify builder args, returning a `ReleaseList`
pub fn build(&self) -> Result<ReleaseList> {
Ok(ReleaseList {
Expand All @@ -112,6 +121,8 @@ impl ReleaseListBuilder {
region: self.region.clone(),
asset_prefix: self.asset_prefix.clone(),
target: self.target.clone(),
#[cfg(feature = "s3-auth")]
access_key: self.access_key.clone(),
})
}
}
Expand All @@ -125,6 +136,8 @@ pub struct ReleaseList {
asset_prefix: Option<String>,
target: Option<String>,
region: Option<String>,
#[cfg(feature = "s3-auth")]
access_key: Option<auth::AccessKey>,
}

impl ReleaseList {
Expand All @@ -136,6 +149,8 @@ impl ReleaseList {
asset_prefix: None,
target: None,
region: None,
#[cfg(feature = "s3-auth")]
access_key: None,
}
}

Expand All @@ -147,6 +162,8 @@ impl ReleaseList {
&self.bucket_name,
&self.region,
&self.asset_prefix,
#[cfg(feature = "s3-auth")]
&self.access_key,
)?;
let releases = match self.target {
None => releases,
Expand All @@ -170,6 +187,8 @@ pub struct UpdateBuilder {
asset_prefix: Option<String>,
target: Option<String>,
region: Option<String>,
#[cfg(feature = "s3-auth")]
access_key: Option<auth::AccessKey>,
bin_name: Option<String>,
bin_install_path: Option<PathBuf>,
bin_path_in_archive: Option<String>,
Expand All @@ -193,6 +212,8 @@ impl Default for UpdateBuilder {
asset_prefix: None,
target: None,
region: None,
#[cfg(feature = "s3-auth")]
access_key: None,
bin_name: None,
bin_install_path: None,
bin_path_in_archive: None,
Expand Down Expand Up @@ -241,6 +262,13 @@ impl UpdateBuilder {
self
}

#[cfg(feature = "s3-auth")]
/// Set the access key id
pub fn access_key_id(&mut self, access_key: impl Into<auth::AccessKey>) -> &mut Self {
self.access_key = Some(access_key.into());
self
}

/// Set the current app version, used to compare against the latest available version.
/// The `cargo_crate_version!` macro can be used to pull the version from your `Cargo.toml`
pub fn current_version(&mut self, ver: &str) -> &mut Self {
Expand Down Expand Up @@ -389,6 +417,8 @@ impl UpdateBuilder {
bail!(Error::Config, "`bucket_name` required")
},
region: self.region.clone(),
#[cfg(feature = "s3-auth")]
access_key: self.access_key.clone(),
asset_prefix: self.asset_prefix.clone(),
target: self
.target
Expand Down Expand Up @@ -432,6 +462,8 @@ pub struct Update {
asset_prefix: Option<String>,
target: String,
region: Option<String>,
#[cfg(feature = "s3-auth")]
access_key: Option<auth::AccessKey>,
current_version: String,
target_version: Option<String>,
bin_name: String,
Expand Down Expand Up @@ -461,6 +493,8 @@ impl ReleaseUpdate for Update {
&self.bucket_name,
&self.region,
&self.asset_prefix,
#[cfg(feature = "s3-auth")]
&self.access_key,
)?;
let rel = releases
.iter()
Expand Down Expand Up @@ -490,6 +524,8 @@ impl ReleaseUpdate for Update {
&self.bucket_name,
&self.region,
&self.asset_prefix,
#[cfg(feature = "s3-auth")]
&self.access_key,
)?;

let mut releases = releases
Expand Down Expand Up @@ -525,6 +561,8 @@ impl ReleaseUpdate for Update {
&self.bucket_name,
&self.region,
&self.asset_prefix,
#[cfg(feature = "s3-auth")]
&self.access_key,
)?;
let rel = releases.iter().find(|x| x.version == ver);
match rel {
Expand Down Expand Up @@ -591,6 +629,172 @@ impl ReleaseUpdate for Update {
}
}

/// Generate S3 auth parameters
#[cfg(feature = "s3-auth")]
mod auth {
use crate::errors::*;
use hmac::{Hmac, Mac};
use percent_encoding::{utf8_percent_encode, AsciiSet, PercentEncode, NON_ALPHANUMERIC};
use sha2::{Digest, Sha256};
use std::{
borrow::Cow,
time::{SystemTime, UNIX_EPOCH},
};
use time::OffsetDateTime;
use url::Url;

#[derive(Clone, Debug)]
pub struct AccessKey {
pub access_key_id: String,
pub secret_access_key: String,
}

impl From<(&str, &str)> for AccessKey {
fn from(value: (&str, &str)) -> Self {
Self {
access_key_id: value.0.to_owned(),
secret_access_key: value.1.to_owned(),
}
}
}

impl From<(String, String)> for AccessKey {
fn from(value: (String, String)) -> Self {
Self {
access_key_id: value.0,
secret_access_key: value.1,
}
}
}

// NON_ALPHANUMERIC Encodes everything except A-Z, a-z, 0-9.
// Remove the last 4 reserved characters that AWS doesn't encode: - . _ ~
const URI_ENCODE: &AsciiSet = &NON_ALPHANUMERIC
.remove(b'-')
.remove(b'.')
.remove(b'_')
.remove(b'~');

// AWS doesn't encode the slash character in the canonical URI, but it does
// encode it in query parameters
const URI_ENCODE_KEEP_SLASH: &AsciiSet = &URI_ENCODE.remove(b'/');

// Encode a string for use in AWS S3 signature v4, encoding reserved
// characters and optionally the slash character
fn uri_encode(input: &str, encode_slash: bool) -> PercentEncode<'_> {
let set = if encode_slash {
URI_ENCODE
} else {
URI_ENCODE_KEEP_SLASH
};
utf8_percent_encode(input, set)
}

fn hex_sha256(data: &[u8]) -> String {
let hash = Sha256::digest(data);
hash.iter().map(|b| format!("{b:02x}")).collect()
}

fn hmac_sha256(key: &[u8], data: &[u8]) -> Result<Vec<u8>> {
let mut mac = Hmac::<Sha256>::new_from_slice(key)?;
mac.update(data);
Ok(mac.finalize().into_bytes().to_vec())
}

fn derive_signing_key(secret: &str, date_stamp: &str, region: &str) -> Result<Vec<u8>> {
let k_date = hmac_sha256(format!("AWS4{secret}").as_bytes(), date_stamp.as_bytes())?;
let k_region = hmac_sha256(&k_date, region.as_bytes())?;
let k_service = hmac_sha256(&k_region, b"s3")?;
hmac_sha256(&k_service, b"aws4_request")
}

fn format_timestamp(secs: u64) -> Result<(String, String)> {
let dt = OffsetDateTime::from_unix_timestamp(secs as i64)?;
let date_stamp = format!("{:04}{:02}{:02}", dt.year(), dt.month() as u8, dt.day());
let amz_date = format!(
"{date_stamp}T{:02}{:02}{:02}Z",
dt.hour(),
dt.minute(),
dt.second()
);
Ok((date_stamp, amz_date))
}

pub fn s3_signature_v4(
url_str: &str,
region: &Option<String>,
access_key: &Option<AccessKey>,
ttl_secs: u64,
) -> Result<String> {
let (access_key_id, secret_access_key) = match access_key {
Some(access_key) => (&access_key.access_key_id, &access_key.secret_access_key),
None => return Ok(url_str.to_owned()),
};
let url = Url::parse(url_str)?;
let host = url
.host_str()
.ok_or_else(|| Error::Config("Cannot extract host from {:?url_str}".to_string()))?;
let canonical_uri = if url.path().is_empty() {
"/"
} else {
url.path()
};

let now_secs = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
let (date_stamp, amz_date) = format_timestamp(now_secs)?;

let region = region.as_deref().unwrap_or("us-east-1");

let credential_scope = format!("{date_stamp}/{region}/s3/aws4_request");

// Existing query params (decoded by url crate) + SigV4 params, sans Signature.
let mut params: Vec<_> = url.query_pairs().collect();

params.extend([
(
Cow::Borrowed("X-Amz-Algorithm"),
Cow::Borrowed("AWS4-HMAC-SHA256"),
),
(
Cow::Borrowed("X-Amz-Credential"),
Cow::Owned(format!("{access_key_id}/{credential_scope}")),
),
(Cow::Borrowed("X-Amz-Date"), Cow::Borrowed(&amz_date)),
(
Cow::Borrowed("X-Amz-Expires"),
Cow::Owned(ttl_secs.to_string()),
),
(Cow::Borrowed("X-Amz-SignedHeaders"), Cow::Borrowed("host")),
]);
params.sort_by(|a, b| a.0.cmp(&b.0));

let canonical_qs: String = params
.iter()
.map(|(k, v)| format!("{}={}", uri_encode(k, true), uri_encode(v, true)))
.collect::<Vec<_>>()
.join("&");

let canonical_request = format!(
"GET\n{}\n{canonical_qs}\nhost:{host}\n\nhost\nUNSIGNED-PAYLOAD",
uri_encode(canonical_uri, false),
);

let string_to_sign = format!(
"AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{}",
hex_sha256(canonical_request.as_bytes())
);

let signing_key = derive_signing_key(secret_access_key, &date_stamp, region)?;
let signature: String = hmac_sha256(&signing_key, string_to_sign.as_bytes())?
.iter()
.map(|b| format!("{b:02x}"))
.collect();

let base = &url_str[..url_str.find('?').unwrap_or(url_str.len())];
Ok(format!("{base}?{canonical_qs}&X-Amz-Signature={signature}"))
}
}

/// Obtain list of releases from AWS S3 API, from bucket and region specified,
/// filtering assets which don't match the prefix string if provided.
///
Expand All @@ -600,25 +804,29 @@ fn fetch_releases_from_s3(
bucket_name: &str,
region: &Option<String>,
asset_prefix: &Option<String>,
#[cfg(feature = "s3-auth")] access_key: &Option<auth::AccessKey>,
) -> Result<Vec<Release>> {
let prefix = match asset_prefix {
Some(prefix) => format!("&prefix={}", prefix),
None => "".to_string(),
};

let region = region
let region_result = region
.as_ref()
.ok_or_else(|| Error::Config("`region` required".to_string()));

let download_base_url = match end_point {
EndPoint::S3 => format!("https://{}.s3.{}.amazonaws.com/", bucket_name, region?),
EndPoint::S3 => format!(
"https://{}.s3.{}.amazonaws.com/",
bucket_name, region_result?
),
EndPoint::S3DualStack => format!(
"https://{}.s3.dualstack.{}.amazonaws.com/",
bucket_name, region?
bucket_name, region_result?
),
EndPoint::DigitalOceanSpaces => format!(
"https://{}.{}.digitaloceanspaces.com/",
bucket_name, region?
bucket_name, region_result?
),
EndPoint::GCS => format!("https://storage.googleapis.com/{}/", bucket_name),
EndPoint::Generic { ref end_point } => end_point.clone(),
Expand All @@ -635,6 +843,9 @@ fn fetch_releases_from_s3(
EndPoint::GCS => format!("{}?max-keys={}{}", download_base_url, MAX_KEYS, prefix),
};

#[cfg(feature = "s3-auth")]
let api_url = auth::s3_signature_v4(&api_url, region, access_key, 300)?;

debug!("using api url: {:?}", api_url);

let resp = http_client::get(&api_url, Default::default())?;
Expand Down Expand Up @@ -703,9 +914,15 @@ fn fetch_releases_from_s3(
release.name = captures["name"].to_string();
release.version =
captures["version"].trim_start_matches('v').to_string();
let download_url = format!("{}{}", download_base_url, txt);

#[cfg(feature = "s3-auth")]
let download_url =
auth::s3_signature_v4(&download_url, region, access_key, 300)?;

release.assets = vec![ReleaseAsset {
name: exe_name.to_string(),
download_url: format!("{}{}", download_base_url, txt),
download_url,
}];
debug!("Matched release: {:?}", release);
} else {
Expand Down
Loading
Loading