diff --git a/src/database/connection.rs b/src/database/connection.rs index 096fc77c3..52736aa14 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -165,7 +165,7 @@ pub struct TransactionOptions { } /// Spawn database transaction -#[async_trait::async_trait] +#[allow(async_fn_in_trait)] pub trait TransactionTrait { /// The concrete type for the transaction type Transaction: ConnectionTrait + TransactionTrait + TransactionSession; @@ -193,10 +193,7 @@ pub trait TransactionTrait { /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. async fn transaction(&self, callback: F) -> Result> where - F: for<'c> FnOnce( - &'c Self::Transaction, - ) -> Pin> + Send + 'c>> - + Send, + F: for<'c> AsyncFnOnce(&'c Self::Transaction) -> Result + Send, T: Send, E: std::fmt::Display + std::fmt::Debug + Send; @@ -209,10 +206,7 @@ pub trait TransactionTrait { access_mode: Option, ) -> Result> where - F: for<'c> FnOnce( - &'c Self::Transaction, - ) -> Pin> + Send + 'c>> - + Send, + F: for<'c> AsyncFnOnce(&'c Self::Transaction) -> Result + Send, T: Send, E: std::fmt::Display + std::fmt::Debug + Send; } diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index f83fccd63..d6da98b83 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -343,7 +343,6 @@ impl StreamTrait for DatabaseConnection { } } -#[async_trait::async_trait] impl TransactionTrait for DatabaseConnection { type Transaction = DatabaseTransaction; @@ -453,10 +452,7 @@ impl TransactionTrait for DatabaseConnection { #[instrument(level = "trace", skip(_callback))] async fn transaction(&self, _callback: F) -> Result> where - F: for<'c> FnOnce( - &'c DatabaseTransaction, - ) -> Pin> + Send + 'c>> - + Send, + F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result + Send, T: Send, E: std::fmt::Display + std::fmt::Debug + Send, { @@ -505,10 +501,7 @@ impl TransactionTrait for DatabaseConnection { _access_mode: Option, ) -> Result> where - F: for<'c> FnOnce( - &'c DatabaseTransaction, - ) -> Pin> + Send + 'c>> - + Send, + F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result + Send, T: Send, E: std::fmt::Display + std::fmt::Debug + Send, { diff --git a/src/database/executor.rs b/src/database/executor.rs index 3505d84a3..705f45e74 100644 --- a/src/database/executor.rs +++ b/src/database/executor.rs @@ -4,8 +4,6 @@ use crate::{ TransactionTrait, }; use crate::{Schema, SchemaBuilder}; -use std::future::Future; -use std::pin::Pin; /// A wrapper that holds either a reference to a [`DatabaseConnection`] or [`DatabaseTransaction`], /// or an owned [`DatabaseTransaction`]. @@ -74,7 +72,6 @@ impl ConnectionTrait for DatabaseExecutor<'_> { } } -#[async_trait::async_trait] impl TransactionTrait for DatabaseExecutor<'_> { type Transaction = DatabaseTransaction; @@ -117,10 +114,7 @@ impl TransactionTrait for DatabaseExecutor<'_> { async fn transaction(&self, callback: F) -> Result> where - F: for<'c> FnOnce( - &'c DatabaseTransaction, - ) -> Pin> + Send + 'c>> - + Send, + F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result + Send, T: Send, E: std::fmt::Display + std::fmt::Debug + Send, { @@ -138,10 +132,7 @@ impl TransactionTrait for DatabaseExecutor<'_> { access_mode: Option, ) -> Result> where - F: for<'c> FnOnce( - &'c DatabaseTransaction, - ) -> Pin> + Send + 'c>> - + Send, + F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result + Send, T: Send, E: std::fmt::Display + std::fmt::Debug + Send, { diff --git a/src/database/mock.rs b/src/database/mock.rs index 263000bf2..7f9a874bd 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -450,13 +450,11 @@ mod tests { async fn test_transaction_1() { let db = MockDatabase::new(DbBackend::Postgres).into_connection(); - db.transaction::<_, (), DbErr>(|txn| { - Box::pin(async move { - let _1 = cake::Entity::find().one(txn).await; - let _2 = fruit::Entity::find().all(txn).await; + db.transaction::<_, (), DbErr>(async |txn| { + let _1 = cake::Entity::find().one(txn).await; + let _2 = fruit::Entity::find().all(txn).await; - Ok(()) - }) + Ok(()) }) .await .unwrap(); @@ -494,11 +492,9 @@ mod tests { let db = MockDatabase::new(DbBackend::Postgres).into_connection(); let result = db - .transaction::<_, (), MyErr>(|txn| { - Box::pin(async move { - let _ = cake::Entity::find().one(txn).await; - Err(MyErr("test".to_owned())) - }) + .transaction::<_, (), MyErr>(async |txn| { + let _ = cake::Entity::find().one(txn).await; + Err(MyErr("test".to_owned())) }) .await; @@ -527,22 +523,18 @@ mod tests { async fn test_nested_transaction_1() { let db = MockDatabase::new(DbBackend::Postgres).into_connection(); - db.transaction::<_, (), DbErr>(|txn| { - Box::pin(async move { - let _ = cake::Entity::find().one(txn).await; + db.transaction::<_, (), DbErr>(async |txn| { + let _ = cake::Entity::find().one(txn).await; - txn.transaction::<_, (), DbErr>(|txn| { - Box::pin(async move { - let _ = fruit::Entity::find().all(txn).await; - - Ok(()) - }) - }) - .await - .unwrap(); + txn.transaction::<_, (), DbErr>(async |txn| { + let _ = fruit::Entity::find().all(txn).await; Ok(()) }) + .await + .unwrap(); + + Ok(()) }) .await .unwrap(); @@ -572,32 +564,26 @@ mod tests { async fn test_nested_transaction_2() { let db = MockDatabase::new(DbBackend::Postgres).into_connection(); - db.transaction::<_, (), DbErr>(|txn| { - Box::pin(async move { - let _ = cake::Entity::find().one(txn).await; + db.transaction::<_, (), DbErr>(async |txn| { + let _ = cake::Entity::find().one(txn).await; - txn.transaction::<_, (), DbErr>(|txn| { - Box::pin(async move { - let _ = fruit::Entity::find().all(txn).await; + txn.transaction::<_, (), DbErr>(async |txn| { + let _ = fruit::Entity::find().all(txn).await; - txn.transaction::<_, (), DbErr>(|txn| { - Box::pin(async move { - let _ = cake::Entity::find().all(txn).await; + txn.transaction::<_, (), DbErr>(async |txn| { + let _ = cake::Entity::find().all(txn).await; - Ok(()) - }) - }) - .await - .unwrap(); - - Ok(()) - }) + Ok(()) }) .await .unwrap(); Ok(()) }) + .await + .unwrap(); + + Ok(()) }) .await .unwrap(); diff --git a/src/database/restricted_connection.rs b/src/database/restricted_connection.rs index 5d4cb5457..13e7c9da4 100644 --- a/src/database/restricted_connection.rs +++ b/src/database/restricted_connection.rs @@ -210,7 +210,6 @@ impl RestrictedTransaction { } } -#[async_trait::async_trait] impl TransactionTrait for RestrictedConnection { type Transaction = RestrictedTransaction; @@ -256,10 +255,7 @@ impl TransactionTrait for RestrictedConnection { #[instrument(level = "trace", skip(callback))] async fn transaction(&self, callback: F) -> Result> where - F: for<'c> FnOnce( - &'c RestrictedTransaction, - ) -> Pin> + Send + 'c>> - + Send, + F: for<'c> AsyncFnOnce(&'c RestrictedTransaction) -> Result + Send, T: Send, E: std::fmt::Display + std::fmt::Debug + Send, { @@ -277,10 +273,7 @@ impl TransactionTrait for RestrictedConnection { access_mode: Option, ) -> Result> where - F: for<'c> FnOnce( - &'c RestrictedTransaction, - ) -> Pin> + Send + 'c>> - + Send, + F: for<'c> AsyncFnOnce(&'c RestrictedTransaction) -> Result + Send, T: Send, E: std::fmt::Display + std::fmt::Debug + Send, { @@ -292,7 +285,6 @@ impl TransactionTrait for RestrictedConnection { } } -#[async_trait::async_trait] impl TransactionTrait for RestrictedTransaction { type Transaction = RestrictedTransaction; @@ -338,10 +330,7 @@ impl TransactionTrait for RestrictedTransaction { #[instrument(level = "trace", skip(callback))] async fn transaction(&self, callback: F) -> Result> where - F: for<'c> FnOnce( - &'c RestrictedTransaction, - ) -> Pin> + Send + 'c>> - + Send, + F: for<'c> AsyncFnOnce(&'c RestrictedTransaction) -> Result + Send, T: Send, E: std::fmt::Display + std::fmt::Debug + Send, { @@ -359,10 +348,7 @@ impl TransactionTrait for RestrictedTransaction { access_mode: Option, ) -> Result> where - F: for<'c> FnOnce( - &'c RestrictedTransaction, - ) -> Pin> + Send + 'c>> - + Send, + F: for<'c> AsyncFnOnce(&'c RestrictedTransaction) -> Result + Send, T: Send, E: std::fmt::Display + std::fmt::Debug + Send, { @@ -391,10 +377,7 @@ impl RestrictedTransaction { #[instrument(level = "trace", skip(callback))] async fn run(self, callback: F) -> Result> where - F: for<'b> FnOnce( - &'b RestrictedTransaction, - ) -> Pin> + Send + 'b>> - + Send, + F: for<'b> AsyncFnOnce(&'b RestrictedTransaction) -> Result + Send, T: Send, E: std::fmt::Display + std::fmt::Debug + Send, { diff --git a/src/database/transaction.rs b/src/database/transaction.rs index d1be4128b..4b68f4400 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -137,10 +137,7 @@ impl DatabaseTransaction { #[instrument(level = "trace", skip(callback))] pub(crate) async fn run(self, callback: F) -> Result> where - F: for<'b> FnOnce( - &'b DatabaseTransaction, - ) -> Pin> + Send + 'b>> - + Send, + F: for<'b> AsyncFnOnce(&'b DatabaseTransaction) -> Result + Send, T: Send, E: std::fmt::Display + std::fmt::Debug + Send, { @@ -607,7 +604,6 @@ impl StreamTrait for DatabaseTransaction { } } -#[async_trait::async_trait] impl TransactionTrait for DatabaseTransaction { type Transaction = DatabaseTransaction; @@ -666,10 +662,7 @@ impl TransactionTrait for DatabaseTransaction { #[instrument(level = "trace", skip(_callback))] async fn transaction(&self, _callback: F) -> Result> where - F: for<'c> FnOnce( - &'c DatabaseTransaction, - ) -> Pin> + Send + 'c>> - + Send, + F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result + Send, T: Send, E: std::fmt::Display + std::fmt::Debug + Send, { @@ -688,10 +681,7 @@ impl TransactionTrait for DatabaseTransaction { access_mode: Option, ) -> Result> where - F: for<'c> FnOnce( - &'c DatabaseTransaction, - ) -> Pin> + Send + 'c>> - + Send, + F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result + Send, T: Send, E: std::fmt::Display + std::fmt::Debug + Send, { diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index ea9c5340d..7d3982c17 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -229,10 +229,7 @@ impl SqlxMySqlPoolConnection { access_mode: Option, ) -> Result> where - F: for<'b> FnOnce( - &'b DatabaseTransaction, - ) -> Pin> + Send + 'b>> - + Send, + F: for<'b> AsyncFnOnce(&'b DatabaseTransaction) -> Result + Send, T: Send, E: std::fmt::Display + std::fmt::Debug + Send, { diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index e0f3e6d62..7d28bb565 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -267,10 +267,7 @@ impl SqlxPostgresPoolConnection { access_mode: Option, ) -> Result> where - F: for<'b> FnOnce( - &'b DatabaseTransaction, - ) -> Pin> + Send + 'b>> - + Send, + F: for<'b> AsyncFnOnce(&'b DatabaseTransaction) -> Result + Send, T: Send, E: std::fmt::Display + std::fmt::Debug + Send, { diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index b0fc51008..b2b7bbafb 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -249,10 +249,7 @@ impl SqlxSqlitePoolConnection { access_mode: Option, ) -> Result> where - F: for<'b> FnOnce( - &'b DatabaseTransaction, - ) -> Pin> + Send + 'b>> - + Send, + F: for<'b> AsyncFnOnce(&'b DatabaseTransaction) -> Result + Send, T: Send, E: std::fmt::Display + std::fmt::Debug + Send, { diff --git a/tests/rbac_tests.rs b/tests/rbac_tests.rs index 58b0f1e5b..969f0f2de 100644 --- a/tests/rbac_tests.rs +++ b/tests/rbac_tests.rs @@ -169,20 +169,18 @@ async fn crud_tests(db: &DbConn) -> Result<(), DbErr> { .await .expect("insert succeeds"); - db.transaction::<_, _, DbErr>(|txn| { - Box::pin(async move { - cake::Entity::insert(cake::ActiveModel { - name: Set("Chocolate".to_owned()), - price: Set(3.into()), - bakery_id: Set(Some(1)), - gluten_free: Set(true), - ..Default::default() - }) - .exec(txn) - .await?; - - Ok(()) + db.transaction::<_, _, DbErr>(async |txn| { + cake::Entity::insert(cake::ActiveModel { + name: Set("Chocolate".to_owned()), + price: Set(3.into()), + bakery_id: Set(Some(1)), + gluten_free: Set(true), + ..Default::default() }) + .exec(txn) + .await?; + + Ok(()) }) .await .expect("insert succeeds"); diff --git a/tests/transaction_tests.rs b/tests/transaction_tests.rs index b90bf4c7d..f02944e0b 100644 --- a/tests/transaction_tests.rs +++ b/tests/transaction_tests.rs @@ -9,12 +9,6 @@ use sea_orm::{ TransactionOptions, TransactionTrait, prelude::*, }; -#[cfg(not(feature = "sync"))] -type FutureResult<'a> = - std::pin::Pin> + Send + 'a>>; -#[cfg(feature = "sync")] -type FutureResult<'a> = Result<(), DbErr>; - fn seaside_bakery() -> bakery::ActiveModel { bakery::ActiveModel { name: Set("SeaSide Bakery".to_owned()), @@ -37,20 +31,18 @@ pub async fn transaction() { create_bakery_table(&ctx.db).await.unwrap(); ctx.db - .transaction::<_, _, DbErr>(|txn| { - Box::pin(async move { - let _ = seaside_bakery().save(txn).await?; - let _ = top_bakery().save(txn).await?; + .transaction::<_, _, DbErr>(async |txn| { + let _ = seaside_bakery().save(txn).await?; + let _ = top_bakery().save(txn).await?; - let bakeries = Bakery::find() - .filter(bakery::Column::Name.contains("Bakery")) - .all(txn) - .await?; + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; - assert_eq!(bakeries.len(), 2); + assert_eq!(bakeries.len(), 2); - Ok(()) - }) + Ok(()) }) .await .unwrap(); @@ -71,20 +63,18 @@ pub async fn rbac_transaction() { )); let db = ctx.db.restricted_for(RbacUserId(0)).unwrap(); - db.transaction::<_, _, DbErr>(|txn| { - Box::pin(async move { - let _ = seaside_bakery().save(txn).await?; - let _ = top_bakery().save(txn).await?; + db.transaction::<_, _, DbErr>(async |txn| { + let _ = seaside_bakery().save(txn).await?; + let _ = top_bakery().save(txn).await?; - let bakeries = Bakery::find() - .filter(bakery::Column::Name.contains("Bakery")) - .all(txn) - .await?; + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; - assert_eq!(bakeries.len(), 2); + assert_eq!(bakeries.len(), 2); - Ok(()) - }) + Ok(()) }) .await .unwrap(); @@ -101,45 +91,43 @@ pub async fn transaction_with_reference() { let name2 = "Top Bakery"; let search_name = "Bakery"; ctx.db - .transaction(|txn| _transaction_with_reference(txn, name1, name2, search_name)) + .transaction(async |txn| _transaction_with_reference(txn, name1, name2, search_name).await) .await .unwrap(); ctx.delete().await; } -fn _transaction_with_reference<'a>( - txn: &'a DatabaseTransaction, - name1: &'a str, - name2: &'a str, - search_name: &'a str, -) -> FutureResult<'a> { - Box::pin(async move { - let _ = bakery::ActiveModel { - name: Set(name1.to_owned()), - profit_margin: Set(10.4), - ..Default::default() - } - .save(txn) - .await?; +async fn _transaction_with_reference( + txn: &DatabaseTransaction, + name1: &str, + name2: &str, + search_name: &str, +) -> Result<(), DbErr> { + let _ = bakery::ActiveModel { + name: Set(name1.to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .save(txn) + .await?; - let _ = bakery::ActiveModel { - name: Set(name2.to_owned()), - profit_margin: Set(15.0), - ..Default::default() - } - .save(txn) - .await?; + let _ = bakery::ActiveModel { + name: Set(name2.to_owned()), + profit_margin: Set(15.0), + ..Default::default() + } + .save(txn) + .await?; - let bakeries = Bakery::find() - .filter(bakery::Column::Name.contains(search_name)) - .all(txn) - .await?; + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains(search_name)) + .all(txn) + .await?; - assert_eq!(bakeries.len(), 2); + assert_eq!(bakeries.len(), 2); - Ok(()) - }) + Ok(()) } #[sea_orm_macros::test] @@ -310,18 +298,16 @@ pub async fn transaction_closure_commit() -> Result<(), DbErr> { let res = ctx .db - .transaction::<_, _, DbErr>(|txn| { - Box::pin(async move { - seaside_bakery().save(txn).await?; + .transaction::<_, _, DbErr>(async |txn| { + seaside_bakery().save(txn).await?; - assert_eq!(bakery::Entity::find().all(txn).await?.len(), 1); + assert_eq!(bakery::Entity::find().all(txn).await?.len(), 1); - top_bakery().save(txn).await?; + top_bakery().save(txn).await?; - assert_eq!(bakery::Entity::find().all(txn).await?.len(), 2); + assert_eq!(bakery::Entity::find().all(txn).await?.len(), 2); - Ok(()) - }) + Ok(()) }) .await; @@ -342,30 +328,28 @@ pub async fn transaction_closure_rollback() -> Result<(), DbErr> { let res = ctx .db - .transaction::<_, _, DbErr>(|txn| { - Box::pin(async move { - seaside_bakery().save(txn).await?; + .transaction::<_, _, DbErr>(async |txn| { + seaside_bakery().save(txn).await?; - assert_eq!(bakery::Entity::find().all(txn).await?.len(), 1); + assert_eq!(bakery::Entity::find().all(txn).await?.len(), 1); - top_bakery().save(txn).await?; + top_bakery().save(txn).await?; - assert_eq!(bakery::Entity::find().all(txn).await?.len(), 2); + assert_eq!(bakery::Entity::find().all(txn).await?.len(), 2); - bakery::ActiveModel { - id: Set(1), - name: Set("Duplicated primary key".to_owned()), - profit_margin: Set(20.0), - } - .insert(txn) - .await?; // Throw error and rollback + bakery::ActiveModel { + id: Set(1), + name: Set("Duplicated primary key".to_owned()), + profit_margin: Set(20.0), + } + .insert(txn) + .await?; // Throw error and rollback - // This line won't be reached - unreachable!(); + // This line won't be reached + unreachable!(); - #[allow(unreachable_code)] - Ok(()) - }) + #[allow(unreachable_code)] + Ok(()) }) .await; @@ -467,18 +451,34 @@ pub async fn transaction_nested() { create_bakery_table(&ctx.db).await.unwrap(); ctx.db - .transaction::<_, _, DbErr>(|txn| { - Box::pin(async move { - let _ = seaside_bakery().save(txn).await?; + .transaction::<_, _, DbErr>(async |txn| { + let _ = seaside_bakery().save(txn).await?; + + let _ = top_bakery().save(txn).await?; + + // Try nested transaction committed + txn.transaction::<_, _, DbErr>(async |txn| { + let _ = bakery::ActiveModel { + name: Set("Nested Bakery".to_owned()), + profit_margin: Set(88.88), + ..Default::default() + } + .save(txn) + .await?; - let _ = top_bakery().save(txn).await?; + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 3); - // Try nested transaction committed - txn.transaction::<_, _, DbErr>(|txn| { - Box::pin(async move { + // Try nested-nested transaction rollbacked + let is_err = txn + .transaction::<_, _, DbErr>(async |txn| { let _ = bakery::ActiveModel { - name: Set("Nested Bakery".to_owned()), - profit_margin: Set(88.88), + name: Set("Rock n Roll Bakery".to_owned()), + profit_margin: Set(28.8), ..Default::default() } .save(txn) @@ -489,89 +489,112 @@ pub async fn transaction_nested() { .all(txn) .await?; - assert_eq!(bakeries.len(), 3); - - // Try nested-nested transaction rollbacked - let is_err = txn - .transaction::<_, _, DbErr>(|txn| { - Box::pin(async move { - let _ = bakery::ActiveModel { - name: Set("Rock n Roll Bakery".to_owned()), - profit_margin: Set(28.8), - ..Default::default() - } - .save(txn) - .await?; - - let bakeries = Bakery::find() - .filter(bakery::Column::Name.contains("Bakery")) - .all(txn) - .await?; - - assert_eq!(bakeries.len(), 4); - - if true { - Err(DbErr::Query(RuntimeErr::Internal( - "Force Rollback!".to_owned(), - ))) - } else { - Ok(()) - } - }) - }) - .await - .is_err(); - - assert!(is_err); + assert_eq!(bakeries.len(), 4); + + if true { + Err(DbErr::Query(RuntimeErr::Internal( + "Force Rollback!".to_owned(), + ))) + } else { + Ok(()) + } + }) + .await + .is_err(); - let bakeries = Bakery::find() - .filter(bakery::Column::Name.contains("Bakery")) - .all(txn) - .await?; + assert!(is_err); - assert_eq!(bakeries.len(), 3); - - // Try nested-nested transaction committed - txn.transaction::<_, _, DbErr>(|txn| { - Box::pin(async move { - let _ = bakery::ActiveModel { - name: Set("Rock n Roll Bakery".to_owned()), - profit_margin: Set(28.8), - ..Default::default() - } - .save(txn) - .await?; + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 3); + + // Try nested-nested transaction committed + txn.transaction::<_, _, DbErr>(async |txn| { + let _ = bakery::ActiveModel { + name: Set("Rock n Roll Bakery".to_owned()), + profit_margin: Set(28.8), + ..Default::default() + } + .save(txn) + .await?; - let bakeries = Bakery::find() - .filter(bakery::Column::Name.contains("Bakery")) - .all(txn) - .await?; + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; - assert_eq!(bakeries.len(), 4); + assert_eq!(bakeries.len(), 4); - Ok(()) - }) - }) - .await - .unwrap(); + Ok(()) + }) + .await + .unwrap(); + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 4); + + Ok(()) + }) + .await + .unwrap(); + + // Try nested transaction rollbacked + let is_err = txn + .transaction::<_, _, DbErr>(async |txn| { + let _ = bakery::ActiveModel { + name: Set("Rock n Roll Bakery".to_owned()), + profit_margin: Set(28.8), + ..Default::default() + } + .save(txn) + .await?; + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 5); + + // Try nested-nested transaction committed + txn.transaction::<_, _, DbErr>(async |txn| { + let _ = bakery::ActiveModel { + name: Set("Rock n Roll Bakery".to_owned()), + profit_margin: Set(28.8), + ..Default::default() + } + .save(txn) + .await?; let bakeries = Bakery::find() .filter(bakery::Column::Name.contains("Bakery")) .all(txn) .await?; - assert_eq!(bakeries.len(), 4); + assert_eq!(bakeries.len(), 6); Ok(()) }) - }) - .await - .unwrap(); + .await + .unwrap(); - // Try nested transaction rollbacked - let is_err = txn - .transaction::<_, _, DbErr>(|txn| { - Box::pin(async move { + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 6); + + // Try nested-nested transaction rollbacked + let is_err = txn + .transaction::<_, _, DbErr>(async |txn| { let _ = bakery::ActiveModel { name: Set("Rock n Roll Bakery".to_owned()), profit_margin: Set(28.8), @@ -585,78 +608,7 @@ pub async fn transaction_nested() { .all(txn) .await?; - assert_eq!(bakeries.len(), 5); - - // Try nested-nested transaction committed - txn.transaction::<_, _, DbErr>(|txn| { - Box::pin(async move { - let _ = bakery::ActiveModel { - name: Set("Rock n Roll Bakery".to_owned()), - profit_margin: Set(28.8), - ..Default::default() - } - .save(txn) - .await?; - - let bakeries = Bakery::find() - .filter(bakery::Column::Name.contains("Bakery")) - .all(txn) - .await?; - - assert_eq!(bakeries.len(), 6); - - Ok(()) - }) - }) - .await - .unwrap(); - - let bakeries = Bakery::find() - .filter(bakery::Column::Name.contains("Bakery")) - .all(txn) - .await?; - - assert_eq!(bakeries.len(), 6); - - // Try nested-nested transaction rollbacked - let is_err = txn - .transaction::<_, _, DbErr>(|txn| { - Box::pin(async move { - let _ = bakery::ActiveModel { - name: Set("Rock n Roll Bakery".to_owned()), - profit_margin: Set(28.8), - ..Default::default() - } - .save(txn) - .await?; - - let bakeries = Bakery::find() - .filter(bakery::Column::Name.contains("Bakery")) - .all(txn) - .await?; - - assert_eq!(bakeries.len(), 7); - - if true { - Err(DbErr::Query(RuntimeErr::Internal( - "Force Rollback!".to_owned(), - ))) - } else { - Ok(()) - } - }) - }) - .await - .is_err(); - - assert!(is_err); - - let bakeries = Bakery::find() - .filter(bakery::Column::Name.contains("Bakery")) - .all(txn) - .await?; - - assert_eq!(bakeries.len(), 6); + assert_eq!(bakeries.len(), 7); if true { Err(DbErr::Query(RuntimeErr::Internal( @@ -666,21 +618,39 @@ pub async fn transaction_nested() { Ok(()) } }) - }) - .await - .is_err(); + .await + .is_err(); - assert!(is_err); + assert!(is_err); - let bakeries = Bakery::find() - .filter(bakery::Column::Name.contains("Bakery")) - .all(txn) - .await?; + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; - assert_eq!(bakeries.len(), 4); + assert_eq!(bakeries.len(), 6); - Ok(()) - }) + if true { + Err(DbErr::Query(RuntimeErr::Internal( + "Force Rollback!".to_owned(), + ))) + } else { + Ok(()) + } + }) + .await + .is_err(); + + assert!(is_err); + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 4); + + Ok(()) }) .await .unwrap(); @@ -853,63 +823,46 @@ pub async fn rbac_transaction_nested() { )); let db = ctx.db.restricted_for(RbacUserId(0)).unwrap(); - db.transaction::<_, _, DbErr>(|txn| { - Box::pin(async move { - let _ = seaside_bakery().save(txn).await?; + db.transaction::<_, _, DbErr>(async |txn| { + let _ = seaside_bakery().save(txn).await?; - let _ = top_bakery().save(txn).await?; + let _ = top_bakery().save(txn).await?; - // Try nested transaction committed - txn.transaction::<_, _, DbErr>(|txn| { - Box::pin(async move { - let _ = bakery::ActiveModel { - name: Set("Nested Bakery".to_owned()), - profit_margin: Set(88.88), - ..Default::default() - } - .save(txn) - .await?; - - let bakeries = Bakery::find() - .filter(bakery::Column::Name.contains("Bakery")) - .all(txn) - .await?; - - assert_eq!(bakeries.len(), 3); - - // Try nested-nested transaction committed - txn.transaction::<_, _, DbErr>(|txn| { - Box::pin(async move { - let _ = bakery::ActiveModel { - name: Set("Rock n Roll Bakery".to_owned()), - profit_margin: Set(28.8), - ..Default::default() - } - .save(txn) - .await?; + // Try nested transaction committed + txn.transaction::<_, _, DbErr>(async |txn| { + let _ = bakery::ActiveModel { + name: Set("Nested Bakery".to_owned()), + profit_margin: Set(88.88), + ..Default::default() + } + .save(txn) + .await?; - let bakeries = Bakery::find() - .filter(bakery::Column::Name.contains("Bakery")) - .all(txn) - .await?; + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; - assert_eq!(bakeries.len(), 4); + assert_eq!(bakeries.len(), 3); - Ok(()) - }) - }) - .await - .unwrap(); + // Try nested-nested transaction committed + txn.transaction::<_, _, DbErr>(async |txn| { + let _ = bakery::ActiveModel { + name: Set("Rock n Roll Bakery".to_owned()), + profit_margin: Set(28.8), + ..Default::default() + } + .save(txn) + .await?; - let bakeries = Bakery::find() - .filter(bakery::Column::Name.contains("Bakery")) - .all(txn) - .await?; + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; - assert_eq!(bakeries.len(), 4); + assert_eq!(bakeries.len(), 4); - Ok(()) - }) + Ok(()) }) .await .unwrap(); @@ -923,6 +876,17 @@ pub async fn rbac_transaction_nested() { Ok(()) }) + .await + .unwrap(); + + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; + + assert_eq!(bakeries.len(), 4); + + Ok(()) }) .await .unwrap(); @@ -957,7 +921,7 @@ pub async fn transaction_with_config() { let search_name = format!("Bakery {}", i); ctx.db .transaction_with_config( - |txn| _transaction_with_config(txn, name1, name2, search_name), + async |txn| _transaction_with_config(txn, name1, name2, search_name).await, Some(isolation_level), access_mode, ) @@ -967,17 +931,15 @@ pub async fn transaction_with_config() { ctx.db .transaction_with_config::<_, _, DbErr>( - |txn| { - Box::pin(async move { - let bakeries = Bakery::find() - .filter(bakery::Column::Name.contains("Bakery")) - .all(txn) - .await?; + async |txn| { + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains("Bakery")) + .all(txn) + .await?; - assert_eq!(bakeries.len(), 8); + assert_eq!(bakeries.len(), 8); - Ok(()) - }) + Ok(()) }, None, Some(AccessMode::ReadOnly), @@ -988,38 +950,36 @@ pub async fn transaction_with_config() { ctx.delete().await; } -fn _transaction_with_config<'a>( - txn: &'a DatabaseTransaction, +async fn _transaction_with_config( + txn: &DatabaseTransaction, name1: String, name2: String, search_name: String, -) -> FutureResult<'a> { - Box::pin(async move { - let _ = bakery::ActiveModel { - name: Set(name1), - profit_margin: Set(10.4), - ..Default::default() - } - .save(txn) - .await?; +) -> Result<(), DbErr> { + let _ = bakery::ActiveModel { + name: Set(name1), + profit_margin: Set(10.4), + ..Default::default() + } + .save(txn) + .await?; - let _ = bakery::ActiveModel { - name: Set(name2), - profit_margin: Set(15.0), - ..Default::default() - } - .save(txn) - .await?; + let _ = bakery::ActiveModel { + name: Set(name2), + profit_margin: Set(15.0), + ..Default::default() + } + .save(txn) + .await?; - let bakeries = Bakery::find() - .filter(bakery::Column::Name.contains(&search_name)) - .all(txn) - .await?; + let bakeries = Bakery::find() + .filter(bakery::Column::Name.contains(&search_name)) + .all(txn) + .await?; - assert_eq!(bakeries.len(), 2); + assert_eq!(bakeries.len(), 2); - Ok(()) - }) + Ok(()) } #[sea_orm_macros::test]