diff --git a/.node-version b/.node-version new file mode 100644 index 0000000..91d5f6f --- /dev/null +++ b/.node-version @@ -0,0 +1 @@ +22.18.0 diff --git a/Gemfile.lock b/Gemfile.lock index 76411c9..acb0dfd 100644 --- a/Gemfile.lock +++ b/Gemfile.lock @@ -7,8 +7,7 @@ GEM remote: https://rubygems.org/ specs: concurrent-ruby (1.3.4) - google-protobuf (3.25.6) - google-protobuf (3.25.6-arm64-darwin) + google-protobuf (3.25.8) googleapis-common-protos-types (1.18.0) google-protobuf (>= 3.18, < 5.a) grpc (1.70.1) diff --git a/bin/bench-server.rb b/bin/bench-server.rb index 263dafd..4a5ba86 100755 --- a/bin/bench-server.rb +++ b/bin/bench-server.rb @@ -14,14 +14,15 @@ bind_address: ENV.fetch("BIND_ADDRESS", "127.0.0.1:3000"), tokio_threads: ENV.fetch("TOKIO_THREADS", "1").to_i, debug: ENV.fetch("DEBUG", "false") == "true", - recv_timeout: ENV.fetch("RECV_TIMEOUT", "30000").to_i + recv_timeout: ENV.fetch("RECV_TIMEOUT", "30000").to_i, + max_connection_age: ENV.fetch("MAX_CONNECTION_AGE", "30000").to_i } server.configure(config) puts "Starting server with config: #{config}" accept_response = HyperRuby::Response.new( - 200, + 202, { "Content-Type" => "application/json" }, { "message" => "Accepted" }.to_json ) diff --git a/ext/hyper_ruby/src/lib.rs b/ext/hyper_ruby/src/lib.rs index e71f070..54f785c 100644 --- a/ext/hyper_ruby/src/lib.rs +++ b/ext/hyper_ruby/src/lib.rs @@ -16,6 +16,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use std::cell::RefCell; use std::net::SocketAddr; +use std::sync::atomic::{AtomicU64, Ordering}; use tokio::net::{TcpListener, UnixListener}; @@ -80,6 +81,7 @@ struct ServerConfig { recv_timeout: u64, channel_capacity: usize, send_timeout: u64, + max_connection_age: Option, } impl ServerConfig { @@ -91,6 +93,7 @@ impl ServerConfig { recv_timeout: 30000, // Default 30 second timeout channel_capacity: 5000, // Default capacity for worker channel send_timeout: 1000, // Default 1 second timeout for send backpressure + max_connection_age: None, // No limit by default } } } @@ -109,6 +112,7 @@ struct Server { work_tx: RefCell>>>, runtime: RefCell>>, shutdown: RefCell>>, + total_connections: Arc, } impl Server { @@ -121,9 +125,14 @@ impl Server { work_tx: RefCell::new(None), runtime: RefCell::new(None), shutdown: RefCell::new(None), + total_connections: Arc::new(AtomicU64::new(0)), } } + pub fn total_connections(&self) -> u64 { + self.total_connections.load(Ordering::Relaxed) + } + pub fn configure(&self, config: magnus::RHash) -> Result<(), MagnusError> { let mut server_config = self.config.borrow_mut(); if let Some(bind_address) = config.get(magnus::Symbol::new("bind_address")) { @@ -150,6 +159,10 @@ impl Server { server_config.send_timeout = u64::try_convert(send_timeout)?; } + if let Some(max_connection_age) = config.get(magnus::Symbol::new("max_connection_age")) { + server_config.max_connection_age = Some(u64::try_convert(max_connection_age)?); + } + // Initialize logging if not already initialized LOGGER_INIT.call_once(|| { let mut builder = env_logger::Builder::from_env(env_logger::Env::default()); @@ -263,7 +276,9 @@ impl Server { *self.work_tx.borrow_mut() = Some(work_tx.clone()); let (shutdown_tx, shutdown_rx) = broadcast::channel(1); - *self.shutdown.borrow_mut() = Some(shutdown_tx); + *self.shutdown.borrow_mut() = Some(shutdown_tx.clone()); + + let total_connections = self.total_connections.clone(); let mut rt_builder = tokio::runtime::Builder::new_multi_thread(); @@ -346,17 +361,19 @@ impl Server { }; // Now that we have successfully bound, spawn the server task + let max_connection_age = config.max_connection_age; let server_task = tokio::spawn(async move { let graceful_shutdown = GracefulShutdown::new(); let mut shutdown_rx = shutdown_rx; loop { tokio::select! { - Ok((stream, _)) = listener.accept() => { + Ok((stream, _)) = listener.accept() => { + total_connections.fetch_add(1, Ordering::Relaxed); info!("New connection established"); - + let io = TokioIo::new(stream); - + debug!("Setting up connection"); let builder = builder.clone(); @@ -365,13 +382,50 @@ impl Server { debug!("Service handling request"); handle_request(req, work_tx.clone(), config.recv_timeout, config.send_timeout) })); - let fut = graceful_shutdown.watch(conn.into_owned()); - tokio::task::spawn(async move { - if let Err(err) = fut.await { - warn!("Error serving connection: {:?}", err); - } - }); - }, + // If max_connection_age is set, handle the connection with a timeout + // but still integrate with server-wide graceful shutdown via broadcast channel + if let Some(max_age_ms) = max_connection_age { + let conn = conn.into_owned(); + let mut conn_shutdown_rx = shutdown_tx.subscribe(); + tokio::task::spawn(async move { + tokio::pin!(conn); + let sleep = tokio::time::sleep(std::time::Duration::from_millis(max_age_ms)); + tokio::pin!(sleep); + let mut graceful_shutdown_started = false; + + loop { + tokio::select! { + result = conn.as_mut() => { + if let Err(err) = result { + warn!("Error serving connection: {:?}", err); + } + break; + } + _ = &mut sleep, if !graceful_shutdown_started => { + debug!("Connection reached max age ({}ms), sending GOAWAY", max_age_ms); + conn.as_mut().graceful_shutdown(); + graceful_shutdown_started = true; + // Continue the loop to let the connection drain + } + _ = conn_shutdown_rx.recv(), if !graceful_shutdown_started => { + debug!("Server shutdown requested, sending GOAWAY to connection"); + conn.as_mut().graceful_shutdown(); + graceful_shutdown_started = true; + // Continue the loop to let the connection drain + } + } + } + }); + } else { + // No max age, use the graceful shutdown watcher + let fut = graceful_shutdown.watch(conn.into_owned()); + tokio::task::spawn(async move { + if let Err(err) = fut.await { + warn!("Error serving connection: {:?}", err); + } + }); + } + }, _ = shutdown_rx.recv() => { debug!("Graceful shutdown requested; shutting down"); break; @@ -589,6 +643,7 @@ fn init(ruby: &Ruby) -> Result<(), MagnusError> { server_class.define_method("start", method!(Server::start, 0))?; server_class.define_method("stop", method!(Server::stop, 0))?; server_class.define_method("run_worker", method!(Server::run_worker, 0))?; + server_class.define_method("total_connections", method!(Server::total_connections, 0))?; let response_class = module.define_class("Response", ruby.class_object())?; response_class.define_singleton_method("new", function!(Response::new, 3))?; diff --git a/test/test_grpc.rb b/test/test_grpc.rb index 2fe5530..62e2456 100644 --- a/test/test_grpc.rb +++ b/test/test_grpc.rb @@ -144,15 +144,101 @@ def test_grpc_compression 'grpc.enable_http_proxy' => 0, }.merge(compression_channel_args) ) - + request = Echo::EchoRequest.new(message: "Hello Compressed GRPC " + ("a" * 10000)) response = stub.echo(request) - + assert_instance_of Echo::EchoResponse, response assert_equal "Decompressed: Hello Compressed GRPC " + ("a" * 10000), response.message end end + def test_max_connection_age_sends_goaway + # Test that max_connection_age causes the server to send GOAWAY after the configured time, + # forcing the client to establish a new connection + buffer = String.new(capacity: 1024) + server_config = { + bind_address: "127.0.0.1:3010", + tokio_threads: 1, + max_connection_age: 500 # 500ms max connection age + } + + with_configured_server(server_config, -> (request) { handler_grpc(request, buffer) }) do |_client, server| + stub = Echo::Echo::Stub.new( + "127.0.0.1:3010", + :this_channel_is_insecure, + channel_args: { + 'grpc.enable_http_proxy' => 0 + } + ) + + # Record initial connection count + initial_connections = server.total_connections + + # First request establishes a connection + request = Echo::EchoRequest.new(message: "Request 1") + response = stub.echo(request) + assert_equal "Request 1 response", response.message + + # Should have one connection now + assert_equal initial_connections + 1, server.total_connections, "First request should establish one connection" + + # Wait for max_connection_age to expire and GOAWAY to be sent + sleep 0.7 + + # Make another request - gRPC client should establish a new connection after GOAWAY + request = Echo::EchoRequest.new(message: "Request 2") + response = stub.echo(request) + assert_equal "Request 2 response", response.message + + # Should have a second connection now (client reconnected after GOAWAY) + assert_equal initial_connections + 2, server.total_connections, "Second request after GOAWAY should establish a new connection" + end + end + + def test_long_max_connection_age_reuses_connection + # Test that with a long max_connection_age, the connection is reused + # (opposite of test_max_connection_age_sends_goaway) + buffer = String.new(capacity: 1024) + server_config = { + bind_address: "127.0.0.1:3010", + tokio_threads: 1, + max_connection_age: 60000 # 60 seconds - much longer than test duration + } + + with_configured_server(server_config, -> (request) { handler_grpc(request, buffer) }) do |_client, server| + stub = Echo::Echo::Stub.new( + "127.0.0.1:3010", + :this_channel_is_insecure, + channel_args: { + 'grpc.enable_http_proxy' => 0 + } + ) + + # Record initial connection count + initial_connections = server.total_connections + + # First request establishes a connection + request = Echo::EchoRequest.new(message: "Request 1") + response = stub.echo(request) + assert_equal "Request 1 response", response.message + + # Should have one connection now + assert_equal initial_connections + 1, server.total_connections, "First request should establish one connection" + + # Wait a bit (but less than max_connection_age) + sleep 0.3 + + # Make another request - should reuse the same connection + request = Echo::EchoRequest.new(message: "Request 2") + response = stub.echo(request) + assert_equal "Request 2 response", response.message + + # Should still have only one connection (connection reused, no GOAWAY sent) + assert_equal initial_connections + 1, server.total_connections, "Second request should reuse existing connection" + end + end + private def handler_grpc(request, buffer) diff --git a/test/test_helper.rb b/test/test_helper.rb index c290844..62ed8df 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -31,7 +31,7 @@ def with_configured_server(config, request_handler, &block) end client = HTTPX.with(origin: "http://127.0.0.1:3010") - block.call(client) + block.call(client, server) ensure server.stop if server diff --git a/test/test_http.rb b/test/test_http.rb index 8617a5d..173a837 100644 --- a/test/test_http.rb +++ b/test/test_http.rb @@ -270,6 +270,117 @@ def test_http2_host end end + def test_chunked_transfer_encoding + buffer = String.new(capacity: 65536) + with_server(-> (request) { handler_to_json(request, buffer) }) do |_client| + require 'socket' + + # Create a large body (64KB) to ensure chunked encoding is meaningful + chunk1 = "A" * 16384 # 16KB + chunk2 = "B" * 16384 # 16KB + chunk3 = "C" * 16384 # 16KB + chunk4 = "D" * 16384 # 16KB + expected_body = chunk1 + chunk2 + chunk3 + chunk4 + + # Build a raw HTTP/1.1 request with Transfer-Encoding: chunked + socket = TCPSocket.new("127.0.0.1", 3010) + + # Send headers + socket.write("POST / HTTP/1.1\r\n") + socket.write("Host: 127.0.0.1:3010\r\n") + socket.write("Transfer-Encoding: chunked\r\n") + socket.write("Content-Type: text/plain\r\n") + socket.write("Connection: close\r\n") + socket.write("\r\n") + + # Send body in chunks (chunked encoding format: size in hex, CRLF, data, CRLF) + [chunk1, chunk2, chunk3, chunk4].each do |chunk| + socket.write("#{chunk.bytesize.to_s(16)}\r\n") + socket.write(chunk) + socket.write("\r\n") + end + + # Send final zero-length chunk to signal end + socket.write("0\r\n") + socket.write("\r\n") + + # Read response + response = socket.read + socket.close + + # Parse response - find the JSON body after headers + headers_end = response.index("\r\n\r\n") + assert headers_end, "Response should have headers" + + status_line = response.lines.first + assert_match(/HTTP\/1\.1 200/, status_line, "Should receive 200 OK") + + # Extract body after headers + body_part = response[(headers_end + 4)..] + + # Parse the JSON response and verify the body was correctly received + json_response = JSON.parse(body_part) + received_body = json_response["message"] + + assert_equal expected_body.bytesize, received_body.bytesize, + "Body size should match (expected #{expected_body.bytesize}, got #{received_body.bytesize})" + assert_equal expected_body, received_body, "Body content should match" + end + end + + def test_chunked_transfer_encoding_aborted_connection + # Test that when a client aborts (RST) mid-chunked-transfer, + # the partial body is NOT exposed to the handler + handler_called = false + + server_config = { + bind_address: "127.0.0.1:3010", + tokio_threads: 1, + recv_timeout: 1000 # 1 second timeout + } + + with_configured_server(server_config, -> (request) { + handler_called = true + buffer = String.new(capacity: 65536) + request.fill_body(buffer) + HyperRuby::Response.new(200, { 'Content-Type' => 'text/plain' }, buffer) + }) do |_client| + chunk1 = "A" * 16384 # 16KB + chunk2 = "B" * 16384 # 16KB + + # Create socket and set SO_LINGER to 0 to send RST on close (abort) + socket = Socket.new(Socket::AF_INET, Socket::SOCK_STREAM) + socket.setsockopt(Socket::SOL_SOCKET, Socket::SO_LINGER, [1, 0].pack("ii")) + socket.connect(Socket.sockaddr_in(3010, "127.0.0.1")) + + # Send headers + socket.write("POST / HTTP/1.1\r\n") + socket.write("Host: 127.0.0.1:3010\r\n") + socket.write("Transfer-Encoding: chunked\r\n") + socket.write("Content-Type: text/plain\r\n") + socket.write("Connection: close\r\n") + socket.write("\r\n") + + # Send first chunk successfully + socket.write("#{chunk1.bytesize.to_s(16)}\r\n") + socket.write(chunk1) + socket.write("\r\n") + + # Send partial second chunk + socket.write("#{chunk2.bytesize.to_s(16)}\r\n") + socket.write(chunk2[0, 8192]) # Only half of chunk2 + + # Abort the connection with RST (no graceful close) + socket.close + + # Wait for the server to process and timeout + sleep 1.5 + + # The handler should NOT have been called since the chunked body was incomplete + refute handler_called, "Handler should not be called when chunked transfer is aborted mid-stream" + end + end + private def handler_simple(request)