From 77b9342dd6c882eecd857af39469fdc29772aee3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Tue, 21 Apr 2026 14:26:21 +0200 Subject: [PATCH] perf(spanner): add gRPC channel pooling Applications that use Spanner for high throughput often need more than one gRPC channel in order to achieve the best possible performance. Having more than one channel both increases the number of concurrent streams that an application can have open at any time (the limit is 100 per gRPC channel), and distributes the load that is generated by the application across multiple Spanner frontends. --- .../src/batch_read_only_transaction.rs | 20 +- src/spanner/src/client.rs | 229 +++++++++++++++--- .../src/partitioned_dml_transaction.rs | 16 +- src/spanner/src/read_only_transaction.rs | 34 ++- src/spanner/src/read_write_transaction.rs | 108 ++++++--- src/spanner/src/result_set.rs | 17 +- src/spanner/src/session_maintainer.rs | 4 +- src/spanner/src/write_only_transaction.rs | 14 +- 8 files changed, 352 insertions(+), 90 deletions(-) diff --git a/src/spanner/src/batch_read_only_transaction.rs b/src/spanner/src/batch_read_only_transaction.rs index 5fdb38a221..a979268e5d 100644 --- a/src/spanner/src/batch_read_only_transaction.rs +++ b/src/spanner/src/batch_read_only_transaction.rs @@ -158,7 +158,11 @@ impl BatchReadOnlyTransaction { .context .client .spanner - .partition_query(request, crate::RequestOptions::default()) + .partition_query( + request, + crate::RequestOptions::default(), + self.inner.context.channel_hint, + ) .await?; Ok(response @@ -215,7 +219,11 @@ impl BatchReadOnlyTransaction { .context .client .spanner - .partition_read(request, crate::RequestOptions::default()) + .partition_read( + request, + crate::RequestOptions::default(), + self.inner.context.channel_hint, + ) .await?; Ok(response @@ -303,9 +311,10 @@ impl Partition { client: &DatabaseClient, req: &crate::model::ExecuteSqlRequest, ) -> crate::Result { + let channel_hint = client.spanner.next_channel_hint(); let stream = client .spanner - .execute_streaming_sql(req.clone(), crate::RequestOptions::default()) + .execute_streaming_sql(req.clone(), crate::RequestOptions::default(), channel_hint) .send() .await?; @@ -319,6 +328,7 @@ impl Partition { client.clone(), req.session.clone(), StreamOperation::Query(req.clone()), + channel_hint, )) } @@ -326,9 +336,10 @@ impl Partition { client: &DatabaseClient, req: &crate::model::ReadRequest, ) -> crate::Result { + let channel_hint = client.spanner.next_channel_hint(); let stream = client .spanner - .streaming_read(req.clone(), crate::RequestOptions::default()) + .streaming_read(req.clone(), crate::RequestOptions::default(), channel_hint) .send() .await?; @@ -342,6 +353,7 @@ impl Partition { client.clone(), req.session.clone(), StreamOperation::Read(req.clone()), + channel_hint, )) } } diff --git a/src/spanner/src/client.rs b/src/spanner/src/client.rs index f4ff2d4f82..9dc620d738 100644 --- a/src/spanner/src/client.rs +++ b/src/spanner/src/client.rs @@ -15,6 +15,7 @@ use crate::generated::gapic_dataplane::client::Spanner as GapicSpanner; use crate::server_streaming::builder; use gaxi::options::{ClientConfig, Credentials}; +use std::sync::atomic::{AtomicUsize, Ordering}; pub use crate::database_client::DatabaseClient; pub use crate::error::SpannerInternalError; @@ -50,8 +51,8 @@ pub use wkt::{DurationError, TimestampError}; /// [Spanner]: https://docs.cloud.google.com/spanner/docs #[derive(Clone, Debug)] pub struct Spanner { - inner: GapicSpanner, - grpc_client: Option, + pub(crate) channels: Vec, + pub(crate) counter: std::sync::Arc, } pub struct Factory; @@ -61,20 +62,19 @@ impl google_cloud_gax::client_builder::internal::ClientFactory for Factory { type Credentials = Credentials; async fn build(self, config: ClientConfig) -> crate::ClientBuilderResult { - let transport = - crate::generated::gapic_dataplane::transport::Spanner::new(config.clone()).await?; - let grpc_client = transport.inner.clone(); + let num_channels = std::env::var("SPANNER_NUM_CHANNELS") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(4); + + let mut channels = Vec::with_capacity(num_channels); + for _ in 0..num_channels { + channels.push(Channel::create(&config).await?); + } - let inner = if gaxi::options::tracing_enabled(&config) { - GapicSpanner::from_stub(crate::generated::gapic_dataplane::tracing::Spanner::new( - transport, - )) - } else { - GapicSpanner::from_stub(transport) - }; Ok(Spanner { - inner, - grpc_client: Some(grpc_client), + channels, + counter: std::sync::Arc::new(AtomicUsize::new(0)), }) } } @@ -146,17 +146,31 @@ impl Spanner { // This method is primarily for testing and doesn't fully initialize grpc_client. // For production use, prefer `Spanner::builder().build()`. Self { - inner: GapicSpanner::from_stub(stub), - grpc_client: None, + channels: vec![Channel { + inner: GapicSpanner::from_stub(stub), + grpc_client: None, + }], + counter: std::sync::Arc::new(AtomicUsize::new(0)), } } + pub(crate) fn get_channel(&self, hint: usize) -> &Channel { + let idx = hint % self.channels.len(); + &self.channels[idx] + } + + pub(crate) fn next_channel_hint(&self) -> usize { + self.counter.fetch_add(1, Ordering::Relaxed) + } + pub(crate) async fn create_session( &self, request: crate::model::CreateSessionRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result { - self.inner + self.get_channel(channel_hint) + .inner .create_session() .with_request(request) .with_options(options) @@ -168,8 +182,10 @@ impl Spanner { &self, request: crate::model::ExecuteSqlRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result { - self.inner + self.get_channel(channel_hint) + .inner .execute_sql() .with_request(request) .with_options(options) @@ -181,8 +197,10 @@ impl Spanner { &self, request: crate::model::ExecuteBatchDmlRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result { - self.inner + self.get_channel(channel_hint) + .inner .execute_batch_dml() .with_request(request) .with_options(options) @@ -194,8 +212,10 @@ impl Spanner { &self, request: crate::model::ReadRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result { - self.inner + self.get_channel(channel_hint) + .inner .read() .with_request(request) .with_options(options) @@ -207,8 +227,10 @@ impl Spanner { &self, request: crate::model::BeginTransactionRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result { - self.inner + self.get_channel(channel_hint) + .inner .begin_transaction() .with_request(request) .with_options(options) @@ -220,8 +242,10 @@ impl Spanner { &self, request: crate::model::CommitRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result { - self.inner + self.get_channel(channel_hint) + .inner .commit() .with_request(request) .with_options(options) @@ -233,8 +257,10 @@ impl Spanner { &self, request: crate::model::RollbackRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result<()> { - self.inner + self.get_channel(channel_hint) + .inner .rollback() .with_request(request) .with_options(options) @@ -246,8 +272,10 @@ impl Spanner { &self, request: crate::model::PartitionQueryRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result { - self.inner + self.get_channel(channel_hint) + .inner .partition_query() .with_request(request) .with_options(options) @@ -259,8 +287,10 @@ impl Spanner { &self, request: crate::model::PartitionReadRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result { - self.inner + self.get_channel(channel_hint) + .inner .partition_read() .with_request(request) .with_options(options) @@ -276,8 +306,10 @@ impl Spanner { &self, request: crate::model::ExecuteSqlRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> builder::ExecuteStreamingSql { - let grpc = self + let channel = self.get_channel(channel_hint); + let grpc = channel .grpc_client .as_ref() .expect("Streaming RPCs are not supported when using a stub client"); @@ -294,8 +326,10 @@ impl Spanner { &self, request: crate::model::ReadRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> builder::StreamingRead { - let grpc = self + let channel = self.get_channel(channel_hint); + let grpc = channel .grpc_client .as_ref() .expect("Streaming RPCs are not supported when using a stub client"); @@ -308,8 +342,10 @@ impl Spanner { &self, request: crate::model::BatchWriteRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> builder::BatchWrite { - let grpc = self + let channel = self.get_channel(channel_hint); + let grpc = channel .grpc_client .as_ref() .expect("Streaming RPCs are not supported when using a stub client"); @@ -319,6 +355,32 @@ impl Spanner { } } +#[derive(Clone, Debug)] +pub(crate) struct Channel { + pub(crate) inner: GapicSpanner, + pub(crate) grpc_client: Option, +} + +impl Channel { + pub(crate) async fn create(config: &ClientConfig) -> crate::ClientBuilderResult { + let transport = + crate::generated::gapic_dataplane::transport::Spanner::new(config.clone()).await?; + let grpc_client = transport.inner.clone(); + + let inner = if gaxi::options::tracing_enabled(config) { + GapicSpanner::from_stub(crate::generated::gapic_dataplane::tracing::Spanner::new( + transport, + )) + } else { + GapicSpanner::from_stub(transport) + }; + Ok(Self { + inner, + grpc_client: Some(grpc_client), + }) + } +} + #[cfg(test)] mod tests { use super::*; @@ -336,6 +398,50 @@ mod tests { assert_not_impl_any!(Spanner: std::panic::RefUnwindSafe, std::panic::UnwindSafe); } + #[tokio::test] + async fn channel_pool_default_size() { + let mock = MockSpanner::new(); + let (address, _server) = start("0.0.0.0:0", mock) + .await + .expect("Failed to start mock server"); + + let client = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await + .expect("Failed to build client"); + + assert_eq!(client.channels.len(), 4); + } + + #[tokio::test] + async fn channel_selection() { + let mock = MockSpanner::new(); + let (address, _server) = start("0.0.0.0:0", mock) + .await + .expect("Failed to start mock server"); + + let client = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await + .expect("Failed to build client"); + + let hint0 = client.next_channel_hint(); + let hint1 = client.next_channel_hint(); + let hint2 = client.next_channel_hint(); + let hint3 = client.next_channel_hint(); + let hint4 = client.next_channel_hint(); + + assert_eq!(hint0 % 4, 0); + assert_eq!(hint1 % 4, 1); + assert_eq!(hint2 % 4, 2); + assert_eq!(hint3 % 4, 3); + assert_eq!(hint4 % 4, 0); + } + #[tokio::test] async fn test_create_session() { // 1. Setup Mock Server @@ -368,7 +474,11 @@ mod tests { "projects/test-project/instances/test-instance/databases/test-db".to_string(); let session = client - .create_session(req, crate::RequestOptions::default()) + .create_session( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call create_session"); @@ -422,6 +532,7 @@ mod tests { "projects/test-project/instances/test-instance/databases/test-db".to_string(); let session = client + .get_channel(client.next_channel_hint()) .inner .create_session() .with_request(req) @@ -471,7 +582,11 @@ mod tests { req.sql = "SELECT 1".to_string(); let result_set = client - .execute_sql(req, crate::RequestOptions::default()) + .execute_sql( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call execute_sql"); assert!(result_set.metadata.is_some()); @@ -510,7 +625,11 @@ mod tests { req.session = "test_session".to_string(); let response = client - .execute_batch_dml(req, crate::RequestOptions::default()) + .execute_batch_dml( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call execute_batch_dml"); assert!(response.status.is_some()); @@ -545,7 +664,11 @@ mod tests { req.table = "test_table".to_string(); let result_set = client - .read(req, crate::RequestOptions::default()) + .read( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call read"); assert!(result_set.metadata.is_none()); @@ -579,7 +702,11 @@ mod tests { req.session = "test_session".to_string(); let tx = client - .begin_transaction(req, crate::RequestOptions::default()) + .begin_transaction( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call begin_transaction"); assert_eq!(tx.id, vec![1, 2, 3]); @@ -617,7 +744,11 @@ mod tests { req.session = "test_session".to_string(); let response = client - .commit(req, crate::RequestOptions::default()) + .commit( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call commit"); assert!(response.commit_timestamp.is_some()); @@ -646,7 +777,11 @@ mod tests { req.session = "test_session".to_string(); client - .rollback(req, crate::RequestOptions::default()) + .rollback( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call rollback"); } @@ -688,7 +823,11 @@ mod tests { req.sql = "SELECT 1".to_string(); let mut stream = client - .execute_streaming_sql(req, crate::RequestOptions::default()) + .execute_streaming_sql( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .send() .await .expect("Failed to call execute_streaming_sql"); @@ -736,7 +875,11 @@ mod tests { req.columns = vec!["col1".to_string()]; let mut stream = client - .streaming_read(req, crate::RequestOptions::default()) + .streaming_read( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .send() .await .expect("Failed to call streaming_read"); @@ -774,7 +917,11 @@ mod tests { req.session = "test_session".to_string(); let mut stream = client - .batch_write(req, crate::RequestOptions::default()) + .batch_write( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .send() .await .expect("Failed to call batch_write"); @@ -810,7 +957,11 @@ mod tests { req.sql = "SELECT 1".to_string(); let mut stream = client - .execute_streaming_sql(req, crate::RequestOptions::default()) + .execute_streaming_sql( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .send() .await .expect("Failed to call execute_streaming_sql"); diff --git a/src/spanner/src/partitioned_dml_transaction.rs b/src/spanner/src/partitioned_dml_transaction.rs index d1b4397821..f6944cc6a2 100644 --- a/src/spanner/src/partitioned_dml_transaction.rs +++ b/src/spanner/src/partitioned_dml_transaction.rs @@ -140,13 +140,18 @@ impl PartitionedDmlTransaction { ..Default::default() }; let base_request = statement.into_request(); + let channel_hint = self.client.spanner.next_channel_hint(); // Execute the statement and retry if the transaction is aborted by Spanner. retry_aborted(&*self.retry_policy, || async { let transaction = self .client .spanner - .begin_transaction(begin_request.clone(), crate::RequestOptions::default()) + .begin_transaction( + begin_request.clone(), + crate::RequestOptions::default(), + channel_hint, + ) .await?; let execute_request = base_request @@ -157,10 +162,11 @@ impl PartitionedDmlTransaction { ..Default::default() }); - let stream_builder = self - .client - .spanner - .execute_streaming_sql(execute_request.clone(), crate::RequestOptions::default()); + let stream_builder = self.client.spanner.execute_streaming_sql( + execute_request.clone(), + crate::RequestOptions::default(), + channel_hint, + ); let stream = stream_builder.send().await?; extract_lower_bound_update_count_from_stream(stream).await diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index b4e0ce8c58..5c004124ac 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -90,6 +90,7 @@ impl SingleUseReadOnlyTransactionBuilder { .set_single_use(TransactionOptions::default().set_read_only(read_only)); let session_name = self.client.session_name(); + let channel_hint = self.client.spanner.next_channel_hint(); SingleUseReadOnlyTransaction { context: ReadContext { session_name, @@ -100,6 +101,7 @@ impl SingleUseReadOnlyTransactionBuilder { ), precommit_token_tracker: PrecommitTokenTracker::new_noop(), transaction_tag: None, + channel_hint, }, } } @@ -277,8 +279,10 @@ impl MultiUseReadOnlyTransactionBuilder { &self, session_name: String, options: TransactionOptions, + channel_hint: usize, ) -> crate::Result { - let response = execute_begin_transaction(&self.client, session_name, options).await?; + let response = + execute_begin_transaction(&self.client, session_name, options, channel_hint).await?; let transaction_selector = crate::model::TransactionSelector::default().set_id(response.id); @@ -309,8 +313,10 @@ impl MultiUseReadOnlyTransactionBuilder { let options = TransactionOptions::default().set_read_only(read_only); let session_name = self.client.session_name(); + let channel_hint = self.client.spanner.next_channel_hint(); let selector = if self.explicit_begin { - self.begin(session_name.clone(), options).await? + self.begin(session_name.clone(), options, channel_hint) + .await? } else { ReadContextTransactionSelector::Lazy(Arc::new(Mutex::new( TransactionState::NotStarted(options), @@ -324,6 +330,7 @@ impl MultiUseReadOnlyTransactionBuilder { transaction_selector: selector, precommit_token_tracker: PrecommitTokenTracker::new_noop(), transaction_tag: None, + channel_hint, }, }) } @@ -430,6 +437,7 @@ async fn execute_begin_transaction( client: &crate::database_client::DatabaseClient, session_name: String, options: crate::model::TransactionOptions, + channel_hint: usize, ) -> crate::Result { let request = crate::model::BeginTransactionRequest::default() .set_session(session_name) @@ -438,7 +446,7 @@ async fn execute_begin_transaction( // TODO(#4972): make request options configurable client .spanner - .begin_transaction(request, crate::RequestOptions::default()) + .begin_transaction(request, crate::RequestOptions::default(), channel_hint) .await } @@ -484,6 +492,7 @@ impl ReadContextTransactionSelector { &self, client: &crate::database_client::DatabaseClient, session_name: String, + channel_hint: usize, ) -> crate::Result<()> { let Self::Lazy(lazy) = self else { return Ok(()); @@ -497,7 +506,8 @@ impl ReadContextTransactionSelector { options.clone() }; - let response = execute_begin_transaction(client, session_name, options).await?; + let response = + execute_begin_transaction(client, session_name, options, channel_hint).await?; self.update(response.id, response.read_timestamp); Ok(()) @@ -537,6 +547,7 @@ pub(crate) struct ReadContext { pub(crate) transaction_selector: ReadContextTransactionSelector, pub(crate) precommit_token_tracker: PrecommitTokenTracker, pub(crate) transaction_tag: Option, + pub(crate) channel_hint: usize, } impl ReadContext { @@ -571,7 +582,7 @@ impl ReadContext { } self.transaction_selector - .begin_explicitly(&self.client, self.session_name.clone()) + .begin_explicitly(&self.client, self.session_name.clone(), self.channel_hint) .await?; Ok(true) } @@ -584,7 +595,11 @@ macro_rules! execute_stream_with_retry { .client .spanner // TODO(#4972): make request options configurable - .$rpc_method($request.clone(), crate::RequestOptions::default()) + .$rpc_method( + $request.clone(), + crate::RequestOptions::default(), + $self.channel_hint, + ) .send() .await { @@ -596,7 +611,11 @@ macro_rules! execute_stream_with_retry { .client .spanner // TODO(#4972): make request options configurable - .$rpc_method($request.clone(), crate::RequestOptions::default()) + .$rpc_method( + $request.clone(), + crate::RequestOptions::default(), + $self.channel_hint, + ) .send() .await? } else { @@ -612,6 +631,7 @@ macro_rules! execute_stream_with_retry { $self.client.clone(), $self.session_name.clone(), $operation_variant($request), + $self.channel_hint, )) }}; } diff --git a/src/spanner/src/read_write_transaction.rs b/src/spanner/src/read_write_transaction.rs index 74c3aa1ec4..ef28895f91 100644 --- a/src/spanner/src/read_write_transaction.rs +++ b/src/spanner/src/read_write_transaction.rs @@ -98,10 +98,11 @@ impl ReadWriteTransactionBuilder { } // TODO(#4972): make request options configurable + let channel_hint = self.client.spanner.next_channel_hint(); let response = self .client .spanner - .begin_transaction(request, RequestOptions::default()) + .begin_transaction(request, RequestOptions::default(), channel_hint) .await?; let transaction_selector = @@ -116,6 +117,7 @@ impl ReadWriteTransactionBuilder { transaction_selector, precommit_token_tracker: PrecommitTokenTracker::new(), transaction_tag: self.transaction_tag.clone(), + channel_hint, }, seqno: Arc::new(AtomicI64::new(1)), }) @@ -161,7 +163,11 @@ impl ReadWriteTransaction { .context .client .spanner - .execute_sql(request, RequestOptions::default()) + .execute_sql( + request, + RequestOptions::default(), + self.context.channel_hint, + ) .await?; self.context .precommit_token_tracker @@ -265,7 +271,11 @@ impl ReadWriteTransaction { .context .client .spanner - .execute_batch_dml(request, RequestOptions::default()) + .execute_batch_dml( + request, + RequestOptions::default(), + self.context.channel_hint, + ) .await; match response_result { @@ -300,7 +310,11 @@ impl ReadWriteTransaction { .context .client .spanner - .commit(request, RequestOptions::default()) + .commit( + request, + RequestOptions::default(), + self.context.channel_hint, + ) .await?; let response = @@ -314,7 +328,11 @@ impl ReadWriteTransaction { self.context .client .spanner - .commit(retry_commit_req, RequestOptions::default()) + .commit( + retry_commit_req, + RequestOptions::default(), + self.context.channel_hint, + ) .await? } else { response @@ -337,7 +355,11 @@ impl ReadWriteTransaction { self.context .client .spanner - .rollback(request, RequestOptions::default()) + .rollback( + request, + RequestOptions::default(), + self.context.channel_hint, + ) .await?; Ok(()) @@ -352,6 +374,7 @@ mod tests { use gaxi::grpc::tonic; use spanner_grpc_mock::google::spanner::v1; use std::fmt::Debug; + use std::sync::Mutex; #[test] fn auto_traits() { @@ -360,23 +383,36 @@ mod tests { } #[tokio::test] - async fn read_write_transaction_commit_retry() { + async fn read_write_transaction_commit_retry() -> anyhow::Result<()> { let mut mock = create_session_mock(); - - mock.expect_begin_transaction().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.session, - "projects/p/instances/i/databases/d/sessions/123" - ); - Ok(tonic::Response::new(v1::Transaction { - id: vec![0, 0, 7], - ..Default::default() - })) - }); + let remotes = Arc::new(Mutex::new(Vec::new())); + + let remotes_clone = remotes.clone(); + mock.expect_begin_transaction() + .once() + .returning(move |req| { + remotes_clone + .lock() + .unwrap() + .push(req.remote_addr().expect("remote_addr should be available")); + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + Ok(tonic::Response::new(v1::Transaction { + id: vec![0, 0, 7], + ..Default::default() + })) + }); // execute_update returns a precommit token. - mock.expect_execute_sql().once().returning(|req| { + let remotes_clone = remotes.clone(); + mock.expect_execute_sql().once().returning(move |req| { + remotes_clone + .lock() + .unwrap() + .push(req.remote_addr().expect("remote_addr should be available")); let req = req.into_inner(); assert_eq!(req.sql, "UPDATE Users SET Name = 'Bob' WHERE Id = 1"); Ok(tonic::Response::new(v1::ResultSet { @@ -395,7 +431,12 @@ mod tests { // Simulate that commit returns a precommit token in the response. // This would normally not happen, but we test it here to verify // that the commit is retried. - mock.expect_commit().once().returning(|req| { + let remotes_clone = remotes.clone(); + mock.expect_commit().once().returning(move |req| { + remotes_clone + .lock() + .unwrap() + .push(req.remote_addr().expect("remote_addr should be available")); let req = req.into_inner(); assert_eq!( req.precommit_token, @@ -422,7 +463,12 @@ mod tests { }); // Second commit retry is automatically issued with the new token - mock.expect_commit().once().returning(|req| { + let remotes_clone = remotes.clone(); + mock.expect_commit().once().returning(move |req| { + remotes_clone + .lock() + .unwrap() + .push(req.remote_addr().expect("remote_addr should be available")); let req = req.into_inner(); assert_eq!( req.precommit_token, @@ -444,17 +490,25 @@ mod tests { let tx = ReadWriteTransactionBuilder::new(db_client.clone()) .begin_transaction() - .await - .expect("Failed to build transaction"); + .await?; let count = tx .execute_update("UPDATE Users SET Name = 'Bob' WHERE Id = 1") - .await - .unwrap(); + .await?; assert_eq!(count, 1); - let timestamp = tx.commit().await.unwrap(); + let timestamp = tx.commit().await?; assert_eq!(timestamp.seconds(), 1001); + + // Verify that all RPCs used the same channel (same remote address) + let remotes = remotes.lock().unwrap(); + assert_eq!(remotes.len(), 4, "Expected exactly 4 RPCs"); + let first = remotes[0]; + for addr in remotes.iter() { + assert_eq!(*addr, first, "All RPCs should use the same gRPC channel"); + } + + Ok(()) } #[tokio::test] diff --git a/src/spanner/src/result_set.rs b/src/spanner/src/result_set.rs index e8da3f5ea0..efdfeee3cf 100644 --- a/src/spanner/src/result_set.rs +++ b/src/spanner/src/result_set.rs @@ -61,6 +61,7 @@ pub struct ResultSet { max_buffered_partial_result_sets: usize, retry_count: usize, transaction_selector: Option, + channel_hint: usize, } /// Errors that can occur when interacting with a [`ResultSet`]. @@ -92,6 +93,7 @@ impl ResultSet { client: DatabaseClient, session_name: String, operation: StreamOperation, + channel_hint: usize, ) -> Self { Self { stream, @@ -109,6 +111,7 @@ impl ResultSet { max_buffered_partial_result_sets: MAX_BUFFERED_PARTIAL_RESULT_SETS, retry_count: 0, transaction_selector, + channel_hint, } } @@ -253,7 +256,7 @@ impl ResultSet { self.transaction_selector .as_ref() .unwrap() - .begin_explicitly(&self.client, self.session_name.clone()) + .begin_explicitly(&self.client, self.session_name.clone(), self.channel_hint) .await?; self.partial_result_sets_buffer.clear(); @@ -383,7 +386,11 @@ impl ResultSet { let stream = self .client .spanner - .execute_streaming_sql(req.clone(), crate::RequestOptions::default()) + .execute_streaming_sql( + req.clone(), + crate::RequestOptions::default(), + self.channel_hint, + ) .send() .await?; self.stream = stream; @@ -396,7 +403,11 @@ impl ResultSet { let stream = self .client .spanner - .streaming_read(req.clone(), crate::RequestOptions::default()) + .streaming_read( + req.clone(), + crate::RequestOptions::default(), + self.channel_hint, + ) .send() .await?; self.stream = stream; diff --git a/src/spanner/src/session_maintainer.rs b/src/spanner/src/session_maintainer.rs index ff32116662..d03ca001c2 100644 --- a/src/spanner/src/session_maintainer.rs +++ b/src/spanner/src/session_maintainer.rs @@ -134,7 +134,9 @@ impl ManagedSessionMaintainer { .set_creator_role(database_role), ); - spanner.create_session(request, options.clone()).await + spanner + .create_session(request, options.clone(), spanner.next_channel_hint()) + .await } async fn maintenance_loop( diff --git a/src/spanner/src/write_only_transaction.rs b/src/spanner/src/write_only_transaction.rs index 4352eada83..7910b99506 100644 --- a/src/spanner/src/write_only_transaction.rs +++ b/src/spanner/src/write_only_transaction.rs @@ -161,6 +161,7 @@ impl WriteOnlyTransaction { let client = self.client; let session_name = self.session_name.clone(); let previous_transaction_id = Arc::new(Mutex::new(Bytes::new())); + let channel_hint = client.spanner.next_channel_hint(); retry_aborted(&*self.retry_policy, || { let client = client.clone(); @@ -186,7 +187,7 @@ impl WriteOnlyTransaction { let tx = client .spanner - .begin_transaction(begin_req, crate::RequestOptions::default()) + .begin_transaction(begin_req, crate::RequestOptions::default(), channel_hint) .await?; *previous_transaction_id.lock().unwrap() = tx.id.clone(); @@ -199,7 +200,7 @@ impl WriteOnlyTransaction { let response = client .spanner - .commit(commit_req, crate::RequestOptions::default()) + .commit(commit_req, crate::RequestOptions::default(), channel_hint) .await?; // If a commit_response with a precommit_token is returned, then we need to @@ -212,7 +213,11 @@ impl WriteOnlyTransaction { .set_precommit_token(new_token); client .spanner - .commit(retry_commit_req, crate::RequestOptions::default()) + .commit( + retry_commit_req, + crate::RequestOptions::default(), + channel_hint, + ) .await } else { Ok(response) @@ -268,6 +273,7 @@ impl WriteOnlyTransaction { .set_single_use_transaction(Box::new(single_use)) .set_request_options(req_options); let client = self.client; + let channel_hint = client.spanner.next_channel_hint(); retry_aborted(&*self.retry_policy, || { let client = client.clone(); @@ -276,7 +282,7 @@ impl WriteOnlyTransaction { async move { client .spanner - .commit(request, crate::RequestOptions::default()) + .commit(request, crate::RequestOptions::default(), channel_hint) .await } })