Skip to content

Commit 027557e

Browse files
committed
improve-typing-and-simplified-mcp-call-logic
1 parent 025c59e commit 027557e

1 file changed

Lines changed: 152 additions & 80 deletions

File tree

spendee/spendee_firestore.py

Lines changed: 152 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from decimal import Decimal
2-
from typing import Callable, Dict, List, Union, Any
2+
from typing import Callable, Dict, List, Union, Any, Optional, Literal
33
from google.auth.credentials import Credentials
44
from google.cloud import firestore
55
from google.cloud.firestore_v1.base_query import FieldFilter
@@ -10,6 +10,7 @@
1010
import logging
1111
import re
1212
import uuid
13+
from pydantic import BaseModel
1314

1415
# Improvement ideas:
1516
# - Implement token expiration check in CustomFirebaseCredentials
@@ -29,6 +30,17 @@ def mcp_tool(func: Callable) -> Callable:
2930
MCP_TOOLS[func.__name__] = func
3031
return func
3132

33+
34+
class Category(BaseModel):
35+
id: str
36+
name: str
37+
type: Literal['income', 'expense']
38+
39+
class Label(BaseModel):
40+
id: str
41+
name: str
42+
43+
3244
class CustomFirebaseCredentials(Credentials):
3345
"""Custom credentials that use existing Firebase access token"""
3446

@@ -72,9 +84,9 @@ def __init__(self, email: str, password: str, base_url: str = 'https://api.spend
7284
# Initialize mappings
7385
self.wallet_name_map = { x['name']: x['id'] for x in self.list_wallets()}
7486
categories = self.list_categories()
75-
self.category_name_map = { x['id']: x['name'] for x in categories}
76-
self.category_type_map = { x['id']: x['type'] for x in categories}
77-
self.label_name_map = { x['id']: x['name'] for x in self.list_labels()}
87+
self.category_name_map = { x.id: x.name for x in categories}
88+
self.category_type_map = { x.id: x.type for x in categories}
89+
self.label_name_map = { x.id: x.name for x in self.list_labels()}
7890
logger.info(f"SpendeeFirestore initialized for user_id={self.user_id}, email={self.email}")
7991

8092
def _token_refresh(self):
@@ -126,40 +138,97 @@ def as_json(self, obj):
126138
return json.dumps(obj, indent=2, default=self._json_serializer, ensure_ascii=False)
127139

128140
@staticmethod
129-
def _matches_filters(value, filters):
130-
for f in filters or []:
131-
if "field" not in f or "op" not in f or "value" not in f:
132-
raise ValueError(f"Invalid filter format: {f}")
133-
field = f.get("field")
134-
op = f.get("op")
135-
filter_value = f.get("value")
141+
def _matches_filters(value: Dict, filters: Dict[str,Union[str,float]]) -> bool:
142+
"""
143+
Evaluate filters provided as a dict with keys like "field__suffix": value.
144+
145+
Supported suffixes: __eq, __regex, __gt, __gte, __lt, __lte, __contains, __not_contains
146+
Example: {"amount__gte": 100, "note__regex": "[mM][cC][dD]"}
147+
"""
148+
if not filters:
149+
return True
150+
151+
if not isinstance(filters, dict):
152+
raise ValueError("filters must be a dict of the form {'field__op': value}")
153+
154+
supported_suffixes = {
155+
"__eq",
156+
"__regex",
157+
"__gt",
158+
"__gte",
159+
"__lt",
160+
"__lte",
161+
"__contains",
162+
"__not_contains",
163+
}
164+
165+
for key, filter_value in filters.items():
166+
if not isinstance(key, str):
167+
raise ValueError(f"Filter key must be a string: {key}")
168+
169+
# split into field and suffix
170+
if "__" not in key:
171+
raise ValueError(f"Invalid filter key format (expected 'field__op'): {key}")
172+
field, suffix = key.split("__", 1)
173+
suffix = f"__{suffix}"
174+
if suffix not in supported_suffixes:
175+
raise ValueError(f"Unsupported filter operation suffix: {suffix} in {key}")
176+
136177
v = value.get(field)
137-
if op == "array-contains":
138-
if field != "labels":
139-
raise ValueError("array-contains operator is only supported for 'labels' field")
140-
if filter_value not in v:
178+
179+
# Unified "contains" and "not contains" for both lists and strings
180+
if suffix in ("__contains", "__not_contains"):
181+
if not v:
182+
# For __contains, missing value means no match; for __not_contains, treat as not present
183+
if suffix == "__contains":
184+
return False
185+
# Accept str or list-like (tuple, set)
186+
if not isinstance(v, (str, list, tuple, set)):
187+
raise ValueError(f"Field '{field}' value must be str or list-like for '{suffix}' filter, got {type(v)}")
188+
contains = filter_value in v
189+
if suffix == "__contains" and not contains:
190+
return False
191+
if suffix == "__not_contains" and contains:
141192
return False
142-
elif op == ">":
143-
if not (v is not None and float(v) > float(filter_value)):
193+
# Regex match
194+
elif suffix == "__regex":
195+
if not v:
144196
return False
145-
elif op == ">=":
146-
if not (v is not None and float(v) >= float(filter_value)):
197+
if not re.search(str(filter_value), str(v)):
147198
return False
148-
elif op == "<":
149-
if not (v is not None and float(v) < float(filter_value)):
199+
# Equality
200+
elif suffix == "__eq":
201+
# If both values are numbers, allow a small error margin (epsilon)
202+
try:
203+
left = float(v)
204+
right = float(filter_value)
205+
epsilon = 1e-4
206+
if abs(left - right) > epsilon:
207+
return False
208+
except (ValueError, TypeError):
209+
if not (str(v) == str(filter_value)):
210+
return False
211+
# Numeric comparisons - try to cast to float
212+
elif suffix in ("__gt", "__gte", "__lt", "__lte"):
213+
if not v:
150214
return False
151-
elif op == "<=":
152-
if not (v is not None and float(v) <= float(filter_value)):
215+
try:
216+
left = float(v)
217+
right = float(filter_value)
218+
except Exception:
153219
return False
154-
elif op == "=":
155-
if not (str(v) == str(filter_value)):
220+
if suffix == "__gt" and not (left > right):
156221
return False
157-
elif op == "~=":
158-
if not (v is not None and re.search(str(filter_value), str(v))):
222+
if suffix == "__gte" and not (left >= right):
223+
return False
224+
if suffix == "__lt" and not (left < right):
225+
return False
226+
if suffix == "__lte" and not (left <= right):
159227
return False
160228
else:
161-
logger.warning(f"Unsupported filter op: {op}")
162-
return False
229+
# Shouldn't reach here because we validated suffixes above
230+
raise ValueError(f"Unhandled filter suffix: {suffix}")
231+
163232
return True
164233

165234
# --- Spendee API methods ---
@@ -189,7 +258,7 @@ def _list_raw_categories(self, as_json: bool = False):
189258

190259

191260
@mcp_tool
192-
def list_categories(self) -> List[Dict[str, Any]]:
261+
def list_categories(self) -> List[Category]:
193262
"""
194263
Returns the list of categories of the user.
195264
@@ -202,11 +271,11 @@ def list_categories(self) -> List[Dict[str, Any]]:
202271
logger.info("Listing categories.")
203272
data = []
204273
for raw_data in self.client.collection(f'users/{self.user_id}/categories').get():
205-
data.append({
206-
'id': str(raw_data.get('path').get('category')),
207-
'name': str(raw_data.get('name')),
208-
'type': str(raw_data.get('type')),
209-
})
274+
data.append(Category(
275+
id=str(raw_data.get('path').get('category')),
276+
name=str(raw_data.get('name')),
277+
type=str(raw_data.get('type')),
278+
))
210279
logger.debug(f"Fetched categories content: {data}")
211280
return data
212281

@@ -222,7 +291,7 @@ def _list_raw_labels(self, as_json: bool = False):
222291

223292

224293
@mcp_tool
225-
def list_labels(self) -> List[Dict[str, str]]:
294+
def list_labels(self) -> List[Label]:
226295
"""
227296
Returns the list of labels used by the user.
228297
@@ -231,39 +300,32 @@ def list_labels(self) -> List[Dict[str, str]]:
231300
logger.info("Listing labels.")
232301
data = []
233302
for raw_data in self.client.collection(f'users/{self.user_id}/labels').get():
234-
data.append({
235-
'id': raw_data.get('path').get('label'),
236-
# attention: here I switch from fieldName 'text' to 'name', for consistency with categories
237-
'name': raw_data.get('text'),
238-
})
303+
data.append(Label(
304+
id=raw_data.get('path').get('label'),
305+
name=raw_data.get('text'),
306+
))
239307
logger.debug(f"Fetched labels content: {data}")
240308
return data
241309

242310

243311
@mcp_tool
244-
def get_wallet_balance(self, wallet_id: str, start: str = None, end: str = None) -> Decimal:
245-
"""Get the balance of a wallet for a specific timeframe.
246-
The start and end parameters should be in ISO 8601 format. If not set,
247-
no filtering is done.
312+
def get_wallet_balance(self, wallet_id: str, date: str = "") -> Decimal:
313+
"""Get the balance of a wallet on the given date.
314+
The date parameter should be in ISO 8601 format.
248315
Args:
249316
wallet_id (str): Name of the wallet. (Should be equal to the results of list_wallets call)
250-
start (str, optional): Start date in ISO 8601 format.
251-
end (str, optional): End date in ISO 8601 format.
317+
date (str, optional): Date in ISO 8601 format.
252318
Returns:
253319
Decimal: The balance of the wallet.
254320
"""
255321

256322
logger.info(f"Calculating balance for wallet_id: {wallet_id}")
257323
query = self.client.collection(f'users/{self.user_id}/wallets/{wallet_id}/transactions')
258324

259-
if start:
260-
query = query.where(filter=FieldFilter("madeAt", ">=", datetime.datetime.fromisoformat(start)))
261-
starting_balance = 0
262-
else:
263-
starting_balance = self.client.document(f'users/{self.user_id}/wallets/{wallet_id}').get().to_dict()['startingBalance']
325+
starting_balance = self.client.document(f'users/{self.user_id}/wallets/{wallet_id}').get().to_dict()['startingBalance']
264326

265-
if end:
266-
query = query.where(filter=FieldFilter("madeAt", "<=", datetime.datetime.fromisoformat(end)))
327+
if date:
328+
query = query.where(filter=FieldFilter("madeAt", "<=", datetime.datetime.fromisoformat(date)))
267329
query = query.order_by("madeAt")
268330
transactions = query.stream()
269331
# Could not use aggregation queries here, because each transaction has a different exchange rate, which needs multiplication.
@@ -272,13 +334,14 @@ def get_wallet_balance(self, wallet_id: str, start: str = None, end: str = None)
272334
total = 0
273335
for transaction in transactions:
274336
data = transaction.to_dict()
275-
usd_value = data.get('usdValue', {})
337+
# usd_value = data.get('usdValue', {})
276338

277-
amount = Decimal(str(usd_value.get('amount', '0')))
278-
exchange_rate = Decimal(str(usd_value.get('exchangeRate', '0')))
339+
# amount = Decimal(str(usd_value.get('amount', '0')))
340+
# exchange_rate = Decimal(str(usd_value.get('exchangeRate', '0')))
279341

280-
converted_value = amount * exchange_rate
281-
total += converted_value
342+
# converted_value = amount * exchange_rate
343+
#total += converted_value
344+
total += Decimal(str(data.get('amount', '0')))
282345

283346
logger.info(f"Total balance calculated: {total + Decimal(str(starting_balance))}")
284347
return round(total + Decimal(str(starting_balance)))
@@ -394,7 +457,7 @@ def get_transaction(self, wallet_id: str, transaction_id: str, resolve_category:
394457
return data
395458

396459

397-
def _list_raw_transactions(self, wallet_id: str, start: str, end: str = None, filters: list = None, limit: int = 20, resolve_labels: bool = True, resolve_category: bool = True, as_json: bool = False):
460+
def _list_raw_transactions(self, wallet_id: str, start: str, end: str = "", filters: Optional[Dict[str,Union[str,float]]] = {}, limit: int = 20, resolve_labels: bool = True, resolve_category: bool = True, as_json: bool = False):
398461
"""
399462
List raw transactions for a wallet, filtered by date range and dynamic filters.
400463
@@ -405,9 +468,10 @@ def _list_raw_transactions(self, wallet_id: str, start: str, end: str = None, fi
405468
wallet_id (str): UUID of the wallet.
406469
start (str): Start date (ISO 8601, required).
407470
end (str, optional): End date (ISO 8601).
408-
filters (list, optional): List of filter dicts, e.g. [{"field": "amount", "op": ">=", "value": 100}].
409-
The 'field' and 'op' values must be strings.
410-
Supported operators: "=", "~=", ">", ">=", "<", "<=", "array-contains", where "~=" is regex match and "array-contains" is for labels only.
471+
filters (list, optional): List of filter dicts, e.g. [{"amount__gte": 100}].
472+
The field__op keys must be strings, and value can be str or float.
473+
Supported suffixes: "__eq", "__regex", "__gt", "__gte", "__lt", "__lte", "__contains", "__not_contains".
474+
Where "__regex" is regex match and "__contains", "__not_contains" is for string only.
411475
limit (int, optional): Max number of transactions to return (default 20).
412476
as_json (bool, optional): Return as JSON string if True.
413477
Returns:
@@ -452,7 +516,7 @@ def _list_raw_transactions(self, wallet_id: str, start: str, end: str = None, fi
452516
if not self._matches_filters(transaction, filters):
453517
continue
454518
results.append(transaction)
455-
if limit and len(results) >= limit:
519+
if (limit or limit != 0) and len(results) >= limit:
456520
break
457521

458522
logger.info(f"Found {len(results)} raw transactions.")
@@ -462,10 +526,10 @@ def _list_raw_transactions(self, wallet_id: str, start: str, end: str = None, fi
462526
def list_transactions(self,
463527
wallet_id: str,
464528
start: str,
465-
end: str = None,
466-
filters: list = None,
467-
limit: int = 20,
468-
fields: list = ["note", "madeAt", "category", "amount", "labels"]) -> List[Dict[str, Any]]:
529+
end: Optional[str] = "",
530+
filters: Optional[Dict[str,Union[str,float]]] = {},
531+
limit: Optional[int] = 20,
532+
fields: Optional[list[str]] = ["note", "madeAt", "category", "amount", "labels"]) -> List[Dict[str, Any]]:
469533
"""
470534
List transactions for a wallet, filtered by date range and dynamic filters.
471535
@@ -482,16 +546,16 @@ def list_transactions(self,
482546
- isPending (bool): Whether the transaction is pending
483547
- type (str): Wether the transactions is one of "regular" or "transfer" (across wallets, or out-of-spendee).
484548
- amount (int): Amount of the transaction in the wallet's currency
485-
- labels (list): List of label names associated with the transaction. Not supported in filters.
549+
- labels (str): List of label names associated with the transaction joined by commas into a string, orderd alphabetically.
486550
487551
Args:
488552
wallet_id (str): UUID of the wallet.
489553
start (str): Start date (ISO 8601, required).
490554
end (str, optional): End date (ISO 8601).
491-
filters (list, optional): List of filter dicts, e.g. [{"field": "amount", "op": ">=", "value": 100}].
492-
The 'field' and 'op' values must be strings.
493-
Supported operators: "=", "~=", ">", ">=", "<", "<=", "array-contains", where "~=" is regex match and "array-contains" is for labels only.
494-
If not provided, use an empty list.
555+
filters (list, optional): List of filter dicts, e.g. [{"amount__gte": 100}].
556+
The field__op keys must be strings, and value can be str or float.
557+
Supported suffixes: "__eq", "__regex", "__gt", "__gte", "__lt", "__lte", "__contains", "__not_contains".
558+
Where "__regex" is regex match and "__contains", "__not_contains" is for string only.
495559
limit (int, optional): Max number of transactions to return (default 20).
496560
fields (list, optional): List of field names to include in the result. Only supported fields are allowed.
497561
Returns:
@@ -516,7 +580,7 @@ def list_transactions(self,
516580
# Create processed transaction data matching get_transaction output fields
517581
data = {
518582
"id": raw_transaction.get("id", ""),
519-
"labels": raw_transaction.get("labels", []),
583+
"labels": ",".join(sorted(raw_transaction.get("labels", []))),
520584
"note": raw_transaction.get("note", ""),
521585
"madeAt": raw_transaction.get("madeAt", ""),
522586
"category": raw_transaction.get("category", ""),
@@ -535,26 +599,34 @@ def list_transactions(self,
535599
return results
536600

537601
@mcp_tool
538-
def aggregate_transactions(self, wallet_id: str, start: str, end: str = None, filters: list = []) -> float:
602+
def aggregate_transactions(self, wallet_id: str, start: str, end: str = "", filters: Optional[Dict[str,Union[str,float]]] = {}) -> float:
539603
"""
540604
Aggregate transactions for a wallet, filtered by date range and dynamic filters.
541605
Returns the total sum of amounts of the matching transactions.
542606
607+
Example call:
608+
aggregate_transactions(
609+
wallet_id="3fcc0060-d3f2-42fb-9001-9a467f95d1b0",
610+
start="2023-01-01",
611+
end="2023-12-31",
612+
filters={"amount__gte": 100}
613+
)
614+
543615
Args:
544616
wallet_id (str): UUID of the wallet.
545617
start (str): Start date (ISO 8601, required).
546618
end (str, optional): End date (ISO 8601).
547-
filters (list, optional): List of filter dicts, e.g. [{"field": "amount", "op": ">=", "value": 100}].
548-
The 'field' and 'op' values must be strings.
549-
Supported operators: "=", "~=", ">", ">=", "<", "<=", "array-contains", where "~=" is regex match and "array-contains" is for labels only.
550-
If not provided, use an empty list.
619+
filters (list, optional): List of filter dicts, e.g. {"amount__gte": 100}.
620+
The field__op keys must be strings, and value can be str or float.
621+
Supported suffixes: "__eq", "__regex", "__gt", "__gte", "__lt", "__lte", "__contains", "__not_contains".
622+
Where "__regex" is regex match and "__contains", "__not_contains" is for string only.
551623
Returns:
552624
float: The total sum of amounts of the matching transactions.
553625
"""
554626
logger.info(f"Aggregating transactions for wallet_id={wallet_id}, start={start}, end={end}, filters={filters}")
555627

556628
# Get raw transactions first
557-
raw_transactions = self.list_transactions(wallet_id, start, end, filters, limit=None, fields=["amount"])
629+
raw_transactions = self.list_transactions(wallet_id, start, end, filters, limit=0, fields=["amount"])
558630

559631
# Sum the amounts
560632
total_amount = sum(float(tx.get("amount", 0)) for tx in raw_transactions)

0 commit comments

Comments
 (0)