diff --git a/src/spanner/src/batch_read_only_transaction.rs b/src/spanner/src/batch_read_only_transaction.rs index 788e4b22ef..3c500b174d 100644 --- a/src/spanner/src/batch_read_only_transaction.rs +++ b/src/spanner/src/batch_read_only_transaction.rs @@ -158,7 +158,11 @@ impl BatchReadOnlyTransaction { .context .client .spanner - .partition_query(request, crate::RequestOptions::default()) + .partition_query( + request, + crate::RequestOptions::default(), + self.inner.context.channel_hint, + ) .await?; Ok(response @@ -215,7 +219,11 @@ impl BatchReadOnlyTransaction { .context .client .spanner - .partition_read(request, crate::RequestOptions::default()) + .partition_read( + request, + crate::RequestOptions::default(), + self.inner.context.channel_hint, + ) .await?; Ok(response @@ -329,9 +337,10 @@ impl Partition { client: &DatabaseClient, req: &crate::model::ExecuteSqlRequest, ) -> crate::Result { + let channel_hint = client.spanner.next_channel_hint(); let stream = client .spanner - .execute_streaming_sql(req.clone(), crate::RequestOptions::default()) + .execute_streaming_sql(req.clone(), crate::RequestOptions::default(), channel_hint) .send() .await?; @@ -345,6 +354,7 @@ impl Partition { client.clone(), req.session.clone(), StreamOperation::Query(req.clone()), + channel_hint, )) } @@ -352,9 +362,10 @@ impl Partition { client: &DatabaseClient, req: &crate::model::ReadRequest, ) -> crate::Result { + let channel_hint = client.spanner.next_channel_hint(); let stream = client .spanner - .streaming_read(req.clone(), crate::RequestOptions::default()) + .streaming_read(req.clone(), crate::RequestOptions::default(), channel_hint) .send() .await?; @@ -368,6 +379,7 @@ impl Partition { client.clone(), req.session.clone(), StreamOperation::Read(req.clone()), + channel_hint, )) } } diff --git a/src/spanner/src/client.rs b/src/spanner/src/client.rs index f4ff2d4f82..9dc620d738 100644 --- a/src/spanner/src/client.rs +++ b/src/spanner/src/client.rs @@ -15,6 +15,7 @@ use crate::generated::gapic_dataplane::client::Spanner as GapicSpanner; use crate::server_streaming::builder; use gaxi::options::{ClientConfig, Credentials}; +use std::sync::atomic::{AtomicUsize, Ordering}; pub use crate::database_client::DatabaseClient; pub use crate::error::SpannerInternalError; @@ -50,8 +51,8 @@ pub use wkt::{DurationError, TimestampError}; /// [Spanner]: https://docs.cloud.google.com/spanner/docs #[derive(Clone, Debug)] pub struct Spanner { - inner: GapicSpanner, - grpc_client: Option, + pub(crate) channels: Vec, + pub(crate) counter: std::sync::Arc, } pub struct Factory; @@ -61,20 +62,19 @@ impl google_cloud_gax::client_builder::internal::ClientFactory for Factory { type Credentials = Credentials; async fn build(self, config: ClientConfig) -> crate::ClientBuilderResult { - let transport = - crate::generated::gapic_dataplane::transport::Spanner::new(config.clone()).await?; - let grpc_client = transport.inner.clone(); + let num_channels = std::env::var("SPANNER_NUM_CHANNELS") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(4); + + let mut channels = Vec::with_capacity(num_channels); + for _ in 0..num_channels { + channels.push(Channel::create(&config).await?); + } - let inner = if gaxi::options::tracing_enabled(&config) { - GapicSpanner::from_stub(crate::generated::gapic_dataplane::tracing::Spanner::new( - transport, - )) - } else { - GapicSpanner::from_stub(transport) - }; Ok(Spanner { - inner, - grpc_client: Some(grpc_client), + channels, + counter: std::sync::Arc::new(AtomicUsize::new(0)), }) } } @@ -146,17 +146,31 @@ impl Spanner { // This method is primarily for testing and doesn't fully initialize grpc_client. // For production use, prefer `Spanner::builder().build()`. Self { - inner: GapicSpanner::from_stub(stub), - grpc_client: None, + channels: vec![Channel { + inner: GapicSpanner::from_stub(stub), + grpc_client: None, + }], + counter: std::sync::Arc::new(AtomicUsize::new(0)), } } + pub(crate) fn get_channel(&self, hint: usize) -> &Channel { + let idx = hint % self.channels.len(); + &self.channels[idx] + } + + pub(crate) fn next_channel_hint(&self) -> usize { + self.counter.fetch_add(1, Ordering::Relaxed) + } + pub(crate) async fn create_session( &self, request: crate::model::CreateSessionRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result { - self.inner + self.get_channel(channel_hint) + .inner .create_session() .with_request(request) .with_options(options) @@ -168,8 +182,10 @@ impl Spanner { &self, request: crate::model::ExecuteSqlRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result { - self.inner + self.get_channel(channel_hint) + .inner .execute_sql() .with_request(request) .with_options(options) @@ -181,8 +197,10 @@ impl Spanner { &self, request: crate::model::ExecuteBatchDmlRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result { - self.inner + self.get_channel(channel_hint) + .inner .execute_batch_dml() .with_request(request) .with_options(options) @@ -194,8 +212,10 @@ impl Spanner { &self, request: crate::model::ReadRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result { - self.inner + self.get_channel(channel_hint) + .inner .read() .with_request(request) .with_options(options) @@ -207,8 +227,10 @@ impl Spanner { &self, request: crate::model::BeginTransactionRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result { - self.inner + self.get_channel(channel_hint) + .inner .begin_transaction() .with_request(request) .with_options(options) @@ -220,8 +242,10 @@ impl Spanner { &self, request: crate::model::CommitRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result { - self.inner + self.get_channel(channel_hint) + .inner .commit() .with_request(request) .with_options(options) @@ -233,8 +257,10 @@ impl Spanner { &self, request: crate::model::RollbackRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result<()> { - self.inner + self.get_channel(channel_hint) + .inner .rollback() .with_request(request) .with_options(options) @@ -246,8 +272,10 @@ impl Spanner { &self, request: crate::model::PartitionQueryRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result { - self.inner + self.get_channel(channel_hint) + .inner .partition_query() .with_request(request) .with_options(options) @@ -259,8 +287,10 @@ impl Spanner { &self, request: crate::model::PartitionReadRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result { - self.inner + self.get_channel(channel_hint) + .inner .partition_read() .with_request(request) .with_options(options) @@ -276,8 +306,10 @@ impl Spanner { &self, request: crate::model::ExecuteSqlRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> builder::ExecuteStreamingSql { - let grpc = self + let channel = self.get_channel(channel_hint); + let grpc = channel .grpc_client .as_ref() .expect("Streaming RPCs are not supported when using a stub client"); @@ -294,8 +326,10 @@ impl Spanner { &self, request: crate::model::ReadRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> builder::StreamingRead { - let grpc = self + let channel = self.get_channel(channel_hint); + let grpc = channel .grpc_client .as_ref() .expect("Streaming RPCs are not supported when using a stub client"); @@ -308,8 +342,10 @@ impl Spanner { &self, request: crate::model::BatchWriteRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> builder::BatchWrite { - let grpc = self + let channel = self.get_channel(channel_hint); + let grpc = channel .grpc_client .as_ref() .expect("Streaming RPCs are not supported when using a stub client"); @@ -319,6 +355,32 @@ impl Spanner { } } +#[derive(Clone, Debug)] +pub(crate) struct Channel { + pub(crate) inner: GapicSpanner, + pub(crate) grpc_client: Option, +} + +impl Channel { + pub(crate) async fn create(config: &ClientConfig) -> crate::ClientBuilderResult { + let transport = + crate::generated::gapic_dataplane::transport::Spanner::new(config.clone()).await?; + let grpc_client = transport.inner.clone(); + + let inner = if gaxi::options::tracing_enabled(config) { + GapicSpanner::from_stub(crate::generated::gapic_dataplane::tracing::Spanner::new( + transport, + )) + } else { + GapicSpanner::from_stub(transport) + }; + Ok(Self { + inner, + grpc_client: Some(grpc_client), + }) + } +} + #[cfg(test)] mod tests { use super::*; @@ -336,6 +398,50 @@ mod tests { assert_not_impl_any!(Spanner: std::panic::RefUnwindSafe, std::panic::UnwindSafe); } + #[tokio::test] + async fn channel_pool_default_size() { + let mock = MockSpanner::new(); + let (address, _server) = start("0.0.0.0:0", mock) + .await + .expect("Failed to start mock server"); + + let client = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await + .expect("Failed to build client"); + + assert_eq!(client.channels.len(), 4); + } + + #[tokio::test] + async fn channel_selection() { + let mock = MockSpanner::new(); + let (address, _server) = start("0.0.0.0:0", mock) + .await + .expect("Failed to start mock server"); + + let client = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await + .expect("Failed to build client"); + + let hint0 = client.next_channel_hint(); + let hint1 = client.next_channel_hint(); + let hint2 = client.next_channel_hint(); + let hint3 = client.next_channel_hint(); + let hint4 = client.next_channel_hint(); + + assert_eq!(hint0 % 4, 0); + assert_eq!(hint1 % 4, 1); + assert_eq!(hint2 % 4, 2); + assert_eq!(hint3 % 4, 3); + assert_eq!(hint4 % 4, 0); + } + #[tokio::test] async fn test_create_session() { // 1. Setup Mock Server @@ -368,7 +474,11 @@ mod tests { "projects/test-project/instances/test-instance/databases/test-db".to_string(); let session = client - .create_session(req, crate::RequestOptions::default()) + .create_session( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call create_session"); @@ -422,6 +532,7 @@ mod tests { "projects/test-project/instances/test-instance/databases/test-db".to_string(); let session = client + .get_channel(client.next_channel_hint()) .inner .create_session() .with_request(req) @@ -471,7 +582,11 @@ mod tests { req.sql = "SELECT 1".to_string(); let result_set = client - .execute_sql(req, crate::RequestOptions::default()) + .execute_sql( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call execute_sql"); assert!(result_set.metadata.is_some()); @@ -510,7 +625,11 @@ mod tests { req.session = "test_session".to_string(); let response = client - .execute_batch_dml(req, crate::RequestOptions::default()) + .execute_batch_dml( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call execute_batch_dml"); assert!(response.status.is_some()); @@ -545,7 +664,11 @@ mod tests { req.table = "test_table".to_string(); let result_set = client - .read(req, crate::RequestOptions::default()) + .read( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call read"); assert!(result_set.metadata.is_none()); @@ -579,7 +702,11 @@ mod tests { req.session = "test_session".to_string(); let tx = client - .begin_transaction(req, crate::RequestOptions::default()) + .begin_transaction( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call begin_transaction"); assert_eq!(tx.id, vec![1, 2, 3]); @@ -617,7 +744,11 @@ mod tests { req.session = "test_session".to_string(); let response = client - .commit(req, crate::RequestOptions::default()) + .commit( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call commit"); assert!(response.commit_timestamp.is_some()); @@ -646,7 +777,11 @@ mod tests { req.session = "test_session".to_string(); client - .rollback(req, crate::RequestOptions::default()) + .rollback( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call rollback"); } @@ -688,7 +823,11 @@ mod tests { req.sql = "SELECT 1".to_string(); let mut stream = client - .execute_streaming_sql(req, crate::RequestOptions::default()) + .execute_streaming_sql( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .send() .await .expect("Failed to call execute_streaming_sql"); @@ -736,7 +875,11 @@ mod tests { req.columns = vec!["col1".to_string()]; let mut stream = client - .streaming_read(req, crate::RequestOptions::default()) + .streaming_read( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .send() .await .expect("Failed to call streaming_read"); @@ -774,7 +917,11 @@ mod tests { req.session = "test_session".to_string(); let mut stream = client - .batch_write(req, crate::RequestOptions::default()) + .batch_write( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .send() .await .expect("Failed to call batch_write"); @@ -810,7 +957,11 @@ mod tests { req.sql = "SELECT 1".to_string(); let mut stream = client - .execute_streaming_sql(req, crate::RequestOptions::default()) + .execute_streaming_sql( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .send() .await .expect("Failed to call execute_streaming_sql"); diff --git a/src/spanner/src/partitioned_dml_transaction.rs b/src/spanner/src/partitioned_dml_transaction.rs index c28a271b17..42b7d083d7 100644 --- a/src/spanner/src/partitioned_dml_transaction.rs +++ b/src/spanner/src/partitioned_dml_transaction.rs @@ -173,13 +173,18 @@ impl PartitionedDmlTransaction { ..Default::default() }; let base_request = statement.into_request(); + let channel_hint = self.client.spanner.next_channel_hint(); // Execute the statement and retry if the transaction is aborted by Spanner. retry_aborted(&*self.retry_policy, || async { let transaction = self .client .spanner - .begin_transaction(begin_request.clone(), crate::RequestOptions::default()) + .begin_transaction( + begin_request.clone(), + crate::RequestOptions::default(), + channel_hint, + ) .await?; let execute_request = base_request @@ -190,10 +195,11 @@ impl PartitionedDmlTransaction { ..Default::default() }); - let stream_builder = self - .client - .spanner - .execute_streaming_sql(execute_request.clone(), crate::RequestOptions::default()); + let stream_builder = self.client.spanner.execute_streaming_sql( + execute_request.clone(), + crate::RequestOptions::default(), + channel_hint, + ); let stream = stream_builder.send().await?; extract_lower_bound_update_count_from_stream(stream).await diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index b4e0ce8c58..5c004124ac 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -90,6 +90,7 @@ impl SingleUseReadOnlyTransactionBuilder { .set_single_use(TransactionOptions::default().set_read_only(read_only)); let session_name = self.client.session_name(); + let channel_hint = self.client.spanner.next_channel_hint(); SingleUseReadOnlyTransaction { context: ReadContext { session_name, @@ -100,6 +101,7 @@ impl SingleUseReadOnlyTransactionBuilder { ), precommit_token_tracker: PrecommitTokenTracker::new_noop(), transaction_tag: None, + channel_hint, }, } } @@ -277,8 +279,10 @@ impl MultiUseReadOnlyTransactionBuilder { &self, session_name: String, options: TransactionOptions, + channel_hint: usize, ) -> crate::Result { - let response = execute_begin_transaction(&self.client, session_name, options).await?; + let response = + execute_begin_transaction(&self.client, session_name, options, channel_hint).await?; let transaction_selector = crate::model::TransactionSelector::default().set_id(response.id); @@ -309,8 +313,10 @@ impl MultiUseReadOnlyTransactionBuilder { let options = TransactionOptions::default().set_read_only(read_only); let session_name = self.client.session_name(); + let channel_hint = self.client.spanner.next_channel_hint(); let selector = if self.explicit_begin { - self.begin(session_name.clone(), options).await? + self.begin(session_name.clone(), options, channel_hint) + .await? } else { ReadContextTransactionSelector::Lazy(Arc::new(Mutex::new( TransactionState::NotStarted(options), @@ -324,6 +330,7 @@ impl MultiUseReadOnlyTransactionBuilder { transaction_selector: selector, precommit_token_tracker: PrecommitTokenTracker::new_noop(), transaction_tag: None, + channel_hint, }, }) } @@ -430,6 +437,7 @@ async fn execute_begin_transaction( client: &crate::database_client::DatabaseClient, session_name: String, options: crate::model::TransactionOptions, + channel_hint: usize, ) -> crate::Result { let request = crate::model::BeginTransactionRequest::default() .set_session(session_name) @@ -438,7 +446,7 @@ async fn execute_begin_transaction( // TODO(#4972): make request options configurable client .spanner - .begin_transaction(request, crate::RequestOptions::default()) + .begin_transaction(request, crate::RequestOptions::default(), channel_hint) .await } @@ -484,6 +492,7 @@ impl ReadContextTransactionSelector { &self, client: &crate::database_client::DatabaseClient, session_name: String, + channel_hint: usize, ) -> crate::Result<()> { let Self::Lazy(lazy) = self else { return Ok(()); @@ -497,7 +506,8 @@ impl ReadContextTransactionSelector { options.clone() }; - let response = execute_begin_transaction(client, session_name, options).await?; + let response = + execute_begin_transaction(client, session_name, options, channel_hint).await?; self.update(response.id, response.read_timestamp); Ok(()) @@ -537,6 +547,7 @@ pub(crate) struct ReadContext { pub(crate) transaction_selector: ReadContextTransactionSelector, pub(crate) precommit_token_tracker: PrecommitTokenTracker, pub(crate) transaction_tag: Option, + pub(crate) channel_hint: usize, } impl ReadContext { @@ -571,7 +582,7 @@ impl ReadContext { } self.transaction_selector - .begin_explicitly(&self.client, self.session_name.clone()) + .begin_explicitly(&self.client, self.session_name.clone(), self.channel_hint) .await?; Ok(true) } @@ -584,7 +595,11 @@ macro_rules! execute_stream_with_retry { .client .spanner // TODO(#4972): make request options configurable - .$rpc_method($request.clone(), crate::RequestOptions::default()) + .$rpc_method( + $request.clone(), + crate::RequestOptions::default(), + $self.channel_hint, + ) .send() .await { @@ -596,7 +611,11 @@ macro_rules! execute_stream_with_retry { .client .spanner // TODO(#4972): make request options configurable - .$rpc_method($request.clone(), crate::RequestOptions::default()) + .$rpc_method( + $request.clone(), + crate::RequestOptions::default(), + $self.channel_hint, + ) .send() .await? } else { @@ -612,6 +631,7 @@ macro_rules! execute_stream_with_retry { $self.client.clone(), $self.session_name.clone(), $operation_variant($request), + $self.channel_hint, )) }}; } diff --git a/src/spanner/src/read_write_transaction.rs b/src/spanner/src/read_write_transaction.rs index fd3c7b28cc..b3e9e8e310 100644 --- a/src/spanner/src/read_write_transaction.rs +++ b/src/spanner/src/read_write_transaction.rs @@ -111,10 +111,11 @@ impl ReadWriteTransactionBuilder { } // TODO(#4972): make request options configurable + let channel_hint = self.client.spanner.next_channel_hint(); let response = self .client .spanner - .begin_transaction(request, RequestOptions::default()) + .begin_transaction(request, RequestOptions::default(), channel_hint) .await?; let transaction_selector = @@ -129,6 +130,7 @@ impl ReadWriteTransactionBuilder { transaction_selector, precommit_token_tracker: PrecommitTokenTracker::new(), transaction_tag: self.transaction_tag.clone(), + channel_hint, }, seqno: Arc::new(AtomicI64::new(1)), max_commit_delay: self.max_commit_delay, @@ -176,7 +178,11 @@ impl ReadWriteTransaction { .context .client .spanner - .execute_sql(request, RequestOptions::default()) + .execute_sql( + request, + RequestOptions::default(), + self.context.channel_hint, + ) .await?; self.context .precommit_token_tracker @@ -280,7 +286,11 @@ impl ReadWriteTransaction { .context .client .spanner - .execute_batch_dml(request, RequestOptions::default()) + .execute_batch_dml( + request, + RequestOptions::default(), + self.context.channel_hint, + ) .await; match response_result { @@ -316,7 +326,11 @@ impl ReadWriteTransaction { .context .client .spanner - .commit(request, RequestOptions::default()) + .commit( + request, + RequestOptions::default(), + self.context.channel_hint, + ) .await?; let response = @@ -330,7 +344,11 @@ impl ReadWriteTransaction { self.context .client .spanner - .commit(retry_commit_req, RequestOptions::default()) + .commit( + retry_commit_req, + RequestOptions::default(), + self.context.channel_hint, + ) .await? } else { response @@ -353,7 +371,11 @@ impl ReadWriteTransaction { self.context .client .spanner - .rollback(request, RequestOptions::default()) + .rollback( + request, + RequestOptions::default(), + self.context.channel_hint, + ) .await?; Ok(()) @@ -368,6 +390,7 @@ mod tests { use gaxi::grpc::tonic; use spanner_grpc_mock::google::spanner::v1; use std::fmt::Debug; + use std::sync::Mutex; #[test] fn auto_traits() { @@ -376,23 +399,36 @@ mod tests { } #[tokio::test] - async fn read_write_transaction_commit_retry() { + async fn read_write_transaction_commit_retry() -> anyhow::Result<()> { let mut mock = create_session_mock(); - - mock.expect_begin_transaction().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.session, - "projects/p/instances/i/databases/d/sessions/123" - ); - Ok(tonic::Response::new(v1::Transaction { - id: vec![0, 0, 7], - ..Default::default() - })) - }); + let remotes = Arc::new(Mutex::new(Vec::new())); + + let remotes_clone = remotes.clone(); + mock.expect_begin_transaction() + .once() + .returning(move |req| { + remotes_clone + .lock() + .unwrap() + .push(req.remote_addr().expect("remote_addr should be available")); + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + Ok(tonic::Response::new(v1::Transaction { + id: vec![0, 0, 7], + ..Default::default() + })) + }); // execute_update returns a precommit token. - mock.expect_execute_sql().once().returning(|req| { + let remotes_clone = remotes.clone(); + mock.expect_execute_sql().once().returning(move |req| { + remotes_clone + .lock() + .unwrap() + .push(req.remote_addr().expect("remote_addr should be available")); let req = req.into_inner(); assert_eq!(req.sql, "UPDATE Users SET Name = 'Bob' WHERE Id = 1"); Ok(tonic::Response::new(v1::ResultSet { @@ -411,7 +447,12 @@ mod tests { // Simulate that commit returns a precommit token in the response. // This would normally not happen, but we test it here to verify // that the commit is retried. - mock.expect_commit().once().returning(|req| { + let remotes_clone = remotes.clone(); + mock.expect_commit().once().returning(move |req| { + remotes_clone + .lock() + .unwrap() + .push(req.remote_addr().expect("remote_addr should be available")); let req = req.into_inner(); assert_eq!( req.precommit_token, @@ -438,7 +479,12 @@ mod tests { }); // Second commit retry is automatically issued with the new token - mock.expect_commit().once().returning(|req| { + let remotes_clone = remotes.clone(); + mock.expect_commit().once().returning(move |req| { + remotes_clone + .lock() + .unwrap() + .push(req.remote_addr().expect("remote_addr should be available")); let req = req.into_inner(); assert_eq!( req.precommit_token, @@ -460,17 +506,25 @@ mod tests { let tx = ReadWriteTransactionBuilder::new(db_client.clone()) .begin_transaction() - .await - .expect("Failed to build transaction"); + .await?; let count = tx .execute_update("UPDATE Users SET Name = 'Bob' WHERE Id = 1") - .await - .unwrap(); + .await?; assert_eq!(count, 1); - let timestamp = tx.commit().await.unwrap(); + let timestamp = tx.commit().await?; assert_eq!(timestamp.seconds(), 1001); + + // Verify that all RPCs used the same channel (same remote address) + let remotes = remotes.lock().unwrap(); + assert_eq!(remotes.len(), 4, "Expected exactly 4 RPCs"); + let first = remotes[0]; + for addr in remotes.iter() { + assert_eq!(*addr, first, "All RPCs should use the same gRPC channel"); + } + + Ok(()) } #[tokio::test] diff --git a/src/spanner/src/result_set.rs b/src/spanner/src/result_set.rs index 8832fb6209..6e88dc9746 100644 --- a/src/spanner/src/result_set.rs +++ b/src/spanner/src/result_set.rs @@ -64,6 +64,7 @@ pub struct ResultSet { max_buffered_partial_result_sets: usize, retry_count: usize, transaction_selector: Option, + channel_hint: usize, } /// Errors that can occur when interacting with a [`ResultSet`]. @@ -95,6 +96,7 @@ impl ResultSet { client: DatabaseClient, session_name: String, operation: StreamOperation, + channel_hint: usize, ) -> Self { Self { stream: Some(stream), @@ -113,6 +115,7 @@ impl ResultSet { max_buffered_partial_result_sets: MAX_BUFFERED_PARTIAL_RESULT_SETS, retry_count: 0, transaction_selector, + channel_hint, stats: None, } } @@ -311,7 +314,7 @@ impl ResultSet { self.transaction_selector .as_ref() .unwrap() - .begin_explicitly(&self.client, self.session_name.clone()) + .begin_explicitly(&self.client, self.session_name.clone(), self.channel_hint) .await?; self.partial_result_sets_buffer.clear(); @@ -461,7 +464,11 @@ impl ResultSet { let stream = self .client .spanner - .execute_streaming_sql(req.clone(), crate::RequestOptions::default()) + .execute_streaming_sql( + req.clone(), + crate::RequestOptions::default(), + self.channel_hint, + ) .send() .await?; self.stream = Some(stream); @@ -474,7 +481,11 @@ impl ResultSet { let stream = self .client .spanner - .streaming_read(req.clone(), crate::RequestOptions::default()) + .streaming_read( + req.clone(), + crate::RequestOptions::default(), + self.channel_hint, + ) .send() .await?; self.stream = Some(stream); diff --git a/src/spanner/src/session_maintainer.rs b/src/spanner/src/session_maintainer.rs index ff32116662..d03ca001c2 100644 --- a/src/spanner/src/session_maintainer.rs +++ b/src/spanner/src/session_maintainer.rs @@ -134,7 +134,9 @@ impl ManagedSessionMaintainer { .set_creator_role(database_role), ); - spanner.create_session(request, options.clone()).await + spanner + .create_session(request, options.clone(), spanner.next_channel_hint()) + .await } async fn maintenance_loop( diff --git a/src/spanner/src/write_only_transaction.rs b/src/spanner/src/write_only_transaction.rs index 5615fec052..8a42ad78fe 100644 --- a/src/spanner/src/write_only_transaction.rs +++ b/src/spanner/src/write_only_transaction.rs @@ -221,6 +221,7 @@ impl WriteOnlyTransaction { let client = self.client; let session_name = self.session_name.clone(); let previous_transaction_id = Arc::new(Mutex::new(Bytes::new())); + let channel_hint = client.spanner.next_channel_hint(); retry_aborted(&*self.retry_policy, || { let client = client.clone(); @@ -250,7 +251,7 @@ impl WriteOnlyTransaction { let tx = client .spanner - .begin_transaction(begin_req, crate::RequestOptions::default()) + .begin_transaction(begin_req, crate::RequestOptions::default(), channel_hint) .await?; *previous_transaction_id.lock().unwrap() = tx.id.clone(); @@ -264,7 +265,7 @@ impl WriteOnlyTransaction { let response = client .spanner - .commit(commit_req, crate::RequestOptions::default()) + .commit(commit_req, crate::RequestOptions::default(), channel_hint) .await?; // If a commit_response with a precommit_token is returned, then we need to @@ -277,7 +278,11 @@ impl WriteOnlyTransaction { .set_precommit_token(new_token); client .spanner - .commit(retry_commit_req, crate::RequestOptions::default()) + .commit( + retry_commit_req, + crate::RequestOptions::default(), + channel_hint, + ) .await } else { Ok(response) @@ -336,6 +341,7 @@ impl WriteOnlyTransaction { .set_request_options(req_options) .set_or_clear_max_commit_delay(self.max_commit_delay); let client = self.client; + let channel_hint = client.spanner.next_channel_hint(); retry_aborted(&*self.retry_policy, || { let client = client.clone(); @@ -344,7 +350,7 @@ impl WriteOnlyTransaction { async move { client .spanner - .commit(request, crate::RequestOptions::default()) + .commit(request, crate::RequestOptions::default(), channel_hint) .await } })