From 726eb82e7874dde6d0b851034727d84fffe18394 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Wed, 22 Apr 2026 16:28:26 +0200 Subject: [PATCH 1/2] chore(spanner): foundation for batch write transactions Adds the foundations for batch write transactions. The public API will be added in a follow-up pull request. --- src/spanner/src/batch_write_transaction.rs | 146 +++++++++++++++++++++ src/spanner/src/database_client.rs | 8 ++ src/spanner/src/lib.rs | 1 + src/spanner/src/mutation.rs | 88 +++++++++++++ 4 files changed, 243 insertions(+) create mode 100644 src/spanner/src/batch_write_transaction.rs diff --git a/src/spanner/src/batch_write_transaction.rs b/src/spanner/src/batch_write_transaction.rs new file mode 100644 index 0000000000..ed6973b943 --- /dev/null +++ b/src/spanner/src/batch_write_transaction.rs @@ -0,0 +1,146 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::client::DatabaseClient; +use crate::model::BatchWriteRequest; +use crate::mutation::MutationGroup; +use crate::server_streaming::stream::BatchWriteStream; + +/// A builder for [BatchWriteTransaction]. +#[allow(dead_code)] +pub struct BatchWriteTransactionBuilder { + client: DatabaseClient, +} + +impl BatchWriteTransactionBuilder { + pub(crate) fn new(client: DatabaseClient) -> Self { + Self { client } + } + + /// Builds the [BatchWriteTransaction]. + #[allow(dead_code)] + pub fn build(self) -> BatchWriteTransaction { + let session_name = self.client.session_name(); + BatchWriteTransaction { + session_name, + client: self.client, + } + } +} + +/// A transaction for executing batch writes. +/// +/// Batch writes are not guaranteed to be atomic across mutation groups. +/// All mutations within a group are applied atomically. +#[allow(dead_code)] +pub struct BatchWriteTransaction { + session_name: String, + client: DatabaseClient, +} + +impl BatchWriteTransaction { + /// Executes the batch write and returns a stream of responses. + #[allow(dead_code)] + pub(crate) async fn execute_streaming(self, groups: I) -> crate::Result + where + I: IntoIterator, + { + let req = BatchWriteRequest::new() + .set_session(self.session_name.clone()) + .set_mutation_groups(groups.into_iter().map(|g| g.build_proto())); + + self.client + .spanner + .batch_write(req, crate::RequestOptions::default()) + .send() + .await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::{Mutation, Spanner}; + use crate::result_set::tests::adapt; + use gaxi::grpc::tonic::Response; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + + pub(crate) async fn setup_db_client( + mock: MockSpanner, + ) -> (DatabaseClient, tokio::task::JoinHandle<()>) { + use google_cloud_auth::credentials::anonymous::Builder as Anonymous; + let (address, server) = spanner_grpc_mock::start("0.0.0.0:0", mock) + .await + .expect("Failed to start mock server"); + let spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await + .expect("Failed to build client"); + + let db_client = spanner + .database_client("projects/p/instances/i/databases/d") + .build() + .await + .expect("Failed to create DatabaseClient"); + + (db_client, server) + } + + #[tokio::test] + async fn test_execute_streaming() { + let mut mock = MockSpanner::new(); + mock.expect_create_session().returning(|_| { + Ok(Response::new(mock_v1::Session { + name: "projects/p/instances/i/databases/d/sessions/123".to_string(), + ..Default::default() + })) + }); + + mock.expect_batch_write().once().returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + assert_eq!(req.mutation_groups.len(), 1); + + let response = mock_v1::BatchWriteResponse { + indexes: vec![0], + status: None, + commit_timestamp: None, + }; + + Ok(Response::from(adapt([Ok(response)]))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + + let mutation = Mutation::new_insert_builder("Users") + .set("UserId") + .to(&1) + .build(); + let group = MutationGroup::new(vec![mutation]); + + let tx = db_client.batch_write_transaction().build(); + let mut stream = tx.execute_streaming(vec![group]).await.unwrap(); + + let result = stream.next_message().await; + assert!(result.is_some()); + let result = result.unwrap().unwrap(); + assert_eq!(result.indexes, vec![0]); + } +} diff --git a/src/spanner/src/database_client.rs b/src/spanner/src/database_client.rs index 92a6b50de8..c3467c3213 100644 --- a/src/spanner/src/database_client.rs +++ b/src/spanner/src/database_client.rs @@ -193,6 +193,14 @@ impl DatabaseClient { crate::write_only_transaction::WriteOnlyTransactionBuilder::new(self.clone()) } + /// Returns a builder for a batch write transaction. + #[allow(dead_code)] + pub(crate) fn batch_write_transaction( + &self, + ) -> crate::batch_write_transaction::BatchWriteTransactionBuilder { + crate::batch_write_transaction::BatchWriteTransactionBuilder::new(self.clone()) + } + pub(crate) fn session_name(&self) -> String { self.session_maintainer.session_name() } diff --git a/src/spanner/src/lib.rs b/src/spanner/src/lib.rs index e9ff6f4937..3541a0651e 100644 --- a/src/spanner/src/lib.rs +++ b/src/spanner/src/lib.rs @@ -44,6 +44,7 @@ pub mod batch_read_only_transaction; pub mod model { pub use crate::generated::gapic_dataplane::model::*; } +pub(crate) mod batch_write_transaction; pub(crate) mod database_client; pub(crate) mod from_value; pub(crate) mod key; diff --git a/src/spanner/src/mutation.rs b/src/spanner/src/mutation.rs index 9267772d5c..8df1bc093e 100644 --- a/src/spanner/src/mutation.rs +++ b/src/spanner/src/mutation.rs @@ -13,9 +13,12 @@ // limitations under the License. use crate::key::KeySet; +use crate::model::batch_write_request::MutationGroup as ProtoMutationGroup; use crate::model::mutation::Operation; use crate::to_value::ToValue; use crate::value::Value; +use std::slice::Iter; +use std::vec::IntoIter; /// Represents an individual table modification to be applied to Cloud Spanner. /// @@ -266,6 +269,42 @@ impl ValueBinder { } } +/// A group of mutations that are applied atomically in a [BatchWriteTransaction]. +#[derive(Clone, Debug, PartialEq)] +pub struct MutationGroup { + pub mutations: Vec, +} + +impl MutationGroup { + /// Creates a new mutation group from a list of mutations. + pub fn new(mutations: Vec) -> Self { + Self { mutations } + } + + #[allow(dead_code)] + pub(crate) fn build_proto(self) -> ProtoMutationGroup { + ProtoMutationGroup::new().set_mutations(self.mutations.into_iter().map(|m| m.build_proto())) + } +} + +impl IntoIterator for MutationGroup { + type Item = Mutation; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.mutations.into_iter() + } +} + +impl<'a> IntoIterator for &'a MutationGroup { + type Item = &'a Mutation; + type IntoIter = Iter<'a, Mutation>; + + fn into_iter(self) -> Self::IntoIter { + self.mutations.iter() + } +} + #[cfg(test)] mod tests { use super::*; @@ -277,6 +316,55 @@ mod tests { static_assertions::assert_impl_all!(Delete: Send, Sync, Clone, std::fmt::Debug); static_assertions::assert_impl_all!(WriteBuilder: Send, Sync); static_assertions::assert_impl_all!(ValueBinder: Send, Sync); + static_assertions::assert_impl_all!(MutationGroup: Send, Sync, Clone, std::fmt::Debug); + } + + #[test] + fn mutation_group() { + let mutation1 = Mutation::new_insert_builder("Users") + .set("UserId") + .to(&1) + .build(); + let mutation2 = Mutation::new_insert_builder("Users") + .set("UserId") + .to(&2) + .build(); + let group = MutationGroup::new(vec![mutation1.clone(), mutation2.clone()]); + assert_eq!(group.mutations.len(), 2); + assert_eq!(group.mutations[0], mutation1); + assert_eq!(group.mutations[1], mutation2); + } + + #[test] + fn mutation_group_into_iter() { + let mutation1 = Mutation::new_insert_builder("Users") + .set("UserId") + .to(&1) + .build(); + let mutation2 = Mutation::new_insert_builder("Users") + .set("UserId") + .to(&2) + .build(); + let group = MutationGroup::new(vec![mutation1.clone(), mutation2.clone()]); + + let mutations: Vec<_> = group.into_iter().collect(); + assert_eq!(mutations, vec![mutation1, mutation2]); + } + + #[test] + fn mutation_group_iter_ref() { + let mutation1 = Mutation::new_insert_builder("Users") + .set("UserId") + .to(&1) + .build(); + let mutation2 = Mutation::new_insert_builder("Users") + .set("UserId") + .to(&2) + .build(); + let group = MutationGroup::new(vec![mutation1.clone(), mutation2.clone()]); + + let mutations: Vec<_> = (&group).into_iter().collect(); + assert_eq!(mutations, vec![&mutation1, &mutation2]); } #[test] From e79d1e7414df900eb781cd1bd8a431097754c763 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Wed, 22 Apr 2026 18:00:09 +0200 Subject: [PATCH 2/2] feat(spanner): add BatchWrite transaction Adds support for BatchWrite transactions. BatchWrite can be used to write large numbers of mutations to Spanner. A BatchWrite transaction is not guaranteed to be atomic. However, each MutationGroup in a BatchWrite transaction is guaranteed to be atomic. This change adds a method that returns a stream of responses and delegates all error handling to the application. A follow-up pull request will add a managed execution method that handles errors automatically. --- Cargo.lock | 1 + src/spanner/src/batch_write_transaction.rs | 71 +++++++++-- src/spanner/src/client.rs | 2 +- src/spanner/src/database_client.rs | 51 +++++++- src/spanner/src/lib.rs | 3 + src/spanner/src/mutation.rs | 2 +- tests/spanner/Cargo.toml | 1 + tests/spanner/src/batch_write.rs | 132 +++++++++++++++++++++ tests/spanner/src/lib.rs | 1 + tests/spanner/tests/driver.rs | 16 ++- 10 files changed, 265 insertions(+), 15 deletions(-) create mode 100644 tests/spanner/src/batch_write.rs diff --git a/Cargo.lock b/Cargo.lock index d154a0f9c5..f8231028c2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5965,6 +5965,7 @@ dependencies = [ "google-cloud-spanner", "google-cloud-test-utils", "prost-types", + "rand 0.10.1", "reqwest 0.13.2", "serde_json", "spanner-grpc-mock", diff --git a/src/spanner/src/batch_write_transaction.rs b/src/spanner/src/batch_write_transaction.rs index fd25bb6a59..6ac314d64e 100644 --- a/src/spanner/src/batch_write_transaction.rs +++ b/src/spanner/src/batch_write_transaction.rs @@ -14,11 +14,12 @@ use crate::client::DatabaseClient; use crate::model::BatchWriteRequest; +use crate::model::BatchWriteResponse; use crate::mutation::MutationGroup; use crate::server_streaming::stream::BatchWriteStream; +use gaxi::prost::FromProto; /// A builder for [BatchWriteTransaction]. -#[allow(dead_code)] pub struct BatchWriteTransactionBuilder { client: DatabaseClient, } @@ -29,7 +30,6 @@ impl BatchWriteTransactionBuilder { } /// Builds the [BatchWriteTransaction]. - #[allow(dead_code)] pub fn build(self) -> BatchWriteTransaction { let session_name = self.client.session_name(); BatchWriteTransaction { @@ -43,7 +43,6 @@ impl BatchWriteTransactionBuilder { /// /// Batch writes are not guaranteed to be atomic across mutation groups. /// All mutations within a group are applied atomically. -#[allow(dead_code)] pub struct BatchWriteTransaction { session_name: String, client: DatabaseClient, @@ -51,8 +50,42 @@ pub struct BatchWriteTransaction { impl BatchWriteTransaction { /// Executes the batch write and returns a stream of responses. - #[allow(dead_code)] - pub(crate) async fn execute_streaming(self, groups: I) -> crate::Result + /// + /// # Example + /// ``` + /// # use google_cloud_spanner::client::{Mutation, Spanner, MutationGroup}; + /// # use google_cloud_gax::error::rpc::Code; + /// # async fn sample() -> Result<(), Box> { + /// let client = Spanner::builder().build().await?; + /// let db = client.database_client("projects/p/instances/i/databases/d").build().await?; + /// + /// let mutation = Mutation::new_insert_builder("Users") + /// .set("UserId").to(&1) + /// .build(); + /// let group = MutationGroup::new(vec![mutation]); + /// + /// let tx = db.batch_write_transaction().build(); + /// let mut stream = tx.execute_streaming(vec![group]).await?; + /// + /// while let Some(response) = stream.next_message().await { + /// let response = response?; + /// if let Some(status) = response.status.as_ref().filter(|s| s.code != Code::Ok as i32) { + /// eprintln!("Error applying groups {:?}: {}", response.indexes, status.message); + /// } else { + /// println!("Applied groups: {:?}", response.indexes); + /// } + /// } + /// # Ok(()) + /// # } + /// ``` + /// + /// This method sends the mutation groups to Spanner and returns the responses as a stream. + /// Each response includes a status code that indicates whether the mutation groups that + /// it references were applied successfully. + /// The method does not handle any errors, including retryable errors like Aborted. + /// The caller is responsible for handling any errors and for retrying the transaction in + /// case it is aborted by Spanner. + pub async fn execute_streaming(self, groups: I) -> crate::Result where I: IntoIterator, { @@ -60,11 +93,35 @@ impl BatchWriteTransaction { .set_session(self.session_name.clone()) .set_mutation_groups(groups.into_iter().map(|g| g.build_proto())); - self.client + let stream = self + .client .spanner .batch_write(req, crate::RequestOptions::default()) .send() - .await + .await?; + Ok(BatchWriteResponseStream { inner: stream }) + } +} + +/// A stream of [BatchWriteResponse] messages. +pub struct BatchWriteResponseStream { + pub(crate) inner: BatchWriteStream, +} + +impl BatchWriteResponseStream { + /// Fetches the next [BatchWriteResponse] from the stream. + /// + /// Returns `Some(Ok(BatchWriteResponse))` when a message is successfully received, + /// `None` when the stream concludes naturally, or `Some(Err(_))` on RPC errors. + pub async fn next_message(&mut self) -> Option> { + let proto_opt = self.inner.next_message().await?; + match proto_opt { + Ok(proto) => match proto.cnv() { + Ok(model) => Some(Ok(model)), + Err(e) => Some(Err(crate::Error::deser(e))), + }, + Err(e) => Some(Err(e)), + } } } diff --git a/src/spanner/src/client.rs b/src/spanner/src/client.rs index f4ff2d4f82..e267af399a 100644 --- a/src/spanner/src/client.rs +++ b/src/spanner/src/client.rs @@ -20,7 +20,7 @@ pub use crate::database_client::DatabaseClient; pub use crate::error::SpannerInternalError; pub use crate::from_value::{ConvertError, FromValue}; pub use crate::key::{Key, KeyRange, KeySet, KeySetBuilder}; -pub use crate::mutation::{Mutation, ValueBinder, WriteBuilder}; +pub use crate::mutation::{Mutation, MutationGroup, ValueBinder, WriteBuilder}; pub use crate::read::ConfiguredReadRequestBuilder; pub use crate::read::ReadRequest; pub use crate::read::ReadRequestBuilder; diff --git a/src/spanner/src/database_client.rs b/src/spanner/src/database_client.rs index c3467c3213..02f7e104cd 100644 --- a/src/spanner/src/database_client.rs +++ b/src/spanner/src/database_client.rs @@ -13,6 +13,7 @@ // limitations under the License. use crate::batch_read_only_transaction::BatchReadOnlyTransactionBuilder; +use crate::batch_write_transaction::BatchWriteTransactionBuilder; use crate::client::Spanner; use crate::partitioned_dml_transaction::PartitionedDmlTransactionBuilder; use crate::read_only_transaction::{ @@ -194,11 +195,51 @@ impl DatabaseClient { } /// Returns a builder for a batch write transaction. - #[allow(dead_code)] - pub(crate) fn batch_write_transaction( - &self, - ) -> crate::batch_write_transaction::BatchWriteTransactionBuilder { - crate::batch_write_transaction::BatchWriteTransactionBuilder::new(self.clone()) + /// + /// # Example + /// ``` + /// # use google_cloud_spanner::client::{Spanner, Mutation, MutationGroup}; + /// # use google_cloud_gax::error::rpc::Code; + /// # async fn sample() -> Result<(), Box> { + /// let client = Spanner::builder().build().await?; + /// let db = client.database_client("projects/p/instances/i/databases/d").build().await?; + /// + /// let mutation1a = Mutation::new_insert_builder("Users") + /// .set("UserId").to(&1) + /// .build(); + /// let mutation1b = Mutation::new_insert_builder("UserRoles") + /// .set("UserId").to(&1) + /// .set("Role").to(&"Admin") + /// .build(); + /// let group1 = MutationGroup::new(vec![mutation1a, mutation1b]); + /// + /// let mutation2 = Mutation::new_insert_builder("Users") + /// .set("UserId").to(&2) + /// .build(); + /// let group2 = MutationGroup::new(vec![mutation2]); + /// + /// let transaction = db.batch_write_transaction().build(); + /// let mut stream = transaction.execute_streaming(vec![group1, group2]).await?; + /// + /// while let Some(response) = stream.next_message().await { + /// let response = response?; + /// if let Some(status) = response.status.as_ref().filter(|s| s.code != Code::Ok as i32) { + /// eprintln!("Error applying groups {:?}: {}", response.indexes, status.message); + /// } else { + /// println!("Applied groups: {:?}", response.indexes); + /// } + /// } + /// # Ok(()) + /// # } + /// ``` + /// + /// A batch write transaction is used to execute non-atomic writes using mutations. + /// Related mutations should be placed in a group. For example, two mutations inserting + /// rows with the same primary key prefix in both parent and child tables are related. + /// All mutations within a group are applied atomically, but the entire batch is not + /// guaranteed to be atomic. + pub fn batch_write_transaction(&self) -> BatchWriteTransactionBuilder { + BatchWriteTransactionBuilder::new(self.clone()) } pub(crate) fn session_name(&self) -> String { diff --git a/src/spanner/src/lib.rs b/src/spanner/src/lib.rs index 3541a0651e..329f7f9df1 100644 --- a/src/spanner/src/lib.rs +++ b/src/spanner/src/lib.rs @@ -25,6 +25,9 @@ pub use batch_dml::BatchDmlBuilder; pub use batch_read_only_transaction::{ BatchReadOnlyTransaction, BatchReadOnlyTransactionBuilder, Partition, }; +pub use batch_write_transaction::{ + BatchWriteResponseStream, BatchWriteTransaction, BatchWriteTransactionBuilder, +}; pub use error::BatchUpdateError; pub use google_cloud_gax::Result; pub use google_cloud_gax::error::Error; diff --git a/src/spanner/src/mutation.rs b/src/spanner/src/mutation.rs index 8df1bc093e..940512652b 100644 --- a/src/spanner/src/mutation.rs +++ b/src/spanner/src/mutation.rs @@ -269,7 +269,7 @@ impl ValueBinder { } } -/// A group of mutations that are applied atomically in a [BatchWriteTransaction]. +/// A group of mutations that are applied atomically in a [crate::BatchWriteTransaction]. #[derive(Clone, Debug, PartialEq)] pub struct MutationGroup { pub mutations: Vec, diff --git a/tests/spanner/Cargo.toml b/tests/spanner/Cargo.toml index 30b6625f58..75cdbe4f22 100644 --- a/tests/spanner/Cargo.toml +++ b/tests/spanner/Cargo.toml @@ -34,6 +34,7 @@ google-cloud-lro = { workspace = true } google-cloud-spanner = { workspace = true, features = ["unstable-stream"] } google-cloud-test-utils = { workspace = true } prost-types.workspace = true +rand = { workspace = true } reqwest = { workspace = true, features = ["json"] } serde_json = { workspace = true } spanner-grpc-mock = { path = "../../src/spanner/grpc-mock" } diff --git a/tests/spanner/src/batch_write.rs b/tests/spanner/src/batch_write.rs new file mode 100644 index 0000000000..8fc252d52d --- /dev/null +++ b/tests/spanner/src/batch_write.rs @@ -0,0 +1,132 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::Result; +use google_cloud_gax::error::rpc::Code; +use google_cloud_spanner::client::MutationGroup; +use google_cloud_spanner::client::{DatabaseClient, Mutation, Statement}; +use google_cloud_test_utils::resource_names::LowercaseAlphanumeric; +use std::time::Duration; +use tokio::time::sleep; + +pub async fn batch_write(db_client: &DatabaseClient) -> Result<()> { + let id1 = format!("batch-write1-{}", LowercaseAlphanumeric.random_string(10)); + let id2 = format!("batch-write2-{}", LowercaseAlphanumeric.random_string(10)); + + let m1 = Mutation::new_insert_or_update_builder("AllTypes") + .set("Id") + .to(&id1) + .set("ColString") + .to(&"batch-write-1".to_string()) + .build(); + + let m2 = Mutation::new_insert_or_update_builder("AllTypes") + .set("Id") + .to(&id2) + .set("ColString") + .to(&"batch-write-2".to_string()) + .build(); + + let group1 = MutationGroup::new(vec![m1]); + let group2 = MutationGroup::new(vec![m2]); + + let mut attempts = 0; + const MAX_ATTEMPTS: u32 = 5; + + let mut seen_indexes = Vec::new(); + loop { + attempts += 1; + let transaction = db_client.batch_write_transaction().build(); + let mut stream = match transaction + .execute_streaming(vec![group1.clone(), group2.clone()]) + .await + { + Ok(s) => s, + Err(e) if e.status().map(|s| s.code) == Some(Code::Aborted) => { + if attempts >= MAX_ATTEMPTS { + anyhow::bail!( + "BatchWrite failed after {} attempts due to Aborted", + attempts + ); + } + sleep(Duration::from_millis(rand::random_range(10_u64..=50_u64))).await; + continue; + } + Err(e) => return Err(e.into()), + }; + + seen_indexes.clear(); + let mut aborted = false; + while let Some(response) = stream.next_message().await { + match response { + Ok(resp) => { + if let Some(status) = &resp.status { + if status.code == Code::Aborted as i32 { + aborted = true; + break; + } + assert_eq!( + status.code, + Code::Ok as i32, + "BatchWriteResponse status was not OK: {}", + status.message + ); + } + seen_indexes.extend(resp.indexes); + } + Err(e) if e.status().map(|s| s.code) == Some(Code::Aborted) => { + aborted = true; + break; + } + Err(e) => return Err(e.into()), + } + } + + if aborted { + if attempts >= MAX_ATTEMPTS { + anyhow::bail!( + "BatchWrite failed after {} attempts due to Aborted in stream", + attempts + ); + } + sleep(Duration::from_millis(rand::random_range(10_u64..=50_u64))).await; + continue; + } + + break; + } + + // Verify that all groups were applied. + assert!(seen_indexes.contains(&0)); + assert!(seen_indexes.contains(&1)); + + // Read back to verify + let read_tx = db_client.single_use().build(); + let stmt = + Statement::builder("SELECT ColString FROM AllTypes WHERE Id IN (@id1, @id2) ORDER BY Id") + .add_param("id1", &id1) + .add_param("id2", &id2) + .build(); + let mut rs = read_tx.execute_query(stmt).await?; + + let mut rows = Vec::new(); + while let Some(row) = rs.next().await { + rows.push(row?); + } + assert_eq!(rows.len(), 2, "Expected precisely 2 rows inserted/updated"); + assert_eq!(rows[0].get::("ColString"), "batch-write-1"); + assert_eq!(rows[1].get::("ColString"), "batch-write-2"); + + Ok(()) +} diff --git a/tests/spanner/src/lib.rs b/tests/spanner/src/lib.rs index 6e8370cc76..e6df86c948 100644 --- a/tests/spanner/src/lib.rs +++ b/tests/spanner/src/lib.rs @@ -13,6 +13,7 @@ // limitations under the License. pub mod batch_read_only_transaction; +pub mod batch_write; pub mod client; pub mod directed_read; pub mod partitioned_dml; diff --git a/tests/spanner/tests/driver.rs b/tests/spanner/tests/driver.rs index f56f3e8a96..8e1535e9a2 100644 --- a/tests/spanner/tests/driver.rs +++ b/tests/spanner/tests/driver.rs @@ -14,6 +14,8 @@ #[cfg(all(test, feature = "run-integration-tests"))] mod spanner { + use integration_tests_spanner::batch_write; + use integration_tests_spanner::client; #[tokio::test] async fn run_query_tests() -> anyhow::Result<()> { @@ -38,7 +40,7 @@ mod spanner { #[tokio::test] async fn run_write_tests() -> anyhow::Result<()> { - let db_client = match integration_tests_spanner::client::create_database_client().await { + let db_client = match client::create_database_client().await { Some(c) => c, None => return Ok(()), }; @@ -49,6 +51,18 @@ mod spanner { Ok(()) } + #[tokio::test] + async fn run_batch_write_tests() -> anyhow::Result<()> { + let db_client = match client::create_database_client().await { + Some(c) => c, + None => return Ok(()), + }; + + batch_write::batch_write(&db_client).await?; + + Ok(()) + } + #[tokio::test] async fn run_read_write_tests() -> anyhow::Result<()> { let db_client = match integration_tests_spanner::client::create_database_client().await {