Skip to content
15 changes: 10 additions & 5 deletions quantara/web_app/db/crud/airdrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,28 @@
ModelType = TypeVar("ModelType", bound=Base)


class AirDropDBConnector(DBConnector):
class AirDropDBConnector:
"""
Provides database connection and operations management for the AirDrop model.
"""

def __init__(self, db_connector: DBConnector = None):
from web_app.db.database import db_connector as default_db_connector

self.db_connector = db_connector or default_db_connector

def save_claim_data(self, airdrop_id: uuid.UUID, amount: Decimal) -> None:
"""
Updates the AirDrop instance with claim data.
:param airdrop_id: uuid.UUID
:param amount: Decimal
"""
airdrop = self.get_object(AirDrop, airdrop_id)
airdrop = self.db_connector.get_object(AirDrop, airdrop_id)
if airdrop:
airdrop.amount = amount
airdrop.is_claimed = True
airdrop.claimed_at = datetime.now()
self.write_to_db(airdrop)
self.db_connector.write_to_db(airdrop)
else:
logger.error(f"AirDrop with ID {airdrop_id} not found")

Expand All @@ -44,7 +49,7 @@ def get_all_unclaimed(self) -> list[AirDrop]:

:return: List of unclaimed AirDrop instances
"""
with self.Session() as db:
with self.db_connector.Session() as db:
try:
unclaimed_instances = (
db.query(AirDrop).filter_by(is_claimed=False).all()
Expand All @@ -61,7 +66,7 @@ def delete_all_users_airdrop(self, user_id: uuid.UUID) -> None:
Delete all airdrops for a user.
:param user_id: User ID
"""
with self.Session() as db:
with self.db_connector.Session() as db:
try:
db.query(AirDrop).filter_by(user_id=user_id).delete(
synchronize_session=False
Expand Down
5 changes: 3 additions & 2 deletions quantara/web_app/db/crud/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ class DBConnector:
- remove_object: Removes an object by its ID from the database.
"""

def __init__(self, db_url: str = SQLALCHEMY_DATABASE_URL):
def __init__(self, db_url: str = SQLALCHEMY_DATABASE_URL, engine=None):
"""
Initialize the database connection and session factory.
:param db_url: str = None
:param engine: sqlalchemy.engine.Engine = None
"""
self.engine = create_engine(db_url)
self.engine = engine or create_engine(db_url)
self.session_factory = sessionmaker(bind=self.engine)
self.Session = scoped_session(self.session_factory)

Expand Down
15 changes: 10 additions & 5 deletions quantara/web_app/db/crud/deposit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@
ModelType = TypeVar("ModelType", bound=Base)


class DepositDBConnector(DBConnector):
class DepositDBConnector:
"""
Provides database connection and operations management for the Vault model.
"""

def __init__(self, db_connector: DBConnector = None):
from web_app.db.database import db_connector as default_db_connector

self.db_connector = db_connector or default_db_connector

def create_vault(self, user: User, symbol: str, amount: str) -> Vault:
"""
Creates a new vault instance
Expand All @@ -30,7 +35,7 @@ def create_vault(self, user: User, symbol: str, amount: str) -> Vault:
:return: Vault
"""
vault = Vault(user_id=user.id, symbol=symbol, amount=amount)
self.write_to_db(vault)
self.db_connector.write_to_db(vault)
return vault

def get_vault(self, wallet_id: str, symbol: str) -> Vault | None:
Expand All @@ -42,8 +47,8 @@ def get_vault(self, wallet_id: str, symbol: str) -> Vault | None:

:return: Vault or None
"""
with self.Session() as db:
user = self.get_object_by_field(User, "wallet_id", wallet_id)
with self.db_connector.Session() as db:
user = self.db_connector.get_object_by_field(User, "wallet_id", wallet_id)
if not user:
logger.error(f"User with wallet id {wallet_id} not found")
return None
Expand All @@ -63,7 +68,7 @@ def add_vault_balance(self, wallet_id: str, symbol: str, amount: str) -> Vault:
vault = self.get_vault(wallet_id, symbol)
if not vault:
raise ValueError("Vault not found")
with self.Session() as db:
with self.db_connector.Session() as db:
new_amount = Decimal(vault.amount) + Decimal(amount)
db.query(Vault).filter_by(id=vault.id).update(amount=str(new_amount))
db.commit()
Expand Down
11 changes: 8 additions & 3 deletions quantara/web_app/db/crud/leaderboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,23 @@

logger = logging.getLogger(__name__)

class LeaderboardDBConnector(DBConnector):
class LeaderboardDBConnector:
"""
Provides database connection and operations management using SQLAlchemy
in a FastAPI application context.
"""

def __init__(self, db_connector: DBConnector = None):
from web_app.db.database import db_connector as default_db_connector

self.db_connector = db_connector or default_db_connector

def get_top_users_by_positions(self) -> list[dict]:
"""
Retrieves the top 10 users ordered by closed/opened positions.
:return: List of dictionaries containing wallet_id and positions_number.
"""
with self.Session() as db:
with self.db_connector.Session() as db:
try:
results = (
db.query(
Expand Down Expand Up @@ -53,7 +58,7 @@ def get_position_token_statistics(self) -> list[dict]:
Retrieves closed/opened positions groupped by token_symbol.
:return: List of dictionaries containing token_symbol and total_positions.
"""
with self.Session() as db:
with self.db_connector.Session() as db:
try:
results = (
db.query(
Expand Down
Loading
Loading