Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions src/database/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ pub struct TransactionOptions {
}

/// Spawn database transaction
#[async_trait::async_trait]
#[allow(async_fn_in_trait)]
pub trait TransactionTrait {
/// The concrete type for the transaction
type Transaction: ConnectionTrait + TransactionTrait + TransactionSession;
Expand Down Expand Up @@ -193,10 +193,7 @@ pub trait TransactionTrait {
/// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
&'c Self::Transaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
+ Send,
F: for<'c> AsyncFnOnce(&'c Self::Transaction) -> Result<T, E> + Send,
T: Send,
E: std::fmt::Display + std::fmt::Debug + Send;

Expand All @@ -209,10 +206,7 @@ pub trait TransactionTrait {
access_mode: Option<AccessMode>,
) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
&'c Self::Transaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
+ Send,
F: for<'c> AsyncFnOnce(&'c Self::Transaction) -> Result<T, E> + Send,
T: Send,
E: std::fmt::Display + std::fmt::Debug + Send;
}
Expand Down
11 changes: 2 additions & 9 deletions src/database/db_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,6 @@ impl StreamTrait for DatabaseConnection {
}
}

#[async_trait::async_trait]
impl TransactionTrait for DatabaseConnection {
type Transaction = DatabaseTransaction;

Expand Down Expand Up @@ -453,10 +452,7 @@ impl TransactionTrait for DatabaseConnection {
#[instrument(level = "trace", skip(_callback))]
async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
&'c DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
+ Send,
F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result<T, E> + Send,
T: Send,
E: std::fmt::Display + std::fmt::Debug + Send,
{
Expand Down Expand Up @@ -505,10 +501,7 @@ impl TransactionTrait for DatabaseConnection {
_access_mode: Option<AccessMode>,
) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
&'c DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
+ Send,
F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result<T, E> + Send,
T: Send,
E: std::fmt::Display + std::fmt::Debug + Send,
{
Expand Down
13 changes: 2 additions & 11 deletions src/database/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ use crate::{
TransactionTrait,
};
use crate::{Schema, SchemaBuilder};
use std::future::Future;
use std::pin::Pin;

/// A wrapper that holds either a reference to a [`DatabaseConnection`] or [`DatabaseTransaction`],
/// or an owned [`DatabaseTransaction`].
Expand Down Expand Up @@ -74,7 +72,6 @@ impl ConnectionTrait for DatabaseExecutor<'_> {
}
}

#[async_trait::async_trait]
impl TransactionTrait for DatabaseExecutor<'_> {
type Transaction = DatabaseTransaction;

Expand Down Expand Up @@ -117,10 +114,7 @@ impl TransactionTrait for DatabaseExecutor<'_> {

async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
&'c DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
+ Send,
F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result<T, E> + Send,
T: Send,
E: std::fmt::Display + std::fmt::Debug + Send,
{
Expand All @@ -138,10 +132,7 @@ impl TransactionTrait for DatabaseExecutor<'_> {
access_mode: Option<AccessMode>,
) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
&'c DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
+ Send,
F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result<T, E> + Send,
T: Send,
E: std::fmt::Display + std::fmt::Debug + Send,
{
Expand Down
66 changes: 26 additions & 40 deletions src/database/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,13 +450,11 @@ mod tests {
async fn test_transaction_1() {
let db = MockDatabase::new(DbBackend::Postgres).into_connection();

db.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
let _1 = cake::Entity::find().one(txn).await;
let _2 = fruit::Entity::find().all(txn).await;
db.transaction::<_, (), DbErr>(async |txn| {
let _1 = cake::Entity::find().one(txn).await;
let _2 = fruit::Entity::find().all(txn).await;

Ok(())
})
Ok(())
})
.await
.unwrap();
Expand Down Expand Up @@ -494,11 +492,9 @@ mod tests {
let db = MockDatabase::new(DbBackend::Postgres).into_connection();

let result = db
.transaction::<_, (), MyErr>(|txn| {
Box::pin(async move {
let _ = cake::Entity::find().one(txn).await;
Err(MyErr("test".to_owned()))
})
.transaction::<_, (), MyErr>(async |txn| {
let _ = cake::Entity::find().one(txn).await;
Err(MyErr("test".to_owned()))
})
.await;

Expand Down Expand Up @@ -527,22 +523,18 @@ mod tests {
async fn test_nested_transaction_1() {
let db = MockDatabase::new(DbBackend::Postgres).into_connection();

db.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
let _ = cake::Entity::find().one(txn).await;
db.transaction::<_, (), DbErr>(async |txn| {
let _ = cake::Entity::find().one(txn).await;

txn.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
let _ = fruit::Entity::find().all(txn).await;

Ok(())
})
})
.await
.unwrap();
txn.transaction::<_, (), DbErr>(async |txn| {
let _ = fruit::Entity::find().all(txn).await;

Ok(())
})
.await
.unwrap();

Ok(())
})
.await
.unwrap();
Expand Down Expand Up @@ -572,32 +564,26 @@ mod tests {
async fn test_nested_transaction_2() {
let db = MockDatabase::new(DbBackend::Postgres).into_connection();

db.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
let _ = cake::Entity::find().one(txn).await;
db.transaction::<_, (), DbErr>(async |txn| {
let _ = cake::Entity::find().one(txn).await;

txn.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
let _ = fruit::Entity::find().all(txn).await;
txn.transaction::<_, (), DbErr>(async |txn| {
let _ = fruit::Entity::find().all(txn).await;

txn.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
let _ = cake::Entity::find().all(txn).await;
txn.transaction::<_, (), DbErr>(async |txn| {
let _ = cake::Entity::find().all(txn).await;

Ok(())
})
})
.await
.unwrap();

Ok(())
})
Ok(())
})
.await
.unwrap();

Ok(())
})
.await
.unwrap();

Ok(())
})
.await
.unwrap();
Expand Down
27 changes: 5 additions & 22 deletions src/database/restricted_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ impl RestrictedTransaction {
}
}

#[async_trait::async_trait]
impl TransactionTrait for RestrictedConnection {
type Transaction = RestrictedTransaction;

Expand Down Expand Up @@ -256,10 +255,7 @@ impl TransactionTrait for RestrictedConnection {
#[instrument(level = "trace", skip(callback))]
async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
&'c RestrictedTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
+ Send,
F: for<'c> AsyncFnOnce(&'c RestrictedTransaction) -> Result<T, E> + Send,
T: Send,
E: std::fmt::Display + std::fmt::Debug + Send,
{
Expand All @@ -277,10 +273,7 @@ impl TransactionTrait for RestrictedConnection {
access_mode: Option<AccessMode>,
) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
&'c RestrictedTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
+ Send,
F: for<'c> AsyncFnOnce(&'c RestrictedTransaction) -> Result<T, E> + Send,
T: Send,
E: std::fmt::Display + std::fmt::Debug + Send,
{
Expand All @@ -292,7 +285,6 @@ impl TransactionTrait for RestrictedConnection {
}
}

#[async_trait::async_trait]
impl TransactionTrait for RestrictedTransaction {
type Transaction = RestrictedTransaction;

Expand Down Expand Up @@ -338,10 +330,7 @@ impl TransactionTrait for RestrictedTransaction {
#[instrument(level = "trace", skip(callback))]
async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
&'c RestrictedTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
+ Send,
F: for<'c> AsyncFnOnce(&'c RestrictedTransaction) -> Result<T, E> + Send,
T: Send,
E: std::fmt::Display + std::fmt::Debug + Send,
{
Expand All @@ -359,10 +348,7 @@ impl TransactionTrait for RestrictedTransaction {
access_mode: Option<AccessMode>,
) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
&'c RestrictedTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
+ Send,
F: for<'c> AsyncFnOnce(&'c RestrictedTransaction) -> Result<T, E> + Send,
T: Send,
E: std::fmt::Display + std::fmt::Debug + Send,
{
Expand Down Expand Up @@ -391,10 +377,7 @@ impl RestrictedTransaction {
#[instrument(level = "trace", skip(callback))]
async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'b> FnOnce(
&'b RestrictedTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
+ Send,
F: for<'b> AsyncFnOnce(&'b RestrictedTransaction) -> Result<T, E> + Send,
T: Send,
E: std::fmt::Display + std::fmt::Debug + Send,
{
Expand Down
16 changes: 3 additions & 13 deletions src/database/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,7 @@ impl DatabaseTransaction {
#[instrument(level = "trace", skip(callback))]
pub(crate) async fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'b> FnOnce(
&'b DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
+ Send,
F: for<'b> AsyncFnOnce(&'b DatabaseTransaction) -> Result<T, E> + Send,
T: Send,
E: std::fmt::Display + std::fmt::Debug + Send,
{
Expand Down Expand Up @@ -607,7 +604,6 @@ impl StreamTrait for DatabaseTransaction {
}
}

#[async_trait::async_trait]
impl TransactionTrait for DatabaseTransaction {
type Transaction = DatabaseTransaction;

Expand Down Expand Up @@ -666,10 +662,7 @@ impl TransactionTrait for DatabaseTransaction {
#[instrument(level = "trace", skip(_callback))]
async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
&'c DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
+ Send,
F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result<T, E> + Send,
T: Send,
E: std::fmt::Display + std::fmt::Debug + Send,
{
Expand All @@ -688,10 +681,7 @@ impl TransactionTrait for DatabaseTransaction {
access_mode: Option<AccessMode>,
) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
&'c DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
+ Send,
F: for<'c> AsyncFnOnce(&'c DatabaseTransaction) -> Result<T, E> + Send,
T: Send,
E: std::fmt::Display + std::fmt::Debug + Send,
{
Expand Down
5 changes: 1 addition & 4 deletions src/driver/sqlx_mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,7 @@ impl SqlxMySqlPoolConnection {
access_mode: Option<AccessMode>,
) -> Result<T, TransactionError<E>>
where
F: for<'b> FnOnce(
&'b DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
+ Send,
F: for<'b> AsyncFnOnce(&'b DatabaseTransaction) -> Result<T, E> + Send,
T: Send,
E: std::fmt::Display + std::fmt::Debug + Send,
{
Expand Down
5 changes: 1 addition & 4 deletions src/driver/sqlx_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,7 @@ impl SqlxPostgresPoolConnection {
access_mode: Option<AccessMode>,
) -> Result<T, TransactionError<E>>
where
F: for<'b> FnOnce(
&'b DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
+ Send,
F: for<'b> AsyncFnOnce(&'b DatabaseTransaction) -> Result<T, E> + Send,
T: Send,
E: std::fmt::Display + std::fmt::Debug + Send,
{
Expand Down
5 changes: 1 addition & 4 deletions src/driver/sqlx_sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,7 @@ impl SqlxSqlitePoolConnection {
access_mode: Option<AccessMode>,
) -> Result<T, TransactionError<E>>
where
F: for<'b> FnOnce(
&'b DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
+ Send,
F: for<'b> AsyncFnOnce(&'b DatabaseTransaction) -> Result<T, E> + Send,
T: Send,
E: std::fmt::Display + std::fmt::Debug + Send,
{
Expand Down
Loading