Skip to content

Commit 4978153

Browse files
committed
implement-transaction-listing-with-smart-filtering-and-enable-auto-tool-exposure-via-decorator
1 parent 8f5b9c2 commit 4978153

2 files changed

Lines changed: 181 additions & 35 deletions

File tree

spendee/spendee_firestore.py

Lines changed: 171 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from decimal import Decimal
2+
from typing import Callable, Dict
23
from google.auth.credentials import Credentials
34
from google.cloud import firestore
45
from google.cloud.firestore_v1.base_query import FieldFilter
@@ -7,6 +8,7 @@
78
import json
89
import datetime
910
import logging
11+
import re
1012

1113
# Improvement ideas:
1214
# - Implement token expiration check in CustomFirebaseCredentials
@@ -17,8 +19,15 @@
1719
# Note: User document has field: firestoreDataExportDone, which is True.
1820
# Which is a trace of the migration from REST API to Firestore.
1921

22+
2023
logger = logging.getLogger(__name__)
2124

25+
# MCP tool registry and decorator
26+
MCP_TOOLS: Dict[str, Callable] = {}
27+
def mcp_tool(func: Callable) -> Callable:
28+
MCP_TOOLS[func.__name__] = func
29+
return func
30+
2231
class CustomFirebaseCredentials(Credentials):
2332
"""Custom credentials that use existing Firebase access token"""
2433

@@ -56,6 +65,8 @@ def __init__(self, email: str, password: str, base_url: str = 'https://api.spend
5665
self.email = access_token_data.get('email', None)
5766
self.user_name = access_token_data.get('name', None)
5867
self.wallet_name_map = { x['name']: x['id'] for x in self.list_wallets()}
68+
self.category_name_map = { x['id']: x['name'] for x in self.list_categories()}
69+
self.label_name_map = { x['id']: x['name'] for x in self.list_labels()}
5970
logger.info(f"SpendeeFirestore initialized for user_id={self.user_id}, email={self.email}")
6071

6172
def _token_refresh(self):
@@ -107,6 +118,36 @@ def _json_serializer(self, obj):
107118
def as_json(self, obj):
108119
return json.dumps(obj, indent=2, default=self._json_serializer, ensure_ascii=False)
109120

121+
@staticmethod
122+
def _matches_filters(value, filters):
123+
for f in filters or []:
124+
field = f.get("field")
125+
op = f.get("op")
126+
filter_value = f.get("value")
127+
v = value.get(field)
128+
if op == ">":
129+
if not (v is not None and float(v) > float(filter_value)):
130+
return False
131+
elif op == ">=":
132+
if not (v is not None and float(v) >= float(filter_value)):
133+
return False
134+
elif op == "<":
135+
if not (v is not None and float(v) < float(filter_value)):
136+
return False
137+
elif op == "<=":
138+
if not (v is not None and float(v) <= float(filter_value)):
139+
return False
140+
elif op == "=":
141+
if not (str(v) == str(filter_value)):
142+
return False
143+
elif op == "~=":
144+
if not (v is not None and re.search(str(filter_value), str(v))):
145+
return False
146+
else:
147+
logger.warning(f"Unsupported filter op: {op}")
148+
return False
149+
return True
150+
110151
# --- Spendee API methods ---
111152

112153
def _get_raw_category(self, category_id: str, as_json: bool = False):
@@ -133,12 +174,16 @@ def _list_raw_categories(self, as_json: bool = False):
133174
return self.as_json(raw_data) if as_json else raw_data
134175

135176

177+
# ...existing code...
136178
def list_categories(self, as_json: bool = False):
137-
logger.info("Listing categories.")
138179
"""
139-
Returns a list of categories for the user.
180+
Returns the list of categories of the user.
140181
If as_json is True, returns the data as a JSON string.
182+
183+
Each category has the fields: id, name, type
184+
where type can be 'income' or 'expense'
141185
"""
186+
logger.info("Listing categories.")
142187
data = []
143188
for raw_data in self.client.collection(f'users/{self.user_id}/categories').get():
144189
data.append({
@@ -160,12 +205,15 @@ def _list_raw_labels(self, as_json: bool = False):
160205
return self.as_json(raw_data) if as_json else raw_data
161206

162207

208+
@mcp_tool
163209
def list_labels(self, as_json: bool = False):
164-
logger.info("Listing labels.")
165210
"""
166-
Returns a list of labels for the user.
211+
Returns the list of labels used by the user.
167212
If as_json is True, returns the data as a JSON string.
213+
214+
Each label item has the fields: id, name
168215
"""
216+
logger.info("Listing labels.")
169217
data = []
170218
for raw_data in self.client.collection(f'users/{self.user_id}/labels').get():
171219
data.append({
@@ -177,11 +225,19 @@ def list_labels(self, as_json: bool = False):
177225
return self.as_json(data) if as_json else data
178226

179227

228+
@mcp_tool
180229
def get_wallet_balance(self, wallet_id: str, start: str = None, end: str = None):
230+
"""Get the balance of a wallet for a specific timeframe.
231+
The start and end parameters should be in ISO 8601 format. If not set,
232+
no filtering is done.
233+
Args:
234+
wallet_id (str): Name of the wallet. (Should be equal to the results of list_wallets call)
235+
start (str, optional): Start date in ISO 8601 format.
236+
end (str, optional): End date in ISO 8601 format.
237+
Returns:
238+
int: The balance of the wallet.
181239
"""
182-
Returns the balance of the wallet with the given wallet_id for a specific timeframe.
183-
The start and end parameters should be in ISO 8601 format. If not set, no filtering is done.
184-
"""
240+
185241
logger.info(f"Calculating balance for wallet_id: {wallet_id}")
186242
query = self.client.collection(f'users/{self.user_id}/wallets/{wallet_id}/transactions')
187243

@@ -213,7 +269,17 @@ def get_wallet_balance(self, wallet_id: str, start: str = None, end: str = None)
213269
return round(total + Decimal(str(starting_balance)))
214270

215271

272+
@mcp_tool
216273
def list_wallets(self):
274+
"""List all wallets for the authenticated user. This is required before wallet related calls, to have exact string for names.
275+
Each wallet is represented by an object with the following fields:
276+
- id: Unique identifier of the wallet
277+
- name: Name of the wallet
278+
- type: Type of the wallet (e.g., cash, bank, etc.)
279+
- currency: Currency of the wallet
280+
- updatedAt: Last updated timestamp of the wallet
281+
"""
282+
217283
logger.info("Listing wallets.")
218284
raw_data = [
219285
x.to_dict()
@@ -242,3 +308,101 @@ def _get_raw_transaction(self, wallet_id: str, transaction_id: str, as_json: boo
242308
obj = self.client.document(f"users/{self.user_id}/wallets/{wallet_id}/transactions/{transaction_id}").get().to_dict()
243309
logger.debug(f"Fetched raw transaction content: {obj}")
244310
return self.as_json(obj) if as_json else obj
311+
312+
313+
@mcp_tool
314+
def get_transaction(self, wallet_id: str, transaction_id: str, as_json: bool = False):
315+
"""Get a specific transaction by its ID from a wallet.
316+
Args:
317+
wallet_id (str): ID of the wallet.
318+
transaction_id (str): ID of the transaction.
319+
as_json (bool, optional): If True, returns the data as a JSON string.
320+
Returns:
321+
dict or str: The transaction data, either as a dictionary or JSON string.
322+
"""
323+
value = self._get_raw_transaction(wallet_id, transaction_id)
324+
# Convert category ID to category name using self.category_name_map
325+
category_id = value.get("category", "")
326+
category_name = self.category_name_map.get(category_id, None)
327+
328+
data = {
329+
"note": value.get("note", ""),
330+
"madeAt": value.get("madeAt", ""),
331+
"category": category_name,
332+
"type": value.get("type", ""),
333+
"isPending": value.get("isPending", ""),
334+
"amount": value.get("amount", ""),
335+
}
336+
337+
logger.info(f"Getting transaction: wallet_id={wallet_id}, transaction_id={transaction_id}")
338+
return data if not as_json else self.as_json(data)
339+
340+
341+
@mcp_tool
342+
def list_transactions(self, wallet_id: str, start: str, end: str = None, filters: list = None, limit: int = 20, fields: list = ["note", "madeAt", "category", "amount"], as_json: bool = False):
343+
"""
344+
List transactions for a wallet, filtered by date range and dynamic filters.
345+
346+
Results are always sorted by 'madeAt' in descending order (most recent transactions first).
347+
Each returned item has the same fields as get_transaction: id, note, madeAt, category (name), type, isPending, amount.
348+
Optionally, the returned fields can be limited by the 'fields' parameter, default is ["note", "madeAt", "category", "amount"].
349+
350+
Args:
351+
wallet_id (str): ID of the wallet.
352+
start (str): Start date (ISO 8601, required).
353+
end (str, optional): End date (ISO 8601).
354+
filters (list, optional): List of filter dicts, e.g. [{"field": "amount", "op": ">=", "value": 100}].
355+
The 'field' and 'op' values must be strings.
356+
Supported operators: "=", "~=", ">", ">=", "<", "<=", where "~=" is regex match.
357+
limit (int, optional): Max number of transactions to return (default 20).
358+
fields (list, optional): List of field names to include in the result. Only supported fields are allowed.
359+
as_json (bool, optional): Return as JSON string if True.
360+
Returns:
361+
list or str: List of transaction dicts or JSON string.
362+
"""
363+
logger.info(f"Listing transactions for wallet_id={wallet_id}, start={start}, end={end}, filters={filters}, limit={limit}")
364+
query = self.client.collection(f'users/{self.user_id}/wallets/{wallet_id}/transactions')
365+
366+
# Required start date
367+
query = query.where(filter=FieldFilter("madeAt", ">=", datetime.datetime.fromisoformat(start)))
368+
if end:
369+
query = query.where(filter=FieldFilter("madeAt", "<=", datetime.datetime.fromisoformat(end)))
370+
371+
# Only order by 'madeAt' (descending), fetch all for post-filtering and limiting
372+
query = query.order_by("madeAt", direction=firestore.Query.DESCENDING)
373+
transactions = query.stream()
374+
375+
allowed_fields = {"id", "note", "madeAt", "category", "type", "isPending", "amount"}
376+
if not isinstance(fields, list) or not all(isinstance(f, str) for f in fields):
377+
raise ValueError("'fields' must be a list of strings.")
378+
unsupported = set(fields) - allowed_fields
379+
if unsupported:
380+
raise ValueError(f"Unsupported fields requested: {unsupported}")
381+
382+
results = []
383+
for transaction in transactions:
384+
value = transaction.to_dict()
385+
# Match get_transaction output fields
386+
category_id = value.get("category", "")
387+
category_name = self.category_name_map.get(category_id, None)
388+
if category_name is None:
389+
logger.warning(f"Category ID {category_id} not found in category_name_map, using None.")
390+
data = {
391+
"id": value.get("path", {}).get("transaction", ""),
392+
"note": value.get("note", ""),
393+
"madeAt": value.get("madeAt", ""),
394+
"category": category_name,
395+
"type": value.get("type", ""),
396+
"isPending": value.get("isPending", ""),
397+
"amount": value.get("amount", ""),
398+
}
399+
if not self._matches_filters(data, filters):
400+
continue
401+
if fields is not None:
402+
data = {k: v for k, v in data.items() if k in fields}
403+
results.append(data)
404+
if len(results) >= limit:
405+
break
406+
407+
logger.info(f"Found {len(results)} transactions.")
408+
return self.as_json(results) if as_json else results

spendee/spendee_mcp.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from starlette.responses import Response
1919
from starlette.types import Scope, Receive, Send
2020

21-
from spendee.spendee_firestore import SpendeeFirestore
21+
from spendee.spendee_firestore import SpendeeFirestore, MCP_TOOLS
2222

2323
# to start (after .venv setup):
2424
# python spendee/spendee_mcp.py
@@ -50,8 +50,15 @@
5050
logger = logging.getLogger(__name__)
5151

5252
mcp = FastMCP("spendee", host="0.0.0.0", port=PORT)
53+
5354
spendee = SpendeeFirestore(EMAIL, PASSWORD)
5455

56+
# Automatically register all MCP tools from the client
57+
for name, func in MCP_TOOLS.items():
58+
# Bind the method to the spendee instance if it's a class method
59+
bound_func = getattr(spendee, name)
60+
mcp.tool()(bound_func)
61+
5562
def main():
5663
#debug_secret(ACCEPTED_TOKEN, "MCP_TOKEN")
5764
#debug_secret(PASSWORD, "PASSWORD")
@@ -65,33 +72,6 @@ def main():
6572
logger.info("Using SSE transport")
6673
sse_server()
6774

68-
69-
@mcp.tool()
70-
def get_wallet_balance(wallet_id: str, start: str = None, end: str = None):
71-
"""Get the balance of a wallet for a specific timeframe.
72-
The start and end parameters should be in ISO 8601 format. If not set,
73-
no filtering is done.
74-
Args:
75-
wallet_id (str): Name of the wallet. (Should be equal to the results of list_wallets call)
76-
start (str, optional): Start date in ISO 8601 format.
77-
end (str, optional): End date in ISO 8601 format.
78-
Returns:
79-
int: The balance of the wallet.
80-
"""
81-
return spendee.get_wallet_balance(wallet_id, start, end)
82-
83-
@mcp.tool()
84-
def list_wallets():
85-
"""List all wallets for the authenticated user. This is required before wallet related calls, to have exact string for names.
86-
Each wallet is represented by an object with the following fields:
87-
- id: Unique identifier of the wallet
88-
- name: Name of the wallet
89-
- type: Type of the wallet (e.g., cash, bank, etc.)
90-
- currency: Currency of the wallet
91-
- updatedAt: Last updated timestamp of the wallet
92-
"""
93-
return spendee.list_wallets()
94-
9575
# Authentication middleware and server setup
9676

9777
def debug_secret(secret, name):
@@ -191,3 +171,5 @@ async def sse_auth_middleware(scope: Scope, receive: Receive, send: Send):
191171

192172
if __name__ == "__main__":
193173
main()
174+
else:
175+
main()

0 commit comments

Comments
 (0)