diff --git a/packages/backend/app/models.py b/packages/backend/app/models.py index 64d44810..5a866518 100644 --- a/packages/backend/app/models.py +++ b/packages/backend/app/models.py @@ -133,3 +133,58 @@ class AuditLog(db.Model): user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=True) action = db.Column(db.String(100), nullable=False) created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) + + +class BankConnectionStatus(str, Enum): + """Status of a bank connection.""" + PENDING = "PENDING" + ACTIVE = "ACTIVE" + ERROR = "ERROR" + DISCONNECTED = "DISCONNECTED" + + +class BankConnection(db.Model): + """ + Model for storing user's bank connections. + + This allows users to connect their bank accounts through + various connectors (Plaid, Mock, etc.) and sync transactions. + """ + __tablename__ = "bank_connections" + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False) + connector_name = db.Column(db.String(50), nullable=False) + account_name = db.Column(db.String(200), nullable=False) + account_mask = db.Column(db.String(10), nullable=True) # Last 4 digits + status = db.Column( + SAEnum(BankConnectionStatus), + default=BankConnectionStatus.PENDING, + nullable=False + ) + access_token = db.Column(db.Text, nullable=False) + refresh_token = db.Column(db.Text, nullable=True) + token_expires_at = db.Column(db.DateTime, nullable=True) + institution_id = db.Column(db.String(100), nullable=True) + institution_name = db.Column(db.String(200), nullable=True) + last_sync_at = db.Column(db.DateTime, nullable=True) + created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) + error_message = db.Column(db.Text, nullable=True) + + # Relationship to user + user = db.relationship("User", backref=db.backref("bank_connections", lazy="dynamic")) + + def to_dict(self) -> dict: + return { + "id": self.id, + "user_id": self.user_id, + "connector_name": self.connector_name, + "account_name": self.account_name, + "account_mask": self.account_mask, + "status": self.status.value if self.status else None, + "institution_id": self.institution_id, + "institution_name": self.institution_name, + "last_sync_at": self.last_sync_at.isoformat() if self.last_sync_at else None, + "created_at": self.created_at.isoformat() if self.created_at else None, + "error_message": self.error_message, + } diff --git a/packages/backend/app/routes/__init__.py b/packages/backend/app/routes/__init__.py index f13b0f89..5190586b 100644 --- a/packages/backend/app/routes/__init__.py +++ b/packages/backend/app/routes/__init__.py @@ -7,6 +7,7 @@ from .categories import bp as categories_bp from .docs import bp as docs_bp from .dashboard import bp as dashboard_bp +from .bank_connections import bp as bank_connections_bp def register_routes(app: Flask): @@ -18,3 +19,4 @@ def register_routes(app: Flask): app.register_blueprint(categories_bp, url_prefix="/categories") app.register_blueprint(docs_bp, url_prefix="/docs") app.register_blueprint(dashboard_bp, url_prefix="/dashboard") + app.register_blueprint(bank_connections_bp, url_prefix="/bank") diff --git a/packages/backend/app/routes/bank_connections.py b/packages/backend/app/routes/bank_connections.py new file mode 100644 index 00000000..9c4de229 --- /dev/null +++ b/packages/backend/app/routes/bank_connections.py @@ -0,0 +1,340 @@ +""" +Bank Connections API + +Routes for managing bank connections and importing transactions. +""" + +from datetime import datetime, date, timedelta +from decimal import Decimal + +from flask import Blueprint, jsonify, request +from flask_jwt_extended import jwt_required, get_jwt_identity + +from ..extensions import db +from ..models import BankConnection, BankConnectionStatus, Expense +from ..services.bank_connectors import get_connector, connector_registry + +bp = Blueprint("bank_connections", __name__) + + +@bp.get("/connectors") +@jwt_required() +def list_connectors(): + """List all available bank connectors.""" + connectors = connector_registry.list_connectors() + return jsonify(connectors) + + +@bp.post("/connectors//authorize") +@jwt_required() +def get_authorization_url(connector_name: str): + """ + Get the authorization URL for a bank connector. + + Returns the OAuth URL the user should be redirected to. + """ + uid = int(get_jwt_identity()) + + connector = get_connector(connector_name) + if not connector: + return jsonify(error=f"Connector '{connector_name}' not found"), 404 + + data = request.get_json() or {} + redirect_uri = data.get("redirect_uri") + if not redirect_uri: + return jsonify(error="redirect_uri is required"), 400 + + # Generate state for CSRF protection + state = data.get("state", f"{uid}_{datetime.utcnow().timestamp()}") + + try: + auth_url = connector.get_authorization_url(redirect_uri, state=state) + return jsonify({"authorization_url": auth_url, "state": state}) + except Exception as e: + return jsonify(error=str(e)), 500 + + +@bp.post("/connectors//callback") +@jwt_required() +def handle_oauth_callback(connector_name: str): + """ + Handle the OAuth callback from a bank connector. + + Exchanges the authorization code for access tokens and + creates a new bank connection record. + """ + uid = int(get_jwt_identity()) + + connector = get_connector(connector_name) + if not connector: + return jsonify(error=f"Connector '{connector_name}' not found"), 404 + + data = request.get_json() or {} + code = data.get("code") + if not code: + return jsonify(error="code is required"), 400 + + redirect_uri = data.get("redirect_uri", "") + + try: + # Exchange code for tokens + token_data = connector.exchange_token(code, redirect_uri) + + # Get account info + account_info = connector.get_account_info(token_data["access_token"]) + + # Create bank connection record + expires_at = None + if token_data.get("expires_in"): + expires_at = datetime.utcnow() + timedelta(seconds=token_data["expires_in"]) + + connection = BankConnection( + user_id=uid, + connector_name=connector_name, + account_name=account_info.get("account_name", "Unknown"), + account_mask=account_info.get("account_mask"), + status=BankConnectionStatus.ACTIVE, + access_token=token_data["access_token"], + refresh_token=token_data.get("refresh_token"), + token_expires_at=expires_at, + institution_id=account_info.get("institution_id"), + institution_name=account_info.get("institution_name"), + ) + + db.session.add(connection) + db.session.commit() + + return jsonify(connection.to_dict()), 201 + + except Exception as e: + db.session.rollback() + return jsonify(error=str(e)), 500 + + +@bp.get("") +@jwt_required() +def list_connections(): + """List all bank connections for the current user.""" + uid = int(get_jwt_identity()) + + connections = ( + db.session.query(BankConnection) + .filter_by(user_id=uid) + .order_by(BankConnection.created_at.desc()) + .all() + ) + + return jsonify([c.to_dict() for c in connections]) + + +@bp.get("/") +@jwt_required() +def get_connection(connection_id: int): + """Get a specific bank connection.""" + uid = int(get_jwt_identity()) + + connection = db.session.get(BankConnection, connection_id) + if not connection or connection.user_id != uid: + return jsonify(error="not found"), 404 + + return jsonify(connection.to_dict()) + + +@bp.delete("/") +@jwt_required() +def delete_connection(connection_id: int): + """Delete a bank connection.""" + uid = int(get_jwt_identity()) + + connection = db.session.get(BankConnection, connection_id) + if not connection or connection.user_id != uid: + return jsonify(error="not found"), 404 + + # Try to disconnect via the connector + connector = get_connector(connection.connector_name) + if connector: + try: + connector.disconnect(connection.access_token) + except Exception: + pass # Ignore errors during disconnect + + db.session.delete(connection) + db.session.commit() + + return jsonify({"deleted": True}) + + +@bp.post("//refresh") +@jwt_required() +def refresh_connection(connection_id: int): + """ + Refresh a bank connection's access token. + + Uses the refresh token to get a new access token. + """ + uid = int(get_jwt_identity()) + + connection = db.session.get(BankConnection, connection_id) + if not connection or connection.user_id != uid: + return jsonify(error="not found"), 404 + + if not connection.refresh_token: + return jsonify(error="no refresh token available"), 400 + + connector = get_connector(connection.connector_name) + if not connector: + return jsonify(error=f"Connector '{connection.connector_name}' not found"), 404 + + try: + token_data = connector.refresh_access_token(connection.refresh_token) + + # Update connection with new tokens + connection.access_token = token_data["access_token"] + if token_data.get("refresh_token"): + connection.refresh_token = token_data["refresh_token"] + if token_data.get("expires_in"): + connection.token_expires_at = datetime.utcnow() + timedelta( + seconds=token_data["expires_in"] + ) + connection.error_message = None + + db.session.commit() + + return jsonify(connection.to_dict()) + + except Exception as e: + connection.status = BankConnectionStatus.ERROR + connection.error_message = str(e) + db.session.commit() + return jsonify(error=str(e)), 500 + + +@bp.post("//import") +@jwt_required() +def import_transactions(connection_id: int): + """ + Import transactions from a bank connection. + + Fetches transactions from the connector and creates + expense records in the database. + """ + uid = int(get_jwt_identity()) + + connection = db.session.get(BankConnection, connection_id) + if not connection or connection.user_id != uid: + return jsonify(error="not found"), 404 + + if connection.status != BankConnectionStatus.ACTIVE: + return jsonify(error="connection is not active"), 400 + + connector = get_connector(connection.connector_name) + if not connector: + return jsonify(error=f"Connector '{connection.connector_name}' not found"), 404 + + # Parse date range from request + data = request.get_json() or {} + start_date = None + end_date = None + + if data.get("start_date"): + try: + start_date = date.fromisoformat(data["start_date"]) + except ValueError: + return jsonify(error="invalid start_date format"), 400 + + if data.get("end_date"): + try: + end_date = date.fromisoformat(data["end_date"]) + except ValueError: + return jsonify(error="invalid end_date format"), 400 + + try: + # Import transactions from connector + transactions = connector.import_transactions( + connection.access_token, + start_date=start_date, + end_date=end_date, + ) + + # Create expense records + created_expenses = [] + for tx in transactions: + # Check if transaction already exists (by external_id) + existing = ( + db.session.query(Expense) + .filter_by( + user_id=uid, + notes=tx.description, + spent_at=tx.date, + ) + .first() + ) + + if not existing: + expense = Expense( + user_id=uid, + amount=abs(tx.amount), + currency=tx.currency, + expense_type=tx.expense_type, + notes=tx.description, + spent_at=tx.date, + category_id=tx.category_id, + ) + db.session.add(expense) + created_expenses.append(expense) + + # Update last sync time + connection.last_sync_at = datetime.utcnow() + + db.session.commit() + + return jsonify({ + "imported": len(created_expenses), + "skipped": len(transactions) - len(created_expenses), + "total": len(transactions), + }) + + except Exception as e: + db.session.rollback() + connection.status = BankConnectionStatus.ERROR + connection.error_message = str(e) + db.session.commit() + return jsonify(error=str(e)), 500 + + +@bp.post("//validate") +@jwt_required() +def validate_connection(connection_id: int): + """Validate that a bank connection is still active.""" + uid = int(get_jwt_identity()) + + connection = db.session.get(BankConnection, connection_id) + if not connection or connection.user_id != uid: + return jsonify(error="not found"), 404 + + connector = get_connector(connection.connector_name) + if not connector: + return jsonify(error=f"Connector '{connection.connector_name}' not found"), 404 + + try: + is_valid = connector.validate_connection(connection.access_token) + + if is_valid: + connection.status = BankConnectionStatus.ACTIVE + connection.error_message = None + else: + connection.status = BankConnectionStatus.ERROR + connection.error_message = "Connection validation failed" + + db.session.commit() + + return jsonify({ + "valid": is_valid, + "status": connection.status.value, + }) + + except Exception as e: + connection.status = BankConnectionStatus.ERROR + connection.error_message = str(e) + db.session.commit() + return jsonify(error=str(e)), 500 \ No newline at end of file diff --git a/packages/backend/app/services/bank_connectors/__init__.py b/packages/backend/app/services/bank_connectors/__init__.py new file mode 100644 index 00000000..cde136ab --- /dev/null +++ b/packages/backend/app/services/bank_connectors/__init__.py @@ -0,0 +1,21 @@ +""" +Bank Connectors Package + +This package provides a pluggable architecture for bank integrations. +Each connector implements the BankConnector interface to provide +import and refresh functionality for bank transaction data. +""" + +from .base import BankConnector, BankConnection, Transaction +from .registry import connector_registry, get_connector + +# Import mock connector to register it +from . import mock # noqa: F401 - registers the mock connector + +__all__ = [ + "BankConnector", + "BankConnection", + "Transaction", + "connector_registry", + "get_connector", +] \ No newline at end of file diff --git a/packages/backend/app/services/bank_connectors/base.py b/packages/backend/app/services/bank_connectors/base.py new file mode 100644 index 00000000..2a9ea215 --- /dev/null +++ b/packages/backend/app/services/bank_connectors/base.py @@ -0,0 +1,200 @@ +""" +Bank Connector Base Classes + +Defines the interface and common types for bank integrations. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime, date +from decimal import Decimal +from enum import Enum +from typing import Any + + +class ConnectionStatus(str, Enum): + """Status of a bank connection.""" + PENDING = "PENDING" + ACTIVE = "ACTIVE" + ERROR = "ERROR" + DISCONNECTED = "DISCONNECTED" + + +@dataclass +class Transaction: + """Represents a single bank transaction.""" + date: date + amount: Decimal + description: str + currency: str = "USD" + expense_type: str = "EXPENSE" # EXPENSE or INCOME + category_id: int | None = None + notes: str = "" + external_id: str = "" # Bank's transaction ID + + def to_dict(self) -> dict[str, Any]: + return { + "date": self.date.isoformat(), + "amount": float(self.amount), + "description": self.description, + "currency": self.currency, + "expense_type": self.expense_type, + "category_id": self.category_id, + "notes": self.notes, + "external_id": self.external_id, + } + + +@dataclass +class BankConnection: + """Represents a user's bank connection.""" + id: int | None = None + user_id: int = 0 + connector_name: str = "" + account_name: str = "" + account_mask: str = "" # Last 4 digits + status: ConnectionStatus = ConnectionStatus.PENDING + access_token: str = "" + refresh_token: str = "" + token_expires_at: datetime | None = None + institution_id: str = "" + institution_name: str = "" + last_sync_at: datetime | None = None + created_at: datetime | None = None + error_message: str = "" + + def to_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "user_id": self.user_id, + "connector_name": self.connector_name, + "account_name": self.account_name, + "account_mask": self.account_mask, + "status": self.status.value if self.status else None, + "institution_id": self.institution_id, + "institution_name": self.institution_name, + "last_sync_at": self.last_sync_at.isoformat() if self.last_sync_at else None, + "created_at": self.created_at.isoformat() if self.created_at else None, + "error_message": self.error_message, + } + + +class BankConnector(ABC): + """ + Abstract base class for bank connectors. + + All bank integrations must implement this interface to provide + import and refresh functionality for transaction data. + """ + + # Connector identifier (e.g., "plaid", "mock", "chase") + name: str = "" + + # Human-readable display name + display_name: str = "" + + # Whether this connector supports OAuth flow + supports_oauth: bool = False + + @abstractmethod + def get_authorization_url(self, redirect_uri: str, **kwargs) -> str: + """ + Get the OAuth authorization URL for connecting the bank. + + Args: + redirect_uri: The URI to redirect after authentication + **kwargs: Additional connector-specific parameters + + Returns: + The authorization URL to redirect the user to + """ + pass + + @abstractmethod + def exchange_token(self, code: str, redirect_uri: str) -> dict[str, Any]: + """ + Exchange authorization code for access tokens. + + Args: + code: The authorization code from the OAuth callback + redirect_uri: The redirect URI used in the initial request + + Returns: + Dictionary containing access_token, refresh_token, and expires_in + """ + pass + + @abstractmethod + def refresh_access_token(self, refresh_token: str) -> dict[str, Any]: + """ + Refresh an expired access token. + + Args: + refresh_token: The refresh token from a previous connection + + Returns: + Dictionary containing new access_token, refresh_token, and expires_in + """ + pass + + @abstractmethod + def get_account_info(self, access_token: str) -> dict[str, Any]: + """ + Get account information for the connected bank. + + Args: + access_token: The access token for API requests + + Returns: + Dictionary with account details (name, mask, institution info) + """ + pass + + @abstractmethod + def import_transactions( + self, + access_token: str, + start_date: date | None = None, + end_date: date | None = None, + ) -> list[Transaction]: + """ + Import transactions from the bank. + + Args: + access_token: The access token for API requests + start_date: Optional start date for transaction import + end_date: Optional end date for transaction import + + Returns: + List of Transaction objects + """ + pass + + @abstractmethod + def disconnect(self, access_token: str) -> bool: + """ + Disconnect the bank connection. + + Args: + access_token: The access token to revoke + + Returns: + True if disconnection was successful + """ + pass + + def validate_connection(self, access_token: str) -> bool: + """ + Validate that the connection is still active. + + Args: + access_token: The access token to validate + + Returns: + True if the connection is valid + """ + try: + self.get_account_info(access_token) + return True + except Exception: + return False \ No newline at end of file diff --git a/packages/backend/app/services/bank_connectors/mock.py b/packages/backend/app/services/bank_connectors/mock.py new file mode 100644 index 00000000..0003f8d5 --- /dev/null +++ b/packages/backend/app/services/bank_connectors/mock.py @@ -0,0 +1,135 @@ +""" +Mock Bank Connector + +A mock connector for testing and development purposes. +""" + +from datetime import date, datetime, timedelta +from decimal import Decimal +from typing import Any + +from .base import BankConnector, ConnectionStatus, Transaction + + +class MockBankConnector(BankConnector): + """ + Mock bank connector for testing and development. + + This connector simulates a bank connection without making + any real API calls. It generates sample transactions for + testing the import functionality. + """ + + name = "mock" + display_name = "Mock Bank (Test)" + supports_oauth = True + + # Sample transaction data for simulation + SAMPLE_TRANSACTIONS = [ + {"date_offset": 0, "amount": -45.99, "description": "Grocery Store", "expense_type": "EXPENSE"}, + {"date_offset": 1, "amount": -12.50, "description": "Coffee Shop", "expense_type": "EXPENSE"}, + {"date_offset": 2, "amount": -150.00, "description": "Electric Bill", "expense_type": "EXPENSE"}, + {"date_offset": 3, "amount": 3500.00, "description": "Salary Deposit", "expense_type": "INCOME"}, + {"date_offset": 4, "amount": -89.99, "description": "Gas Station", "expense_type": "EXPENSE"}, + {"date_offset": 5, "amount": -24.99, "description": "Streaming Service", "expense_type": "EXPENSE"}, + {"date_offset": 6, "amount": -65.00, "description": "Restaurant", "expense_type": "EXPENSE"}, + {"date_offset": 7, "amount": -199.00, "description": "Internet Bill", "expense_type": "EXPENSE"}, + {"date_offset": 8, "amount": -55.00, "description": "Gym Membership", "expense_type": "EXPENSE"}, + {"date_offset": 9, "amount": 500.00, "description": "Freelance Payment", "expense_type": "INCOME"}, + {"date_offset": 10, "amount": -35.00, "description": "Pharmacy", "expense_type": "EXPENSE"}, + {"date_offset": 11, "amount": -125.00, "description": "Car Insurance", "expense_type": "EXPENSE"}, + {"date_offset": 12, "amount": -42.50, "description": "Coffee Shop", "expense_type": "EXPENSE"}, + {"date_offset": 13, "amount": -78.00, "description": "Online Shopping", "expense_type": "EXPENSE"}, + {"date_offset": 14, "amount": 100.00, "description": "Refund", "expense_type": "INCOME"}, + ] + + def get_authorization_url(self, redirect_uri: str, **kwargs) -> str: + """Get mock authorization URL.""" + # In a real implementation, this would redirect to the bank's OAuth page + # For mock, we simulate the flow + state = kwargs.get("state", "mock_state") + return f"{redirect_uri}?code=mock_auth_code&state={state}" + + def exchange_token(self, code: str, redirect_uri: str) -> dict[str, Any]: + """Exchange mock authorization code for tokens.""" + # Simulate token exchange + return { + "access_token": f"mock_access_{code[:8]}", + "refresh_token": f"mock_refresh_{code[:8]}", + "expires_in": 3600, + "token_type": "Bearer", + } + + def refresh_access_token(self, refresh_token: str) -> dict[str, Any]: + """Refresh mock access token.""" + return { + "access_token": f"mock_access_refreshed_{refresh_token[:8]}", + "refresh_token": f"mock_refresh_refreshed_{refresh_token[:8]}", + "expires_in": 3600, + "token_type": "Bearer", + } + + def get_account_info(self, access_token: str) -> dict[str, Any]: + """Get mock account information.""" + # Simulate account info based on token + return { + "account_id": "mock_account_123", + "account_name": "Mock Checking Account", + "account_mask": "1234", + "account_type": "checking", + "institution_id": "mock_institution", + "institution_name": "Mock Bank", + "balance": 5432.10, + "currency": "USD", + } + + def import_transactions( + self, + access_token: str, + start_date: date | None = None, + end_date: date | None = None, + ) -> list[Transaction]: + """Import mock transactions.""" + today = date.today() + + # Default to last 30 days if no dates specified + if end_date is None: + end_date = today + if start_date is None: + start_date = today - timedelta(days=30) + + transactions = [] + for i, sample in enumerate(self.SAMPLE_TRANSACTIONS): + tx_date = today - timedelta(days=sample["date_offset"]) + + # Filter by date range + if tx_date < start_date or tx_date > end_date: + continue + + transactions.append( + Transaction( + date=tx_date, + amount=Decimal(str(sample["amount"])), + description=sample["description"], + currency="USD", + expense_type=sample["expense_type"], + external_id=f"mock_tx_{i}_{tx_date.isoformat()}", + ) + ) + + return transactions + + def disconnect(self, access_token: str) -> bool: + """Disconnect mock connection.""" + # Always succeed for mock + return True + + def validate_connection(self, access_token: str) -> bool: + """Validate mock connection.""" + # Mock always validates successfully + return True + + +# Auto-register the mock connector +from .registry import register_connector +register_connector(MockBankConnector) \ No newline at end of file diff --git a/packages/backend/app/services/bank_connectors/registry.py b/packages/backend/app/services/bank_connectors/registry.py new file mode 100644 index 00000000..154592e4 --- /dev/null +++ b/packages/backend/app/services/bank_connectors/registry.py @@ -0,0 +1,71 @@ +""" +Bank Connector Registry + +Provides a registry for managing available bank connectors. +""" + +from typing import Any + +from .base import BankConnector + + +class ConnectorRegistry: + """Registry for bank connectors.""" + + def __init__(self): + self._connectors: dict[str, type[BankConnector]] = {} + + def register(self, connector_class: type[BankConnector]) -> None: + """Register a connector class.""" + instance = connector_class() + self._connectors[instance.name] = connector_class + + def get(self, name: str) -> type[BankConnector] | None: + """Get a connector class by name.""" + return self._connectors.get(name) + + def list_connectors(self) -> list[dict[str, Any]]: + """List all registered connectors.""" + result = [] + for name, connector_class in self._connectors.items(): + instance = connector_class() + result.append({ + "name": instance.name, + "display_name": instance.display_name, + "supports_oauth": instance.supports_oauth, + }) + return result + + def is_registered(self, name: str) -> bool: + """Check if a connector is registered.""" + return name in self._connectors + + +# Global registry instance +connector_registry = ConnectorRegistry() + + +def get_connector(name: str) -> BankConnector | None: + """ + Get a connector instance by name. + + Args: + name: The connector name + + Returns: + An instance of the connector, or None if not found + """ + connector_class = connector_registry.get(name) + if connector_class: + return connector_class() + return None + + +def register_connector(connector_class: type[BankConnector]) -> None: + """ + Register a connector class. + + Args: + connector_class: The connector class to register + """ + connector_registry.register(connector_class) \ No newline at end of file diff --git a/packages/backend/tests/test_bank_connectors.py b/packages/backend/tests/test_bank_connectors.py new file mode 100644 index 00000000..8fee8211 --- /dev/null +++ b/packages/backend/tests/test_bank_connectors.py @@ -0,0 +1,164 @@ +""" +Tests for bank connector architecture. +""" + +import pytest +from datetime import date, timedelta +from decimal import Decimal + +from app.services.bank_connectors import ( + get_connector, + connector_registry, + BankConnector, + Transaction, +) + + +class TestConnectorRegistry: + """Tests for the connector registry.""" + + def test_list_connectors(self): + """Test listing available connectors.""" + connectors = connector_registry.list_connectors() + assert len(connectors) >= 1 + assert any(c["name"] == "mock" for c in connectors) + + def test_get_connector(self): + """Test getting a connector by name.""" + connector = get_connector("mock") + assert connector is not None + assert connector.name == "mock" + + def test_get_connector_not_found(self): + """Test getting a non-existent connector.""" + connector = get_connector("nonexistent") + assert connector is None + + def test_is_registered(self): + """Test checking if a connector is registered.""" + assert connector_registry.is_registered("mock") + assert not connector_registry.is_registered("nonexistent") + + +class TestMockConnector: + """Tests for the mock bank connector.""" + + @pytest.fixture + def connector(self): + return get_connector("mock") + + def test_connector_properties(self, connector): + """Test connector basic properties.""" + assert connector.name == "mock" + assert connector.display_name == "Mock Bank (Test)" + assert connector.supports_oauth is True + + def test_get_authorization_url(self, connector): + """Test getting authorization URL.""" + url = connector.get_authorization_url( + "http://localhost/callback", + state="test_state" + ) + assert "code=mock_auth_code" in url + assert "state=test_state" in url + + def test_exchange_token(self, connector): + """Test token exchange.""" + tokens = connector.exchange_token("test_code", "http://localhost/callback") + assert "access_token" in tokens + assert "refresh_token" in tokens + assert tokens["expires_in"] == 3600 + + def test_refresh_access_token(self, connector): + """Test token refresh.""" + tokens = connector.refresh_access_token("test_refresh_token") + assert "access_token" in tokens + assert "refresh_token" in tokens + + def test_get_account_info(self, connector): + """Test getting account info.""" + account = connector.get_account_info("test_token") + assert account["account_name"] == "Mock Checking Account" + assert account["account_mask"] == "1234" + assert account["institution_name"] == "Mock Bank" + + def test_import_transactions(self, connector): + """Test importing transactions.""" + today = date.today() + txs = connector.import_transactions( + "test_token", + start_date=today - timedelta(days=7), + end_date=today + ) + assert len(txs) > 0 + assert all(isinstance(tx, Transaction) for tx in txs) + + def test_import_transactions_with_defaults(self, connector): + """Test importing transactions with default date range.""" + txs = connector.import_transactions("test_token") + assert len(txs) > 0 + + def test_disconnect(self, connector): + """Test disconnecting.""" + result = connector.disconnect("test_token") + assert result is True + + def test_validate_connection(self, connector): + """Test connection validation.""" + result = connector.validate_connection("test_token") + assert result is True + + +class TestTransaction: + """Tests for the Transaction dataclass.""" + + def test_transaction_creation(self): + """Test creating a transaction.""" + tx = Transaction( + date=date.today(), + amount=Decimal("100.50"), + description="Test transaction", + currency="USD", + expense_type="EXPENSE", + ) + assert tx.amount == Decimal("100.50") + assert tx.description == "Test transaction" + + def test_transaction_to_dict(self): + """Test transaction serialization.""" + tx = Transaction( + date=date(2024, 1, 15), + amount=Decimal("50.00"), + description="Test", + currency="USD", + ) + d = tx.to_dict() + assert d["date"] == "2024-01-15" + assert d["amount"] == 50.00 + assert d["description"] == "Test" + + +class TestBankConnectorInterface: + """Tests to verify the BankConnector interface.""" + + def test_mock_implements_interface(self): + """Verify mock connector implements all required methods.""" + connector = get_connector("mock") + + # Check all required methods exist + assert hasattr(connector, "get_authorization_url") + assert hasattr(connector, "exchange_token") + assert hasattr(connector, "refresh_access_token") + assert hasattr(connector, "get_account_info") + assert hasattr(connector, "import_transactions") + assert hasattr(connector, "disconnect") + assert hasattr(connector, "validate_connection") + + # Check all are callable + assert callable(connector.get_authorization_url) + assert callable(connector.exchange_token) + assert callable(connector.refresh_access_token) + assert callable(connector.get_account_info) + assert callable(connector.import_transactions) + assert callable(connector.disconnect) + assert callable(connector.validate_connection) \ No newline at end of file