11from decimal import Decimal
2- from typing import Callable , Dict
2+ from typing import Callable , Dict , List
33from google .auth .credentials import Credentials
44from google .cloud import firestore
55from google .cloud .firestore_v1 .base_query import FieldFilter
6- from google .oauth2 import service_account
76from .firebase_client import FirebaseClient
87
98import json
109import datetime
1110import logging
1211import re
13- import os
1412import uuid
1513
1614# Improvement ideas:
@@ -130,13 +128,18 @@ def as_json(self, obj):
130128 @staticmethod
131129 def _matches_filters (value , filters ):
132130 for f in filters or []:
131+ if "field" not in f or "op" not in f or "value" not in f :
132+ raise ValueError (f"Invalid filter format: { f } " )
133133 field = f .get ("field" )
134- if field == "labels" :
135- continue
136134 op = f .get ("op" )
137135 filter_value = f .get ("value" )
138136 v = value .get (field )
139- if op == ">" :
137+ if op == "array-contains" :
138+ if field != "labels" :
139+ raise ValueError ("array-contains operator is only supported for 'labels' field" )
140+ if filter_value not in v :
141+ return False
142+ elif op == ">" :
140143 if not (v is not None and float (v ) > float (filter_value )):
141144 return False
142145 elif op == ">=" :
@@ -217,7 +220,7 @@ def _list_raw_labels(self, as_json: bool = False):
217220
218221
219222 @mcp_tool
220- def list_labels (self , as_json : bool = False ):
223+ def list_labels (self , as_json : bool = False ) -> list [ Dict [ str , str ]] :
221224 """
222225 Returns the list of labels used by the user.
223226 If as_json is True, returns the data as a JSON string.
@@ -401,7 +404,7 @@ def _list_raw_transactions(self, wallet_id: str, start: str, end: str = None, fi
401404 end (str, optional): End date (ISO 8601).
402405 filters (list, optional): List of filter dicts, e.g. [{"field": "amount", "op": ">=", "value": 100}].
403406 The 'field' and 'op' values must be strings.
404- Supported operators: "=", "~=", ">", ">=", "<", "<=", where "~=" is regex match.
407+ Supported operators: "=", "~=", ">", ">=", "<", "<=", "array-contains", where "~=" is regex match and "array-contains" is for labels only .
405408 limit (int, optional): Max number of transactions to return (default 20).
406409 as_json (bool, optional): Return as JSON string if True.
407410 Returns:
@@ -417,26 +420,36 @@ def _list_raw_transactions(self, wallet_id: str, start: str, end: str = None, fi
417420
418421 # Only order by 'madeAt' (descending), fetch all for post-filtering and limiting
419422 query = query .order_by ("madeAt" , direction = firestore .Query .DESCENDING )
423+ #query._all_descendants = True
420424 transactions = query .stream ()
421425
422426 results = []
427+ stored_transactions = []
423428 for transaction in transactions :
424429 value = transaction .to_dict ()
425430 # Add the transaction ID to the raw data for consistency
426431 value ["id" ] = value .get ("path" , {}).get ("transaction" , "" )
427- if resolve_category :
432+ category_id = value .get ("category" , None )
433+ if resolve_category and category_id is not None :
428434 # Resolve category ID to category name
429- category_id = value .get ("category" , None )
430- category_name = self .category_name_map .get (category_id , None )
431- if category_name is None :
432- logger .warning (f"Category ID { category_id } not found in category_name_map, using None." )
433-
434- if resolve_labels :
435- value ["labels" ] = self ._get_transation_labels (wallet_id , value ["id" ], resolve_names = True )
436- if not self ._matches_filters (value , filters ):
435+ value ["category" ] = self .category_name_map .get (category_id , None )
436+ stored_transactions .append (value )
437+
438+ labels_query = self .client .collection_group ('transactionLabels' )
439+ labels_query = labels_query .where (filter = FieldFilter ("path.user" , "==" , self .user_id ))
440+ transactionLabels = [x .to_dict () for x in labels_query .stream ()]
441+
442+ for transaction in stored_transactions :
443+ transaction ["labels" ] = [
444+ self .label_name_map .get (x .get ('label' ), None ) if resolve_labels else x .get ('label' )
445+ for x in transactionLabels
446+ if x .get ('path' , {}).get ('transaction' ) == transaction .get ('id' )
447+ ]
448+
449+ if not self ._matches_filters (transaction , filters ):
437450 continue
438- results .append (value )
439- if len (results ) >= limit :
451+ results .append (transaction )
452+ if limit and len (results ) >= limit :
440453 break
441454
442455 logger .info (f"Found { len (results )} raw transactions." )
@@ -455,18 +468,28 @@ def list_transactions(self,
455468 List transactions for a wallet, filtered by date range and dynamic filters.
456469
457470 Results are always sorted by 'madeAt' in descending order (most recent transactions first).
458- Each returned item has the same fields as get_transaction: id, note, madeAt, category (name), type, isPending , amount.
471+ Each returned item has the same fields as get_transaction: id, note, madeAt, category (name), isPending, type , amount, labels .
459472 Optionally, the returned fields can be limited by the 'fields' parameter, default is ["note", "madeAt", "category", "amount"].
460473 Category IDs are automatically resolved to category names.
461474
475+ Fields of a transaction (no other fields are supported!):
476+ - id (str): UUID of the transaction
477+ - note (str): Transaction note/description
478+ - madeAt (str): ISO 8601 timestamp of when the transaction was made
479+ - category (str): Category name
480+ - isPending (bool): Whether the transaction is pending
481+ - type (str): Wether the transactions is one of "regular" or "transfer" (across wallets, or out-of-spendee).
482+ - amount (int): Amount of the transaction in the wallet's currency
483+ - labels (list): List of label names associated with the transaction. Not supported in filters.
484+
462485 Args:
463486 wallet_id (str): UUID of the wallet.
464487 start (str): Start date (ISO 8601, required).
465488 end (str, optional): End date (ISO 8601).
466489 filters (list, optional): List of filter dicts, e.g. [{"field": "amount", "op": ">=", "value": 100}].
467490 The 'field' and 'op' values must be strings.
468- Supported operators: "=", "~=", ">", ">=", "<", "<=", where "~=" is regex match.
469- Does not support filtering by labels .
491+ Supported operators: "=", "~=", ">", ">=", "<", "<=", "array-contains", where "~=" is regex match and "array-contains" is for labels only .
492+ If not provided, use an empty list .
470493 limit (int, optional): Max number of transactions to return (default 20).
471494 fields (list, optional): List of field names to include in the result. Only supported fields are allowed.
472495 as_json (bool, optional): Return as JSON string if True.
@@ -510,6 +533,34 @@ def list_transactions(self,
510533 logger .info (f"Found { len (results )} transactions." )
511534 return self .as_json (results ) if as_json else results
512535
536+ @mcp_tool
537+ def aggregate_transactions (self , wallet_id : str , start : str , end : str = None , filters : list = []) -> float :
538+ """
539+ Aggregate transactions for a wallet, filtered by date range and dynamic filters.
540+ Returns the total sum of amounts of the matching transactions.
541+
542+ Args:
543+ wallet_id (str): UUID of the wallet.
544+ start (str): Start date (ISO 8601, required).
545+ end (str, optional): End date (ISO 8601).
546+ filters (list, optional): List of filter dicts, e.g. [{"field": "amount", "op": ">=", "value": 100}].
547+ The 'field' and 'op' values must be strings.
548+ Supported operators: "=", "~=", ">", ">=", "<", "<=", "array-contains", where "~=" is regex match and "array-contains" is for labels only.
549+ If not provided, use an empty list.
550+ Returns:
551+ int: The total sum of amounts of the matching transactions.
552+ """
553+ logger .info (f"Aggregating transactions for wallet_id={ wallet_id } , start={ start } , end={ end } , filters={ filters } " )
554+
555+ # Get raw transactions first
556+ raw_transactions = self .list_transactions (wallet_id , start , end , filters , limit = None , fields = ["amount" ], as_json = False )
557+
558+ # Sum the amounts
559+ total_amount = sum (float (tx .get ("amount" , 0 )) for tx in raw_transactions )
560+
561+ logger .info (f"Total aggregated amount: { total_amount } " )
562+ return total_amount
563+
513564 def _get_transation_labels (self , wallet_id : str , transaction_id : str , resolve_names : bool = True , as_json : bool = False ):
514565 """
515566 Get the labels of a specific transaction by its ID from a wallet.
0 commit comments