Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions gateway/rpc/proto/gateway_rpc.proto
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ message HostInfo {
string app_id = 3;
// The base domain of the HTTPS endpoint of the host.
string base_domain = 4;
// The external port of the host.
uint32 port = 5;
// The latest handshake time of the host.
uint64 latest_handshake = 6;
// The number of connections of the host.
Expand Down
2 changes: 0 additions & 2 deletions gateway/src/admin_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ impl AdminRpcHandler {
ip: instance.ip.to_string(),
app_id: instance.app_id.clone(),
base_domain: base_domain.clone(),
port: state.config.proxy.listen_port as u32,
latest_handshake: encode_ts(instance.last_seen),
num_connections: instance.num_connections(),
})
Expand Down Expand Up @@ -97,7 +96,6 @@ impl AdminRpc for AdminRpcHandler {
ip: instance.ip.to_string(),
app_id: instance.app_id.clone(),
base_domain: base_domain.clone(),
port: state.config.proxy.listen_port as u32,
latest_handshake: {
let (ts, _) = handshakes
.get(&instance.public_key)
Expand Down
37 changes: 36 additions & 1 deletion gateway/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,40 @@ pub enum TlsVersion {
Tls13,
}

/// Deserialize a port range from either a single integer (443) or a string range ("443-543").
fn deserialize_port_range<'de, D>(deserializer: D) -> std::result::Result<Vec<u16>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de;

#[derive(Deserialize)]
#[serde(untagged)]
enum PortSpec {
Single(u16),
Range(String),
}

match PortSpec::deserialize(deserializer)? {
PortSpec::Single(p) => Ok(vec![p]),
PortSpec::Range(s) => {
if let Some((start, end)) = s.split_once('-') {
let start: u16 = start.trim().parse().map_err(de::Error::custom)?;
let end: u16 = end.trim().parse().map_err(de::Error::custom)?;
if start > end {
return Err(de::Error::custom(format!(
"invalid port range: {start} > {end}"
)));
}
Ok((start..=end).collect())
} else {
let p: u16 = s.trim().parse().map_err(de::Error::custom)?;
Ok(vec![p])
}
}
}
}

#[derive(Debug, Clone, Deserialize)]
pub struct ProxyConfig {
pub cert_chain: String,
Expand All @@ -76,7 +110,8 @@ pub struct ProxyConfig {
pub base_domain: String,
pub external_port: u16,
pub listen_addr: Ipv4Addr,
pub listen_port: u16,
#[serde(deserialize_with = "deserialize_port_range")]
pub listen_port: Vec<u16>,
pub agent_port: u16,
pub timeouts: Timeouts,
pub buffer_size: usize,
Expand Down
45 changes: 30 additions & 15 deletions gateway/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::{
atomic::{AtomicU64, AtomicUsize, Ordering},
Arc,
},
task::Poll,
};

use anyhow::{bail, Context, Result};
Expand Down Expand Up @@ -173,21 +174,35 @@ pub async fn proxy_main(config: &ProxyConfig, proxy: Proxy) -> Result<()> {
let base_domain = base_domain.strip_prefix(".").unwrap_or(base_domain);
Arc::new(format!(".{base_domain}"))
};
let listener = TcpListener::bind((config.listen_addr, config.listen_port))
.await
.with_context(|| {
format!(
"failed to bind {}:{}",
config.listen_addr, config.listen_port
)
})?;
info!(
"tcp bridge listening on {}:{}",
config.listen_addr, config.listen_port
);
let mut tcp_listeners = Vec::new();
for &port in &config.listen_port {
let listener = TcpListener::bind((config.listen_addr, port))
.await
.with_context(|| format!("failed to bind {}:{}", config.listen_addr, port))?;
info!("tcp bridge listening on {}:{}", config.listen_addr, port);
tcp_listeners.push(listener);
}
if tcp_listeners.is_empty() {
bail!("no tcp listen ports configured");
}

let poll_counter = AtomicUsize::new(0);
loop {
match listener.accept().await {
// Accept from any TCP listener via round-robin poll.
let poll_start = poll_counter.fetch_add(1, Ordering::Relaxed);
let n = tcp_listeners.len();
let accepted: std::io::Result<(TcpStream, std::net::SocketAddr)> =
std::future::poll_fn(|cx| {
for j in 0..n {
let i = (poll_start + j) % n;
if let Poll::Ready(result) = tcp_listeners[i].poll_accept(cx) {
return Poll::Ready(result);
}
}
Poll::Pending
})
.await;
match accepted {
Ok((inbound, from)) => {
let span = info_span!("conn", id = next_connection_id());
let _enter = span.enter();
Expand All @@ -210,7 +225,7 @@ pub async fn proxy_main(config: &ProxyConfig, proxy: Proxy) -> Result<()> {
info!("connection closed");
}
Ok(Err(e)) => {
error!("connection error: {e:?}");
error!("connection error: {e:#}");
}
Err(_) => {
error!("connection kept too long, force closing");
Expand Down Expand Up @@ -245,7 +260,7 @@ pub fn start(config: ProxyConfig, app_state: Proxy) -> Result<()> {
// Run the proxy_main function in this runtime
if let Err(err) = rt.block_on(proxy_main(&config, app_state)) {
error!(
"error on {}:{}: {err:?}",
"error on {}:{:?}: {err:?}",
config.listen_addr, config.listen_port
);
}
Expand Down