From ded58f2d21ff79ce8b25e832d15be83b594c9299 Mon Sep 17 00:00:00 2001 From: Brunno Vanelli Date: Thu, 23 Apr 2026 20:20:10 +0200 Subject: [PATCH] refactor: Fix more easy mypy issues. --- actual/__init__.py | 7 +++-- actual/budgets.py | 4 +-- actual/database.py | 10 ++++--- actual/queries.py | 59 ++++++++++++++++++++++++++++++++---------- tests/test_database.py | 16 +++++++++--- tests/test_models.py | 3 ++- 6 files changed, 74 insertions(+), 25 deletions(-) diff --git a/actual/__init__.py b/actual/__init__.py index ef583d8..6c11b73 100644 --- a/actual/__init__.py +++ b/actual/__init__.py @@ -346,8 +346,11 @@ def export_data(self, output_file: str | PathLike[str] | IO[bytes] | None = None z.write(self.data_dir / "metadata.json", "metadata.json") content = temp_file.getvalue() if output_file: - with open(output_file, "wb") as f: - f.write(content) + if isinstance(output_file, (str, PathLike)): + with open(output_file, "wb") as f: + f.write(content) + else: + output_file.write(content) return content def encrypt(self, encryption_password: str): diff --git a/actual/budgets.py b/actual/budgets.py index 49bbb6b..fabba78 100644 --- a/actual/budgets.py +++ b/actual/budgets.py @@ -5,7 +5,7 @@ from sqlmodel import Session, col, select -from actual.database import Categories, CategoryGroups, ReflectBudgets, Transactions, ZeroBudgets +from actual.database import BaseBudgets, Categories, CategoryGroups, ReflectBudgets, Transactions from actual.queries import ( _balance_base_query, _get_budget_table, @@ -64,7 +64,7 @@ class BudgetCategory(_HasDatabaseObject): This reflects the values displayed on the Actual frontend. """ - budget: ReflectBudgets | ZeroBudgets | None = None + budget: BaseBudgets | None = None """ The underlying budget database record, if it exists. diff --git a/actual/database.py b/actual/database.py index de124ab..3bf7041 100644 --- a/actual/database.py +++ b/actual/database.py @@ -104,15 +104,19 @@ def get_attribute_from_reflected_table_name(metadata: MetaData, table_name: str, return table.columns.get(column_name, None) -def get_class_by_table_name(table_name: str) -> type["BaseModel"] | None: +def get_class_by_table_name(table_name: str) -> type["BaseModel"]: """ Returns, based on the defined tables `__tablename__` the corresponding SQLModel object. - If not found, returns `None`. + If not found, raises `ValueError`. + + :param table_name: SQL table name. + :return SQLModel: SQLAlchemy object. + :raises ValueError: Raises `ValueError` if the table name is not existing. """ entry = __TABLE_COLUMNS_MAP__.get(table_name) if entry is None: - return None + raise ValueError(f"Could not find table '{table_name}' on the database model.") return entry["entity"] diff --git a/actual/queries.py b/actual/queries.py index 066206b..afb8162 100644 --- a/actual/queries.py +++ b/actual/queries.py @@ -13,10 +13,12 @@ from sqlalchemy.orm import joinedload from sqlalchemy.sql.expression import Select from sqlmodel import Session, col, select +from sqlmodel.sql.expression import SelectOfScalar from actual.crypto import is_uuid from actual.database import ( Accounts, + BaseBudgets, Categories, CategoryGroups, CategoryMapping, @@ -56,8 +58,8 @@ def _transactions_base_query( account: Accounts | str | None | None = None, category: Categories | str | None = None, include_deleted: bool = False, -) -> Select: - query = ( +) -> SelectOfScalar[Transactions]: + query: SelectOfScalar[Transactions] = ( select(Transactions) .options( joinedload(Transactions.account), @@ -232,7 +234,7 @@ def match_transaction( query = _transactions_base_query( s, date - datetime.timedelta(days=7), date + datetime.timedelta(days=8), account=account ).where(col(Transactions.amount) == round(amount * 100)) - results: list[Transactions] = s.exec(query).all() # noqa + results: list[Transactions] = list(s.exec(query).all()) # filter out the ones that were already matched if already_matched: matched = {t.id for t in already_matched} @@ -535,13 +537,17 @@ def create_split(s: Session, transaction: Transactions, amount: float | decimal. return split -def _base_query(instance: type[T], name: str | None = None, include_deleted: bool = False) -> Select: +def _base_query(instance: type[T], name: str | None = None, include_deleted: bool = False) -> SelectOfScalar[T]: """Internal method to reduce querying complexity on sub-functions.""" query = select(instance) if not include_deleted: - query = query.where(sqlalchemy.func.coalesce(instance.tombstone, 0) == 0) + tombstone_col = getattr(instance, "tombstone", None) + if tombstone_col is not None: + query = query.where(sqlalchemy.func.coalesce(tombstone_col, 0) == 0) if name: - query = query.where(instance.name.ilike(f"%{sqlalchemy.text(name).compile()}%")) + name_col = getattr(instance, "name", None) + if name_col is not None: + query = query.where(name_col.ilike(f"%{sqlalchemy.text(name).compile()}%")) return query @@ -839,7 +845,7 @@ def get_or_create_account(s: Session, name: str | Accounts) -> Accounts: return account -def _get_budget_table(s: Session) -> type[ReflectBudgets | ZeroBudgets]: +def _get_budget_table(s: Session) -> type[ZeroBudgets] | type[ReflectBudgets]: """ Finds out which type of budget the user uses. The types are: @@ -858,7 +864,7 @@ def _get_budget_table(s: Session) -> type[ReflectBudgets | ZeroBudgets]: def get_budgets( s: Session, month: datetime.date | None = None, category: str | Categories | None = None -) -> typing.Sequence[ZeroBudgets | ReflectBudgets]: +) -> typing.Sequence[BaseBudgets]: """ Returns a list of all available budgets. @@ -889,7 +895,7 @@ def get_budgets( return s.exec(query).unique().all() -def get_budget(s: Session, month: datetime.date, category: str | Categories) -> ZeroBudgets | ReflectBudgets | None: +def get_budget(s: Session, month: datetime.date, category: str | Categories) -> BaseBudgets | None: """ Gets an existing budget by category name, returns `None` if not found. @@ -910,7 +916,7 @@ def create_budget( category: str | Categories, amount: decimal.Decimal | float | int = 0.0, carryover: bool | None = None, -) -> ZeroBudgets | ReflectBudgets: +) -> BaseBudgets: """ Gets an existing budget based on the month and category. If it already exists, the amount will be replaced by the new amount. @@ -1140,6 +1146,32 @@ def get_schedules( return s.exec(query).all() +@typing.overload +def create_schedule( + s: Session, + date: datetime.date | datetime.datetime | Schedule, + amount: tuple[decimal.Decimal, decimal.Decimal] | tuple[float, float], + amount_operation: typing.Literal["isbetween"], + name: str | None, + payee: str | Payees | None, + account: str | Accounts | None, + posts_transaction: bool, +) -> Schedules: ... + + +@typing.overload +def create_schedule( + s: Session, + date: datetime.date | datetime.datetime | Schedule, + amount: decimal.Decimal | float, + amount_operation: typing.Literal["is", "isapprox"], + name: str | None, + payee: str | Payees | None, + account: str | Accounts | None, + posts_transaction: bool, +) -> Schedules: ... + + def create_schedule( s: Session, date: datetime.date | datetime.datetime | Schedule, @@ -1171,9 +1203,6 @@ def create_schedule( :param posts_transaction: Whether the schedule should auto-post transactions on your behalf. Defaults to false. :return: Rule database object created. """ - if amount_operation == "isbetween" and not isinstance(amount, tuple): - raise ActualError("When using 'isbetween', amount must be a tuple (num1, num2), where num1 < num2.") - schedule_id = str(uuid.uuid4()) conditions = [] # Handle the payee condition @@ -1197,6 +1226,8 @@ def create_schedule( ) # Handle the amount condition if amount_operation == "isbetween": + if not isinstance(amount, tuple): + raise ActualError("When using 'isbetween', amount must be a tuple (num1, num2), where num1 < num2.") conditions.append( Condition( field="amount", @@ -1205,6 +1236,8 @@ def create_schedule( ) ) else: + if isinstance(amount, tuple): + raise ActualError(f"When using '{amount_operation}', amount must be a single decimal number.") conditions.append(Condition(field="amount", op=ConditionType(amount_operation), value=decimal_to_cents(amount))) actions = [ diff --git a/tests/test_database.py b/tests/test_database.py index 42b83d6..56f9608 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -508,10 +508,6 @@ def test_schedule_is_betweeen(session): payee = get_or_create_payee(session, "Insurance company") # should always be paid on the first working day of the month config = create_schedule_config(expected_date, patterns=[Pattern(1, "day")], skip_weekend=True) - # if the amount_operation="isbetween", the schedule needs two amounts - with pytest.raises(ActualError, match="amount must be a tuple"): - create_schedule(session, config, 100.0, "isbetween", "Insurance", payee, account) - schedule = create_schedule(session, config, (100.0, 110.0), "isbetween", "Insurance", payee, account) assert json.loads(schedule.rule.conditions) == [ {"field": "description", "type": "id", "op": "is", "value": payee.id}, @@ -585,6 +581,18 @@ def test_schedule_populates_next_date_simple_date(session, start_date, expected_ assert rows[0].base_next_date == expected_next_date +def test_schedule_exceptions(session): + expected_date = datetime.date(2025, 10, 11) + account = create_account(session, "Bank") + payee = get_or_create_payee(session, "Insurance company") + # should always be paid on the first working day of the month + config = create_schedule_config(expected_date, patterns=[Pattern(1, "day")], skip_weekend=True) + with pytest.raises(ActualError, match="amount must be a tuple"): + create_schedule(session, config, 100.0, "isbetween", "Insurance", payee, account) + with pytest.raises(ActualError, match="amount must be a single decimal number"): + create_schedule(session, config, (100.0, 110.0), "isapprox", "Insurance", payee, account) + + def test_get_transactions_with_cleared_filter(session): acct = create_account(session, "ClearedTxs") create_transaction(session, date=today, account=acct, amount=10, cleared=False) diff --git a/tests/test_models.py b/tests/test_models.py index 13542d5..1506700 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -16,7 +16,8 @@ def test_get_class_by_table_name(): assert get_class_by_table_name("transactions") == Transactions - assert get_class_by_table_name("foo") is None + with pytest.raises(ValueError, match="Could not find table 'foo'"): + get_class_by_table_name("foo") def test_get_attribute_by_table_name():