diff --git a/quantara/web_app/db/crud/airdrop.py b/quantara/web_app/db/crud/airdrop.py index 0473cfb08..30c77c35a 100644 --- a/quantara/web_app/db/crud/airdrop.py +++ b/quantara/web_app/db/crud/airdrop.py @@ -18,23 +18,28 @@ ModelType = TypeVar("ModelType", bound=Base) -class AirDropDBConnector(DBConnector): +class AirDropDBConnector: """ Provides database connection and operations management for the AirDrop model. """ + def __init__(self, db_connector: DBConnector = None): + from web_app.db.database import db_connector as default_db_connector + + self.db_connector = db_connector or default_db_connector + def save_claim_data(self, airdrop_id: uuid.UUID, amount: Decimal) -> None: """ Updates the AirDrop instance with claim data. :param airdrop_id: uuid.UUID :param amount: Decimal """ - airdrop = self.get_object(AirDrop, airdrop_id) + airdrop = self.db_connector.get_object(AirDrop, airdrop_id) if airdrop: airdrop.amount = amount airdrop.is_claimed = True airdrop.claimed_at = datetime.now() - self.write_to_db(airdrop) + self.db_connector.write_to_db(airdrop) else: logger.error(f"AirDrop with ID {airdrop_id} not found") @@ -44,7 +49,7 @@ def get_all_unclaimed(self) -> list[AirDrop]: :return: List of unclaimed AirDrop instances """ - with self.Session() as db: + with self.db_connector.Session() as db: try: unclaimed_instances = ( db.query(AirDrop).filter_by(is_claimed=False).all() @@ -61,7 +66,7 @@ def delete_all_users_airdrop(self, user_id: uuid.UUID) -> None: Delete all airdrops for a user. :param user_id: User ID """ - with self.Session() as db: + with self.db_connector.Session() as db: try: db.query(AirDrop).filter_by(user_id=user_id).delete( synchronize_session=False diff --git a/quantara/web_app/db/crud/base.py b/quantara/web_app/db/crud/base.py index 705852b9b..c50873dcb 100644 --- a/quantara/web_app/db/crud/base.py +++ b/quantara/web_app/db/crud/base.py @@ -28,12 +28,13 @@ class DBConnector: - remove_object: Removes an object by its ID from the database. """ - def __init__(self, db_url: str = SQLALCHEMY_DATABASE_URL): + def __init__(self, db_url: str = SQLALCHEMY_DATABASE_URL, engine=None): """ Initialize the database connection and session factory. :param db_url: str = None + :param engine: sqlalchemy.engine.Engine = None """ - self.engine = create_engine(db_url) + self.engine = engine or create_engine(db_url) self.session_factory = sessionmaker(bind=self.engine) self.Session = scoped_session(self.session_factory) diff --git a/quantara/web_app/db/crud/deposit.py b/quantara/web_app/db/crud/deposit.py index c8cb361ae..309004ef6 100644 --- a/quantara/web_app/db/crud/deposit.py +++ b/quantara/web_app/db/crud/deposit.py @@ -14,11 +14,16 @@ ModelType = TypeVar("ModelType", bound=Base) -class DepositDBConnector(DBConnector): +class DepositDBConnector: """ Provides database connection and operations management for the Vault model. """ + def __init__(self, db_connector: DBConnector = None): + from web_app.db.database import db_connector as default_db_connector + + self.db_connector = db_connector or default_db_connector + def create_vault(self, user: User, symbol: str, amount: str) -> Vault: """ Creates a new vault instance @@ -30,7 +35,7 @@ def create_vault(self, user: User, symbol: str, amount: str) -> Vault: :return: Vault """ vault = Vault(user_id=user.id, symbol=symbol, amount=amount) - self.write_to_db(vault) + self.db_connector.write_to_db(vault) return vault def get_vault(self, wallet_id: str, symbol: str) -> Vault | None: @@ -42,8 +47,8 @@ def get_vault(self, wallet_id: str, symbol: str) -> Vault | None: :return: Vault or None """ - with self.Session() as db: - user = self.get_object_by_field(User, "wallet_id", wallet_id) + with self.db_connector.Session() as db: + user = self.db_connector.get_object_by_field(User, "wallet_id", wallet_id) if not user: logger.error(f"User with wallet id {wallet_id} not found") return None @@ -63,7 +68,7 @@ def add_vault_balance(self, wallet_id: str, symbol: str, amount: str) -> Vault: vault = self.get_vault(wallet_id, symbol) if not vault: raise ValueError("Vault not found") - with self.Session() as db: + with self.db_connector.Session() as db: new_amount = Decimal(vault.amount) + Decimal(amount) db.query(Vault).filter_by(id=vault.id).update(amount=str(new_amount)) db.commit() diff --git a/quantara/web_app/db/crud/leaderboard.py b/quantara/web_app/db/crud/leaderboard.py index d12cdce0c..b27b2bf92 100644 --- a/quantara/web_app/db/crud/leaderboard.py +++ b/quantara/web_app/db/crud/leaderboard.py @@ -13,18 +13,23 @@ logger = logging.getLogger(__name__) -class LeaderboardDBConnector(DBConnector): +class LeaderboardDBConnector: """ Provides database connection and operations management using SQLAlchemy in a FastAPI application context. """ + def __init__(self, db_connector: DBConnector = None): + from web_app.db.database import db_connector as default_db_connector + + self.db_connector = db_connector or default_db_connector + def get_top_users_by_positions(self) -> list[dict]: """ Retrieves the top 10 users ordered by closed/opened positions. :return: List of dictionaries containing wallet_id and positions_number. """ - with self.Session() as db: + with self.db_connector.Session() as db: try: results = ( db.query( @@ -53,7 +58,7 @@ def get_position_token_statistics(self) -> list[dict]: Retrieves closed/opened positions groupped by token_symbol. :return: List of dictionaries containing token_symbol and total_positions. """ - with self.Session() as db: + with self.db_connector.Session() as db: try: results = ( db.query( diff --git a/quantara/web_app/db/crud/position.py b/quantara/web_app/db/crud/position.py index cc10277a7..368f09c4b 100644 --- a/quantara/web_app/db/crud/position.py +++ b/quantara/web_app/db/crud/position.py @@ -15,17 +15,28 @@ from web_app.db.models import Base, ExtraDeposit, Position, Status, Transaction, User +from .base import DBConnector from .user import UserDBConnector logger = logging.getLogger(__name__) ModelType = TypeVar("ModelType", bound=Base) -class PositionDBConnector(UserDBConnector): +class PositionDBConnector: """ Provides database connection and operations management for the Position model. """ + def __init__( + self, + db_connector: DBConnector = None, + user_db_connector: UserDBConnector = None, + ): + from web_app.db.database import db_connector as default_db_connector + + self.db_connector = db_connector or default_db_connector + self.user_db_connector = user_db_connector or UserDBConnector(self.db_connector) + START_PRICE = 0.0 @staticmethod @@ -63,7 +74,7 @@ def _get_user_by_wallet_id(self, wallet_id: str) -> User | None: :param wallet_id: str :return: User | None """ - return self.get_user_by_wallet_id(wallet_id) + return self.user_db_connector.get_user_by_wallet_id(wallet_id) def get_positions_by_wallet_id( self, wallet_id: str, start: int = 0, limit: int = 10 @@ -76,7 +87,7 @@ def get_positions_by_wallet_id( :param limit: number of records to return :return: list of dict """ - with self.Session() as db: + with self.db_connector.Session() as db: user = self._get_user_by_wallet_id(wallet_id) if not user: return [] @@ -115,7 +126,7 @@ def get_all_positions_by_wallet_id( :param limit: number of records to return :return: list of dict """ - with self.Session() as db: + with self.db_connector.Session() as db: user = self._get_user_by_wallet_id(wallet_id) if not user: return [] @@ -146,7 +157,7 @@ def get_count_positions_by_wallet_id(self, wallet_id: str) -> int: :param wallet_id: Wallet ID of the user :return: Total number of positions """ - with self.Session() as db: + with self.db_connector.Session() as db: user = self._get_user_by_wallet_id(wallet_id) if not user: return 0 @@ -169,7 +180,7 @@ def has_opened_position(self, wallet_id: str) -> bool: :param wallet_id: str :return: bool """ - with self.Session() as db: + with self.db_connector.Session() as db: user = self._get_user_by_wallet_id(wallet_id) if not user: return False @@ -207,7 +218,7 @@ def create_position( return None # Check if a position with status 'pending' already exists for this user - with self.Session() as session: + with self.db_connector.Session() as session: existing_position = ( session.query(Position) .filter( @@ -235,7 +246,7 @@ def create_position( start_price=PositionDBConnector.START_PRICE, ) - position = self.write_to_db(position) + position = self.db_connector.write_to_db(position) return position def get_position_id_by_wallet_id(self, wallet_id: str) -> str | None: @@ -259,7 +270,7 @@ def update_position(self, position: Position, amount: str, multiplier: int) -> N """ position.amount = amount position.multiplier = multiplier - self.write_to_db(position) + self.db_connector.write_to_db(position) def delete_position(self, position: Position) -> None: """ @@ -267,7 +278,7 @@ def delete_position(self, position: Position) -> None: :param position: Position :return: None """ - self.delete_object_by_id(Position, position.id) + self.db_connector.delete_object_by_id(Position, position.id) def close_position(self, position_id: uuid) -> Position | None: """ @@ -275,11 +286,11 @@ def close_position(self, position_id: uuid) -> Position | None: :param position_id: str :return: Position | None """ - position = self.get_object(Position, position_id) + position = self.db_connector.get_object(Position, position_id) if position: position.status = Status.CLOSED.value position.closed_at = datetime.now() - self.write_to_db(position) + self.db_connector.write_to_db(position) return position.status def open_position(self, position_id: uuid.UUID, current_prices: dict) -> str | None: @@ -289,11 +300,11 @@ def open_position(self, position_id: uuid.UUID, current_prices: dict) -> str | N :param current_prices: dict :return: str | None """ - position = self.get_object(Position, position_id) + position = self.db_connector.get_object(Position, position_id) if position: position.status = Status.OPENED.value - self.write_to_db(position) - self.create_empty_claim(position.user_id) + self.db_connector.write_to_db(position) + self.db_connector.create_empty_claim(position.user_id) self.save_current_price(position, current_prices) return position.status else: @@ -306,7 +317,7 @@ def get_repay_data(self, wallet_id: str) -> tuple: :param wallet_id: :return: contract_address, position_id, token_symbol """ - with self.Session() as db: + with self.db_connector.Session() as db: result = ( db.query(User.contract_address, Position.id, Position.token_symbol) .join(Position, Position.user_id == User.id) @@ -329,7 +340,7 @@ def get_total_amounts_for_open_positions(self) -> dict[str, Decimal]: :return: Dictionary of total amounts for each token in opened positions """ - with self.Session() as db: + with self.db_connector.Session() as db: try: token_amounts = ( db.query( @@ -355,7 +366,7 @@ def save_current_price(self, position: Position, price_dict: dict) -> None: start_price = price_dict.get(position.token_symbol) try: position.start_price = start_price - self.write_to_db(position) + self.db_connector.write_to_db(position) except SQLAlchemyError as e: logger.error(f"Error while saving current_price for position: {e}") @@ -379,7 +390,7 @@ def save_transaction( status=status, transaction_hash=transaction_hash, ) - return self.write_to_db(transaction) + return self.db_connector.write_to_db(transaction) except SQLAlchemyError as e: logger.error(f"Failed to save transaction: {str(e)}") return None @@ -392,7 +403,7 @@ def liquidate_position(self, position_id: UUID) -> bool: :param position_id: UUID of the position to be liquidated. :return: True if the update was successful, False otherwise. """ - with self.Session() as db: + with self.db_connector.Session() as db: try: position = db.query(Position).filter(Position.id == position_id).first() @@ -404,7 +415,7 @@ def liquidate_position(self, position_id: UUID) -> bool: position.is_liquidated = True position.datetime_liquidation = datetime.now() - self.write_to_db(position) + self.db_connector.write_to_db(position) logger.info(f"Position {position_id} successfully liquidated.") return True @@ -419,7 +430,7 @@ def get_all_liquidated_positions(self) -> list[dict]: :return: A list of dictionaries containing the liquidated positions. """ - with self.Session() as db: + with self.db_connector.Session() as db: try: liquidated_positions = ( db.query(Position) @@ -454,14 +465,14 @@ def get_position_by_id(self, position_id: UUID) -> Position | None: :param position_id: Position UUID :return: Position | None """ - return self.get_object(Position, position_id) + return self.db_connector.get_object(Position, position_id) def delete_all_user_positions(self, user_id: uuid.UUID) -> None: """ Deletes all positions for a user. :param user_id: User ID """ - with self.Session() as db: + with self.db_connector.Session() as db: try: positions = db.query(Position).filter_by(user_id=user_id).all() for position in positions: @@ -478,7 +489,7 @@ def add_extra_deposit_to_position( If the token already exists for this position, update its amount. Otherwise, create a new extra deposit entry. """ - with self.Session() as session: + with self.db_connector.Session() as session: session.execute( insert(ExtraDeposit) .values( @@ -501,7 +512,7 @@ def get_extra_deposits_data(self, position_id: UUID) -> dict[str, str]: :param position_id: UUID of the position :return: a dictionary of token_symbol: amount pairs. """ - with self.Session() as db: + with self.db_connector.Session() as db: deposits = ( db.query(ExtraDeposit) .filter(ExtraDeposit.position_id == position_id) @@ -516,7 +527,7 @@ def get_extra_deposits_by_position_id(self, position_id: UUID) -> list[ExtraDepo :param position_id: UUID of the position :return: list of extra deposits """ - with self.Session() as db: + with self.db_connector.Session() as db: extra_deposits = ( db.query(ExtraDeposit) .filter(ExtraDeposit.position_id == position_id) @@ -529,6 +540,6 @@ def delete_all_extra_deposits(self, position_id: UUID) -> None: """ Delete all extra deposits for a position. """ - with self.Session() as db: + with self.db_connector.Session() as db: db.query(ExtraDeposit).filter(ExtraDeposit.position_id == position_id).delete() db.commit() diff --git a/quantara/web_app/db/crud/telegram.py b/quantara/web_app/db/crud/telegram.py index b62ab2e9c..b5e580b3a 100644 --- a/quantara/web_app/db/crud/telegram.py +++ b/quantara/web_app/db/crud/telegram.py @@ -15,18 +15,23 @@ ModelType = TypeVar("ModelType", bound=Base) -class TelegramUserDBConnector(DBConnector): +class TelegramUserDBConnector: """ Provides database connection and operations management for the TelegramUser model. """ + def __init__(self, db_connector: DBConnector = None): + from web_app.db.database import db_connector as default_db_connector + + self.db_connector = db_connector or default_db_connector + def get_telegram_user_by_wallet_id(self, wallet_id: str) -> TelegramUser | None: """ Retrieves a TelegramUser by their wallet ID. :param wallet_id: str :return: TelegramUser | None """ - return self.get_object_by_field(TelegramUser, "wallet_id", wallet_id) + return self.db_connector.get_object_by_field(TelegramUser, "wallet_id", wallet_id) def get_user_by_telegram_id(self, telegram_id: str) -> TelegramUser | None: """ @@ -34,7 +39,7 @@ def get_user_by_telegram_id(self, telegram_id: str) -> TelegramUser | None: :param telegram_id: str :return: TelegramUser | None """ - return self.get_object_by_field(TelegramUser, "telegram_id", telegram_id) + return self.db_connector.get_object_by_field(TelegramUser, "telegram_id", telegram_id) def get_wallet_id_by_telegram_id(self, telegram_id: str) -> str | None: """ @@ -52,7 +57,7 @@ def create_telegram_user(self, user_data: dict) -> TelegramUser: :return: TelegramUser """ telegram_user = TelegramUser(**user_data) - return self.write_to_db(telegram_user) + return self.db_connector.write_to_db(telegram_user) def update_telegram_user(self, telegram_id: str, user_data: dict) -> None: """ @@ -61,7 +66,7 @@ def update_telegram_user(self, telegram_id: str, user_data: dict) -> None: :param user_data: dict :return: None """ - with self.Session() as session: + with self.db_connector.Session() as session: stmt = ( update(TelegramUser) .where(TelegramUser.telegram_id == telegram_id) @@ -93,7 +98,7 @@ def delete_telegram_user(self, telegram_id: str) -> None: """ user = self.get_user_by_telegram_id(telegram_id) if user: - self.delete_object_by_id(user, user.id) + self.db_connector.delete_object_by_id(TelegramUser, user.id) def set_allow_notification(self, telegram_id: str, wallet_id: str) -> bool: """ diff --git a/quantara/web_app/db/crud/transaction.py b/quantara/web_app/db/crud/transaction.py index f07b33470..18e819346 100644 --- a/quantara/web_app/db/crud/transaction.py +++ b/quantara/web_app/db/crud/transaction.py @@ -12,11 +12,16 @@ ModelType = TypeVar("ModelType", bound=Base) -class TransactionDBConnector(DBConnector): +class TransactionDBConnector: """ Provides database connection and operations management for the Transaction model. """ + def __init__(self, db_connector: DBConnector = None): + from web_app.db.database import db_connector as default_db_connector + + self.db_connector = db_connector or default_db_connector + def create_transaction( self, position_id: uuid.UUID, transaction_hash: str, status: TransactionStatus ) -> Transaction: @@ -28,5 +33,5 @@ def create_transaction( transaction_hash=transaction_hash, status=status, ) - transaction = self.write_to_db(transaction) + transaction = self.db_connector.write_to_db(transaction) return transaction diff --git a/quantara/web_app/db/crud/user.py b/quantara/web_app/db/crud/user.py index 52e503d86..74ec83b3f 100644 --- a/quantara/web_app/db/crud/user.py +++ b/quantara/web_app/db/crud/user.py @@ -15,11 +15,16 @@ ModelType = TypeVar("ModelType", bound=Base) -class UserDBConnector(DBConnector): +class UserDBConnector: """ Provides database connection and operations management for the User model. """ + def __init__(self, db_connector: DBConnector = None): + from web_app.db.database import db_connector as default_db_connector + + self.db_connector = db_connector or default_db_connector + def get_users_for_notifications(self) -> list[tuple[str, str]]: """ Retrieves the contract_address of users with an OPENED position status and @@ -27,7 +32,7 @@ def get_users_for_notifications(self) -> list[tuple[str, str]]: :return: List of tuples (contract_address, telegram_id) """ - with self.Session() as db: + with self.db_connector.Session() as db: try: results = ( db.query(User.contract_address, TelegramUser.telegram_id) @@ -51,7 +56,7 @@ def get_user_by_wallet_id(self, wallet_id: str) -> User | None: :param wallet_id: str :return: User | None """ - return self.get_object_by_field(User, "wallet_id", wallet_id) + return self.db_connector.get_object_by_field(User, "wallet_id", wallet_id) def get_contract_address_by_wallet_id(self, wallet_id: str) -> str: """ @@ -69,7 +74,7 @@ def create_user(self, wallet_id: str) -> User: :return: User """ user = User(wallet_id=wallet_id) - self.write_to_db(user) + self.db_connector.write_to_db(user) return user def update_user_contract(self, user: User, contract_address: str) -> None: @@ -81,14 +86,14 @@ def update_user_contract(self, user: User, contract_address: str) -> None: """ user.is_contract_deployed = not user.is_contract_deployed user.contract_address = contract_address - self.write_to_db(user) + self.db_connector.write_to_db(user) def get_unique_users_count(self) -> int: """ Retrieves the number of unique users in the database. :return: The count of unique users. """ - with self.Session() as db: + with self.db_connector.Session() as db: try: # Query to count distinct users based on wallet ID unique_users_count = db.query(User.wallet_id).distinct().count() @@ -113,7 +118,7 @@ def fetch_user_history(self, user_id: int) -> list[dict]: ### Returns: - A list of dictionaries containing position details. """ - with self.Session() as db: + with self.db_connector.Session() as db: try: # Query positions matching the user_id positions = ( @@ -155,7 +160,7 @@ def delete_user_by_wallet_id(self, wallet_id: str) -> None: :return: None :raises SQLAlchemyError: If the operation fails """ - with self.Session() as session: + with self.db_connector.Session() as session: try: user = session.query(User).filter(User.wallet_id == wallet_id).first() if user: diff --git a/quantara/web_app/db/database.py b/quantara/web_app/db/database.py index e3f704394..d3f31ab49 100644 --- a/quantara/web_app/db/database.py +++ b/quantara/web_app/db/database.py @@ -32,6 +32,12 @@ Base = declarative_base() + +# Create a shared DBConnector instance +from web_app.db.crud.base import DBConnector +db_connector = DBConnector(engine=engine) + + def get_database() -> Generator[Session, None, None]: """ FastAPI dependency that yields a database session and ensures cleanup. diff --git a/quantara/web_app/test_integration/utils.py b/quantara/web_app/test_integration/utils.py index a04461a3f..b8684f0b6 100644 --- a/quantara/web_app/test_integration/utils.py +++ b/quantara/web_app/test_integration/utils.py @@ -27,4 +27,4 @@ def with_temp_user(wallet_id: str) -> User: for position in position_db.get_all_positions_by_wallet_id(wallet_id, 0, 1000): position_db.delete_all_extra_deposits(position["id"]) position_db.delete_all_user_positions(user.id) - position_db.delete_user_by_wallet_id(wallet_id) + user_db.delete_user_by_wallet_id(wallet_id)