Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions actual/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,11 @@ def export_data(self, output_file: str | PathLike[str] | IO[bytes] | None = None
z.write(self.data_dir / "metadata.json", "metadata.json")
content = temp_file.getvalue()
if output_file:
with open(output_file, "wb") as f:
f.write(content)
if isinstance(output_file, (str, PathLike)):
with open(output_file, "wb") as f:
f.write(content)
else:
output_file.write(content)
return content

def encrypt(self, encryption_password: str):
Expand Down
4 changes: 2 additions & 2 deletions actual/budgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sqlmodel import Session, col, select

from actual.database import Categories, CategoryGroups, ReflectBudgets, Transactions, ZeroBudgets
from actual.database import BaseBudgets, Categories, CategoryGroups, ReflectBudgets, Transactions
from actual.queries import (
_balance_base_query,
_get_budget_table,
Expand Down Expand Up @@ -64,7 +64,7 @@ class BudgetCategory(_HasDatabaseObject):
This reflects the values displayed on the Actual frontend.
"""

budget: ReflectBudgets | ZeroBudgets | None = None
budget: BaseBudgets | None = None
"""
The underlying budget database record, if it exists.

Expand Down
10 changes: 7 additions & 3 deletions actual/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,19 @@ def get_attribute_from_reflected_table_name(metadata: MetaData, table_name: str,
return table.columns.get(column_name, None)


def get_class_by_table_name(table_name: str) -> type["BaseModel"] | None:
def get_class_by_table_name(table_name: str) -> type["BaseModel"]:
"""
Returns, based on the defined tables `__tablename__` the corresponding SQLModel object.

If not found, returns `None`.
If not found, raises `ValueError`.

:param table_name: SQL table name.
:return SQLModel: SQLAlchemy object.
:raises ValueError: Raises `ValueError` if the table name is not existing.
"""
entry = __TABLE_COLUMNS_MAP__.get(table_name)
if entry is None:
return None
raise ValueError(f"Could not find table '{table_name}' on the database model.")
return entry["entity"]


Expand Down
59 changes: 46 additions & 13 deletions actual/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
from sqlalchemy.orm import joinedload
from sqlalchemy.sql.expression import Select
from sqlmodel import Session, col, select
from sqlmodel.sql.expression import SelectOfScalar

from actual.crypto import is_uuid
from actual.database import (
Accounts,
BaseBudgets,
Categories,
CategoryGroups,
CategoryMapping,
Expand Down Expand Up @@ -56,8 +58,8 @@ def _transactions_base_query(
account: Accounts | str | None | None = None,
category: Categories | str | None = None,
include_deleted: bool = False,
) -> Select:
query = (
) -> SelectOfScalar[Transactions]:
query: SelectOfScalar[Transactions] = (
select(Transactions)
.options(
joinedload(Transactions.account),
Expand Down Expand Up @@ -232,7 +234,7 @@ def match_transaction(
query = _transactions_base_query(
s, date - datetime.timedelta(days=7), date + datetime.timedelta(days=8), account=account
).where(col(Transactions.amount) == round(amount * 100))
results: list[Transactions] = s.exec(query).all() # noqa
results: list[Transactions] = list(s.exec(query).all())
# filter out the ones that were already matched
if already_matched:
matched = {t.id for t in already_matched}
Expand Down Expand Up @@ -535,13 +537,17 @@ def create_split(s: Session, transaction: Transactions, amount: float | decimal.
return split


def _base_query(instance: type[T], name: str | None = None, include_deleted: bool = False) -> Select:
def _base_query(instance: type[T], name: str | None = None, include_deleted: bool = False) -> SelectOfScalar[T]:
"""Internal method to reduce querying complexity on sub-functions."""
query = select(instance)
if not include_deleted:
query = query.where(sqlalchemy.func.coalesce(instance.tombstone, 0) == 0)
tombstone_col = getattr(instance, "tombstone", None)
if tombstone_col is not None:
query = query.where(sqlalchemy.func.coalesce(tombstone_col, 0) == 0)
if name:
query = query.where(instance.name.ilike(f"%{sqlalchemy.text(name).compile()}%"))
name_col = getattr(instance, "name", None)
if name_col is not None:
query = query.where(name_col.ilike(f"%{sqlalchemy.text(name).compile()}%"))
return query


Expand Down Expand Up @@ -839,7 +845,7 @@ def get_or_create_account(s: Session, name: str | Accounts) -> Accounts:
return account


def _get_budget_table(s: Session) -> type[ReflectBudgets | ZeroBudgets]:
def _get_budget_table(s: Session) -> type[ZeroBudgets] | type[ReflectBudgets]:
"""
Finds out which type of budget the user uses. The types are:

Expand All @@ -858,7 +864,7 @@ def _get_budget_table(s: Session) -> type[ReflectBudgets | ZeroBudgets]:

def get_budgets(
s: Session, month: datetime.date | None = None, category: str | Categories | None = None
) -> typing.Sequence[ZeroBudgets | ReflectBudgets]:
) -> typing.Sequence[BaseBudgets]:
"""
Returns a list of all available budgets.

Expand Down Expand Up @@ -889,7 +895,7 @@ def get_budgets(
return s.exec(query).unique().all()


def get_budget(s: Session, month: datetime.date, category: str | Categories) -> ZeroBudgets | ReflectBudgets | None:
def get_budget(s: Session, month: datetime.date, category: str | Categories) -> BaseBudgets | None:
"""
Gets an existing budget by category name, returns `None` if not found.

Expand All @@ -910,7 +916,7 @@ def create_budget(
category: str | Categories,
amount: decimal.Decimal | float | int = 0.0,
carryover: bool | None = None,
) -> ZeroBudgets | ReflectBudgets:
) -> BaseBudgets:
"""
Gets an existing budget based on the month and category. If it already exists, the amount will be replaced by
the new amount.
Expand Down Expand Up @@ -1140,6 +1146,32 @@ def get_schedules(
return s.exec(query).all()


@typing.overload
def create_schedule(
s: Session,
date: datetime.date | datetime.datetime | Schedule,
amount: tuple[decimal.Decimal, decimal.Decimal] | tuple[float, float],
amount_operation: typing.Literal["isbetween"],
name: str | None,
payee: str | Payees | None,
account: str | Accounts | None,
posts_transaction: bool,
) -> Schedules: ...


@typing.overload
def create_schedule(
s: Session,
date: datetime.date | datetime.datetime | Schedule,
amount: decimal.Decimal | float,
amount_operation: typing.Literal["is", "isapprox"],
name: str | None,
payee: str | Payees | None,
account: str | Accounts | None,
posts_transaction: bool,
) -> Schedules: ...


def create_schedule(
s: Session,
date: datetime.date | datetime.datetime | Schedule,
Expand Down Expand Up @@ -1171,9 +1203,6 @@ def create_schedule(
:param posts_transaction: Whether the schedule should auto-post transactions on your behalf. Defaults to false.
:return: Rule database object created.
"""
if amount_operation == "isbetween" and not isinstance(amount, tuple):
raise ActualError("When using 'isbetween', amount must be a tuple (num1, num2), where num1 < num2.")

schedule_id = str(uuid.uuid4())
conditions = []
# Handle the payee condition
Expand All @@ -1197,6 +1226,8 @@ def create_schedule(
)
# Handle the amount condition
if amount_operation == "isbetween":
if not isinstance(amount, tuple):
raise ActualError("When using 'isbetween', amount must be a tuple (num1, num2), where num1 < num2.")
conditions.append(
Condition(
field="amount",
Expand All @@ -1205,6 +1236,8 @@ def create_schedule(
)
)
else:
if isinstance(amount, tuple):
raise ActualError(f"When using '{amount_operation}', amount must be a single decimal number.")
conditions.append(Condition(field="amount", op=ConditionType(amount_operation), value=decimal_to_cents(amount)))

actions = [
Expand Down
16 changes: 12 additions & 4 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,6 @@ def test_schedule_is_betweeen(session):
payee = get_or_create_payee(session, "Insurance company")
# should always be paid on the first working day of the month
config = create_schedule_config(expected_date, patterns=[Pattern(1, "day")], skip_weekend=True)
# if the amount_operation="isbetween", the schedule needs two amounts
with pytest.raises(ActualError, match="amount must be a tuple"):
create_schedule(session, config, 100.0, "isbetween", "Insurance", payee, account)

schedule = create_schedule(session, config, (100.0, 110.0), "isbetween", "Insurance", payee, account)
assert json.loads(schedule.rule.conditions) == [
{"field": "description", "type": "id", "op": "is", "value": payee.id},
Expand Down Expand Up @@ -585,6 +581,18 @@ def test_schedule_populates_next_date_simple_date(session, start_date, expected_
assert rows[0].base_next_date == expected_next_date


def test_schedule_exceptions(session):
expected_date = datetime.date(2025, 10, 11)
account = create_account(session, "Bank")
payee = get_or_create_payee(session, "Insurance company")
# should always be paid on the first working day of the month
config = create_schedule_config(expected_date, patterns=[Pattern(1, "day")], skip_weekend=True)
with pytest.raises(ActualError, match="amount must be a tuple"):
create_schedule(session, config, 100.0, "isbetween", "Insurance", payee, account)
with pytest.raises(ActualError, match="amount must be a single decimal number"):
create_schedule(session, config, (100.0, 110.0), "isapprox", "Insurance", payee, account)


def test_get_transactions_with_cleared_filter(session):
acct = create_account(session, "ClearedTxs")
create_transaction(session, date=today, account=acct, amount=10, cleared=False)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

def test_get_class_by_table_name():
assert get_class_by_table_name("transactions") == Transactions
assert get_class_by_table_name("foo") is None
with pytest.raises(ValueError, match="Could not find table 'foo'"):
get_class_by_table_name("foo")


def test_get_attribute_by_table_name():
Expand Down
Loading