diff --git a/CLAUDE.md b/CLAUDE.md index 461a04a..9afb8b2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -2,7 +2,7 @@ ## Project Overview -FastFetchBot is a social media content fetching service built as a **UV workspace monorepo** with four microservices: a FastAPI server (API), a Telegram Bot client, a Celery worker for file operations, and an ARQ-based async worker for off-path scraping. It scrapes and archives content from various social media platforms including Twitter, Weibo, Xiaohongshu, Reddit, Bluesky, Instagram, Zhihu, Douban, YouTube, and Bilibili. +FastFetchBot is a social media content fetching service built as a **UV workspace monorepo** with four microservices: a FastAPI server (API), a Telegram Bot client, a Celery worker for file operations, and an ARQ-based async worker for off-path scraping and file_id persistence. It scrapes and archives content from various social media platforms including Twitter, Weibo, Xiaohongshu, Reddit, Bluesky, Instagram, Zhihu, Douban, YouTube, and Bilibili. ## Architecture @@ -36,7 +36,7 @@ FastFetchBot/ ├── apps/api/ # FastAPI server: enriched service, routing, storage ├── apps/telegram-bot/ # Telegram Bot: webhook/polling, message handling ├── apps/worker/ # Celery worker: sync file operations (video, PDF, audio) -├── apps/async-worker/ # ARQ async worker: off-path scraping + enrichment +├── apps/async-worker/ # ARQ async worker: off-path scraping + enrichment + file_id persistence ├── pyproject.toml # Root workspace configuration └── uv.lock # Lockfile for the entire workspace ``` @@ -70,21 +70,27 @@ The Telegram Bot communicates with the API server over HTTP (`API_SERVER_URL`). - **`api_client.py`** — HTTP client calling the API server - **`queue_client.py`** — ARQ Redis client for enqueuing scrape jobs (queue mode) - **`handlers/`** — `messages.py`, `buttons.py`, `url_process.py`, `commands.py` (start, /settings with inline toggles) -- **`services/`** — `bot_app.py`, `message_sender.py`, `user_settings.py` (get/toggle `auto_fetch_in_dm` and `force_refresh_cache`), `constants.py` +- **`services/`** — `bot_app.py`, `message_sender.py` (media packaging with file_id shortcut + background file_id capture), `file_id_capture.py` (extracts file_ids from sent messages, pushes to Redis for async worker persistence), `user_settings.py` (get/toggle `auto_fetch_in_dm` and `force_refresh_cache`), `constants.py` - **`webhook/server.py`** — Webhook/polling server - **`templates/`** — Jinja2 templates for bot messages +### Async Worker (`apps/async-worker/async_worker/`) + +- **`main.py`** — ARQ worker entry point with `on_startup`/`on_shutdown` hooks for MongoDB and file_id consumer lifecycle +- **`config.py`** — `AsyncWorkerSettings` with MongoDB, Redis, and runtime flags (`file_id_consumer_ready`) +- **`services/file_id_consumer.py`** — Background Redis BRPOP consumer for `fileid:updates` queue. Receives file_id payloads from the Telegram bot, matches media URLs to the latest `Metadata` document in MongoDB, and persists `telegram_file_id` values. Lifecycle managed via `start()`/`stop()` during worker startup/shutdown when `DATABASE_ON` is true + ### Shared Library (`packages/shared/fastfetchbot_shared/`) - **`config.py`** — URL patterns (SOCIAL_MEDIA_WEBSITE_PATTERNS, VIDEO_WEBSITE_PATTERNS, BANNED_PATTERNS); shared env vars including `SIGN_SERVER_URL` and `XHS_COOKIE_PATH` -- **`models/`** — `classes.py` (NamedBytesIO), `metadata_item.py` (MediaFile, MetadataItem, MessageType), `telegraph_item.py`, `url_metadata.py` +- **`models/`** — `classes.py` (NamedBytesIO), `metadata_item.py` (MediaFile with optional `telegram_file_id` for Telegram file_id caching, MetadataItem, MessageType), `telegraph_item.py`, `url_metadata.py` - **`utils/`** — `parse.py` (URL parsing, HTML processing, `get_env_bool`), `image.py`, `logger.py`, `network.py`, `cookie.py` - **`database/`** — Dual database layer: - **SQLAlchemy** (user settings): `base.py`, `engine.py`, `session.py`, `models/user_setting.py` — `UserSetting` model with `auto_fetch_in_dm` and `force_refresh_cache` toggles. Supports SQLite (dev) and PostgreSQL (prod) via `SETTINGS_DATABASE_URL`. Alembic migrations in `packages/shared/alembic/` - **`database/mongodb/`** — Beanie ODM for scraped content persistence, shared across API and async worker: - **`connection.py`** — `init_mongodb(mongodb_url, db_name)`, `close_mongodb()`, `save_instances()`. Parameterized — each app passes its own config at startup - **`cache.py`** — MongoDB-backed URL cache: `find_cached(url, ttl_seconds)` returns the latest versioned document if within TTL (0 = never expire); `save_metadata(metadata_item)` auto-increments `version` for the same URL before inserting - - **`models/metadata.py`** — `Metadata(Document)` with fields: title, url, author, content, media_files, telegraph_url, timestamp, version, etc. `DatabaseMediaFile(MediaFile)` extends the scraper `MediaFile` dataclass with `file_key` for S3 storage. Compound index on `(url, version)` for efficient cache lookups. `@before_event(Insert)` hook auto-computes text lengths and converts `MediaFile` → `DatabaseMediaFile` + - **`models/metadata.py`** — `Metadata(Document)` with fields: title, url, author, content, media_files, telegraph_url, timestamp, version, etc. `DatabaseMediaFile(MediaFile)` extends the scraper `MediaFile` dataclass with `file_key` for S3 storage and inherits `telegram_file_id` for Telegram file_id caching. Compound index on `(url, version)` for efficient cache lookups. `@before_event(Insert)` hook auto-computes text lengths and converts `MediaFile` → `DatabaseMediaFile`. Custom `bson_encoders = {DatabaseMediaFile: asdict}` ensures proper BSON serialization of pydantic dataclasses - **`__init__.py`** — Re-exports: `init_mongodb`, `close_mongodb`, `save_instances`, `find_cached`, `save_metadata`, `Metadata`, `DatabaseMediaFile` - **`services/scrapers/`** — All platform scrapers, fully decoupled from FastAPI: - **`config.py`** — All scraper env vars: platform credentials (Twitter, Bluesky, Weibo, XHS, Zhihu, Reddit, Instagram), Firecrawl/Zyte config, OpenAI key, Telegraph tokens, `JINJA2_ENV`, cookie file loading. Configurable `CONF_DIR` for cookie/config files @@ -192,6 +198,13 @@ See `template.env` for a complete reference. Key variables: MongoDB models and connection logic live in `packages/shared/fastfetchbot_shared/database/mongodb/`. Both the API server and async worker use the same shared ODM layer. The `Metadata` Beanie Document stores scraped content with versioning — each re-scrape of the same URL increments the `version` field. The cache system (`find_cached` / `save_metadata`) queries the latest version and checks TTL before deciding to re-scrape. +**Telegram file_id caching** — automatic when `DATABASE_ON` is true and `SCRAPE_MODE` is `queue`: +- When the bot sends media to users, it extracts Telegram `file_id` values from the `send_media_group` response +- File_ids are pushed to Redis queue `fileid:updates` via the bot's `file_id_capture` module (fire-and-forget background task) +- The async worker's `file_id_consumer` processes the queue and persists file_ids to the corresponding `Metadata.media_files[*].telegram_file_id` in MongoDB +- On subsequent cache hits, `media_files_packaging` uses stored file_ids directly via `InputMediaPhoto(file_id)` etc., skipping HTTP download entirely +- The bot has no direct MongoDB access — all database writes go through the async worker via Redis + **SQLite/PostgreSQL** (user settings — always enabled for the Telegram bot): | Variable | Default | Description | |----------|---------|-------------| diff --git a/apps/async-worker/async_worker/config.py b/apps/async-worker/async_worker/config.py index 63dca52..d98ac5f 100644 --- a/apps/async-worker/async_worker/config.py +++ b/apps/async-worker/async_worker/config.py @@ -34,6 +34,9 @@ class AsyncWorkerSettings(BaseSettings): # Timeout DOWNLOAD_VIDEO_TIMEOUT: int = 600 + # Runtime flag (not loaded from env; set at startup) + file_id_consumer_ready: bool = False + @model_validator(mode="after") def _resolve_derived(self) -> "AsyncWorkerSettings": if not self.MONGODB_URL: diff --git a/apps/async-worker/async_worker/main.py b/apps/async-worker/async_worker/main.py index 4f599b4..dcc1cb7 100644 --- a/apps/async-worker/async_worker/main.py +++ b/apps/async-worker/async_worker/main.py @@ -64,8 +64,18 @@ async def on_startup(ctx: dict) -> None: await init_mongodb(settings.MONGODB_URL) + from async_worker.services import file_id_consumer + + await file_id_consumer.start() + settings.file_id_consumer_ready = True + @staticmethod async def on_shutdown(ctx: dict) -> None: + if settings.file_id_consumer_ready: + from async_worker.services import file_id_consumer + + await file_id_consumer.stop() + if settings.DATABASE_ON: from fastfetchbot_shared.database.mongodb import close_mongodb diff --git a/apps/async-worker/async_worker/services/file_id_consumer.py b/apps/async-worker/async_worker/services/file_id_consumer.py new file mode 100644 index 0000000..60db77e --- /dev/null +++ b/apps/async-worker/async_worker/services/file_id_consumer.py @@ -0,0 +1,141 @@ +import asyncio +import json + +import redis.asyncio as aioredis + +from async_worker.config import settings +from fastfetchbot_shared.utils.logger import logger + +FILEID_QUEUE_KEY = "fileid:updates" +FILEID_DLQ_KEY = "fileid:updates:dlq" +_MAX_RETRIES = 3 + +_redis: aioredis.Redis | None = None +_consumer_task: asyncio.Task | None = None + + +async def _get_redis() -> aioredis.Redis: + global _redis + if _redis is None: + _redis = aioredis.from_url(settings.OUTBOX_REDIS_URL, decode_responses=True) + return _redis + + +async def _consume_loop() -> None: + """Background loop: BRPOP from the file_id updates queue and persist to MongoDB.""" + r = await _get_redis() + logger.info(f"file_id consumer started, listening on '{FILEID_QUEUE_KEY}'") + + while True: + raw_payload = None + try: + result = await r.brpop(FILEID_QUEUE_KEY, timeout=0) + if result is None: + continue + + _, raw_payload = result + payload = json.loads(raw_payload) + await _process_file_id_update(payload) + + except asyncio.CancelledError: + # Shutdown requested — requeue unprocessed payload so it isn't lost. + if raw_payload is not None: + try: + await r.lpush(FILEID_QUEUE_KEY, raw_payload) + logger.info("Requeued in-flight payload before shutdown") + except Exception: + logger.warning(f"Failed to requeue payload on shutdown: {raw_payload}") + logger.info("file_id consumer cancelled, shutting down") + break + except json.JSONDecodeError as e: + # Permanent failure — payload is malformed, send to dead-letter queue. + logger.error(f"Malformed JSON in file_id queue, moving to DLQ: {e}") + try: + await r.lpush(FILEID_DLQ_KEY, raw_payload) + except Exception: + logger.warning(f"Failed to push to DLQ: {raw_payload}") + except Exception as e: + # Transient failure (DB unavailable, network blip, etc.) — requeue + # the payload so it can be retried on the next loop iteration. + logger.error(f"file_id consumer error, requeuing payload: {e}") + if raw_payload is not None: + retry_count = 0 + try: + payload = json.loads(raw_payload) + retry_count = payload.get("_retry_count", 0) + except (json.JSONDecodeError, TypeError): + pass + + if retry_count < _MAX_RETRIES: + try: + # Stamp retry count so we can detect repeated failures. + payload["_retry_count"] = retry_count + 1 + await r.lpush(FILEID_QUEUE_KEY, json.dumps(payload, ensure_ascii=False)) + except Exception: + logger.warning(f"Failed to requeue payload: {raw_payload}") + else: + logger.error(f"Max retries ({_MAX_RETRIES}) exceeded, moving to DLQ: {raw_payload}") + try: + await r.lpush(FILEID_DLQ_KEY, raw_payload) + except Exception: + logger.warning(f"Failed to push to DLQ: {raw_payload}") + await asyncio.sleep(1) + + +async def _process_file_id_update(payload: dict) -> None: + """Update the latest Metadata document with telegram file_ids.""" + from fastfetchbot_shared.database.mongodb.models.metadata import Metadata + + metadata_url = payload.get("metadata_url", "") + updates = payload.get("file_id_updates", []) + + if not metadata_url or not updates: + logger.warning(f"Invalid file_id update payload: {payload}") + return + + doc = await Metadata.find( + Metadata.url == metadata_url + ).sort("-version").limit(1).first_or_none() + + if doc is None or not doc.media_files: + logger.warning(f"No metadata found for file_id update: {metadata_url}") + return + + matched = 0 + for update in updates: + for mf in doc.media_files: + if mf.url == update["url"] and mf.telegram_file_id is None: + mf.telegram_file_id = update["telegram_file_id"] + matched += 1 + + if matched > 0: + await doc.save() + logger.info(f"Updated {matched}/{len(updates)} file_ids for {metadata_url}") + + +async def start() -> None: + """Start the file_id consumer as a background asyncio task.""" + global _consumer_task + if _consumer_task is not None: + logger.warning("file_id consumer already running") + return + _consumer_task = asyncio.create_task(_consume_loop()) + logger.info("file_id consumer task created") + + +async def stop() -> None: + """Stop the file_id consumer and close the Redis connection.""" + global _consumer_task, _redis + + if _consumer_task is not None: + _consumer_task.cancel() + try: + await _consumer_task + except asyncio.CancelledError: + pass + _consumer_task = None + logger.info("file_id consumer task stopped") + + if _redis is not None: + await _redis.aclose() + _redis = None diff --git a/apps/telegram-bot/core/services/file_id_capture.py b/apps/telegram-bot/core/services/file_id_capture.py new file mode 100644 index 0000000..0f04779 --- /dev/null +++ b/apps/telegram-bot/core/services/file_id_capture.py @@ -0,0 +1,78 @@ +import json +from typing import Optional + +import redis.asyncio as aioredis +from telegram import Message + +from core.config import settings +from fastfetchbot_shared.utils.logger import logger + +FILEID_QUEUE_KEY = "fileid:updates" + +_redis: aioredis.Redis | None = None + + +async def _get_redis() -> aioredis.Redis: + global _redis + if _redis is None: + _redis = aioredis.from_url(settings.OUTBOX_REDIS_URL, decode_responses=True) + return _redis + + +def extract_file_id(message: Message, media_type: str) -> Optional[str]: + """Extract the telegram file_id from a sent Message based on media type.""" + if media_type == "image" and message.photo: + return message.photo[-1].file_id + elif media_type == "video" and message.video: + return message.video.file_id + elif media_type == "gif" and message.animation: + return message.animation.file_id + elif media_type == "audio" and message.audio: + return message.audio.file_id + elif media_type == "document" and message.document: + return message.document.file_id + return None + + +async def capture_and_push_file_ids( + uncached_info: list[dict], + sent_messages: tuple[Message, ...], + metadata_url: str, +) -> None: + """Extract file_ids from sent messages and push updates to Redis for the async worker. + + Args: + uncached_info: list of {"url": str, "media_type": str} for items that were + downloaded (not served from cached file_id). Parallel to sent_messages — + each entry corresponds to a message at the same position. + sent_messages: tuple of Message objects returned by send_media_group. + metadata_url: the original scraped page URL, used as the key for MongoDB lookup. + """ + file_id_updates = [] + + for i, info in enumerate(uncached_info): + if info is None: + continue + if i >= len(sent_messages): + break + file_id = extract_file_id(sent_messages[i], info["media_type"]) + if file_id: + file_id_updates.append({ + "url": info["url"], + "media_type": info["media_type"], + "telegram_file_id": file_id, + }) + + if not file_id_updates: + return + + try: + r = await _get_redis() + payload = json.dumps({ + "metadata_url": metadata_url, + "file_id_updates": file_id_updates, + }, ensure_ascii=False) + await r.lpush(FILEID_QUEUE_KEY, payload) + logger.info(f"Pushed {len(file_id_updates)} file_id updates for {metadata_url}") + except Exception: + logger.warning(f"Failed to push file_id updates for {metadata_url}") diff --git a/apps/telegram-bot/core/services/message_sender.py b/apps/telegram-bot/core/services/message_sender.py index 1e6480d..9f3bf4a 100644 --- a/apps/telegram-bot/core/services/message_sender.py +++ b/apps/telegram-bot/core/services/message_sender.py @@ -44,6 +44,12 @@ def _get_application(): return application +def _log_file_id_task_exception(task: asyncio.Task): + """Callback for fire-and-forget file_id capture tasks.""" + if not task.cancelled() and task.exception(): + logger.opt(exception=task.exception()).error("Error in file_id capture task") + + async def send_item_message( data: dict, chat_id: Union[int, str] = None, message: Message = None, message_id: int = None, @@ -74,12 +80,13 @@ async def send_item_message( if len(data["media_files"]) > 0: # if the message type is short and there are some media files, send media group reply_to_message_id = None - media_message_group, file_message_group = await media_files_packaging( + media_message_group, file_message_group, uncached_media_info = await media_files_packaging( media_files=data["media_files"], data=data ) if ( len(media_message_group) > 0 ): # if there are some media groups to send, send it + all_sent_messages = [] for i, media_group in enumerate(media_message_group): caption_text = ( caption_text @@ -98,11 +105,24 @@ async def send_item_message( write_timeout=settings.TELEBOT_WRITE_TIMEOUT, reply_to_message_id=_reply_to, ) + all_sent_messages.extend(sent_media_files_message) if sent_media_files_message is tuple: reply_to_message_id = sent_media_files_message[0].message_id elif sent_media_files_message is Message: reply_to_message_id = sent_media_files_message.message_id logger.debug(f"sent media files message: {sent_media_files_message}") + # Background file_id capture: extract file_ids from sent messages + has_uncached = any(info is not None for info in uncached_media_info) + if settings.SCRAPE_MODE == "queue" and has_uncached and all_sent_messages: + from core.services.file_id_capture import capture_and_push_file_ids + task = asyncio.create_task( + capture_and_push_file_ids( + uncached_info=uncached_media_info, + sent_messages=tuple(all_sent_messages), + metadata_url=data.get("url", ""), + ) + ) + task.add_done_callback(_log_file_id_task_exception) else: sent_message = await application.bot.send_message( chat_id=chat_id, @@ -200,17 +220,16 @@ async def media_files_packaging(media_files: list, data: dict) -> tuple: sending them by send_media_group method or send_document method. :param data: (dict) metadata of the item :param media_files: (list) a list of media files, - :return: (tuple) a tuple of media group and file group - media_message_group: (list) a list of media items, the type of each item is InputMediaPhoto or InputMediaVideo - file_group: (list) a list of file items, the type of each item is InputFile - TODO: It's not a good practice for this function. This method will still download all the media files even when - media files are too large and it can be memory consuming even if we use a database to store the media files. - The function should be optimized to resolve the media files one group by one group and send each group - immediately after it is resolved. - This processing method should be optimized in the future. + :return: (tuple) a tuple of (media_message_group, file_message_group, uncached_media_info) + media_message_group: (list) a list of media groups, each is a list of InputMedia* items + file_message_group: (list) a list of file groups, each is a list of InputMediaDocument items + uncached_media_info: (list) parallel to flattened media groups — each entry is + {"url": str, "media_type": str} for items that were downloaded (need file_id capture), + or None for items served from cached file_id """ media_counter, file_counter = 0, 0 media_message_group, media_group, file_message_group, file_group = [], [], [], [] + uncached_media_info = [] for ( media_item ) in media_files: # To traverse all media items in the media files list @@ -225,6 +244,27 @@ async def media_files_packaging(media_files: list, data: dict) -> tuple: file_message_group.append(file_group) file_group = [] file_counter = 0 + # Check for cached telegram file_id — skip download entirely if available + file_id = media_item.get("telegram_file_id") + if file_id: + media_type = media_item["media_type"] + if media_type == "image": + media_group.append(InputMediaPhoto(file_id)) + elif media_type == "gif": + media_group.append(InputMediaAnimation(file_id)) + elif media_type == "video": + media_group.append(InputMediaVideo(file_id, supports_streaming=True)) + elif media_type == "audio": + media_group.append(InputMediaAudio(file_id)) + elif media_type == "document": + file_group.append(InputMediaDocument(file_id, parse_mode=ParseMode.HTML)) + file_counter += 1 + uncached_media_info.append(None) + media_counter += 1 + logger.info( + f"get the {media_counter}th media item (cached file_id), type: {media_type}, url: {media_item['url']}" + ) + continue if not ( media_item["media_type"] in ["image", "gif", "video"] and data["message_type"] == "long" @@ -343,6 +383,10 @@ async def media_files_packaging(media_files: list, data: dict) -> tuple: InputMediaDocument(io_object, parse_mode=ParseMode.HTML) ) file_counter += 1 + uncached_media_info.append({ + "url": media_item["url"], + "media_type": media_item["media_type"], + }) media_counter += 1 logger.info( f"get the {media_counter}th media item,type: {media_item['media_type']}, url: {media_item['url']}" @@ -352,4 +396,4 @@ async def media_files_packaging(media_files: list, data: dict) -> tuple: media_message_group.append(media_group) if len(file_group) > 0: file_message_group.append(file_group) - return media_message_group, file_message_group + return media_message_group, file_message_group, uncached_media_info diff --git a/packages/shared/fastfetchbot_shared/database/mongodb/models/metadata.py b/packages/shared/fastfetchbot_shared/database/mongodb/models/metadata.py index cd893a8..358f87f 100644 --- a/packages/shared/fastfetchbot_shared/database/mongodb/models/metadata.py +++ b/packages/shared/fastfetchbot_shared/database/mongodb/models/metadata.py @@ -1,4 +1,5 @@ from typing import Optional, Any +from dataclasses import asdict from datetime import datetime from pydantic import Field @@ -41,6 +42,9 @@ class Settings: indexes = [ [("url", DESCENDING), ("version", DESCENDING)], ] + bson_encoders = { + DatabaseMediaFile: asdict, + } @before_event(Insert) def prepare_for_insert(self): diff --git a/packages/shared/fastfetchbot_shared/models/metadata_item.py b/packages/shared/fastfetchbot_shared/models/metadata_item.py index 6b5820d..4c37c1c 100644 --- a/packages/shared/fastfetchbot_shared/models/metadata_item.py +++ b/packages/shared/fastfetchbot_shared/models/metadata_item.py @@ -42,6 +42,7 @@ class MediaFile: url: str original_url: Optional[str] = None caption: Optional[str] = None + telegram_file_id: Optional[str] = None @staticmethod def from_dict(obj: Any) -> "MediaFile": @@ -49,13 +50,16 @@ def from_dict(obj: Any) -> "MediaFile": media_type = from_str(obj.get("media_type")) url = from_str(obj.get("url")) caption = from_str(obj.get("caption")) - return MediaFile(media_type, url, caption) + telegram_file_id = obj.get("telegram_file_id") + return MediaFile(media_type, url, caption=caption, telegram_file_id=telegram_file_id) def to_dict(self) -> dict: result: dict = {} result["media_type"] = from_str(self.media_type) result["url"] = from_str(self.url) result["caption"] = self.caption + if self.telegram_file_id is not None: + result["telegram_file_id"] = self.telegram_file_id return result diff --git a/tests/unit/async_worker/test_file_id_consumer.py b/tests/unit/async_worker/test_file_id_consumer.py new file mode 100644 index 0000000..37cf073 --- /dev/null +++ b/tests/unit/async_worker/test_file_id_consumer.py @@ -0,0 +1,529 @@ +"""Tests for apps/async-worker/async_worker/services/file_id_consumer.py""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def reset_module_state(): + """Reset module-level global state before each test.""" + from async_worker.services import file_id_consumer as fic + + fic._redis = None + fic._consumer_task = None + yield + fic._redis = None + fic._consumer_task = None + + +@pytest.fixture +def mock_redis(): + r = AsyncMock() + r.brpop = AsyncMock() + r.lpush = AsyncMock() + r.aclose = AsyncMock() + return r + + +def _make_payload(metadata_url="https://example.com/post/1", file_id_updates=None): + """Create a JSON-encoded file_id update payload.""" + if file_id_updates is None: + file_id_updates = [ + { + "url": "https://img.com/1.jpg", + "media_type": "image", + "telegram_file_id": "AgACAgI123", + }, + ] + return json.dumps( + { + "metadata_url": metadata_url, + "file_id_updates": file_id_updates, + }, + ensure_ascii=False, + ) + + +# --------------------------------------------------------------------------- +# _get_redis +# --------------------------------------------------------------------------- + + +class TestGetRedis: + @pytest.mark.asyncio + async def test_creates_connection(self, mock_redis): + with patch( + "async_worker.services.file_id_consumer.aioredis.from_url", + return_value=mock_redis, + ): + from async_worker.services.file_id_consumer import _get_redis + + r = await _get_redis() + assert r is mock_redis + + @pytest.mark.asyncio + async def test_reuses_connection(self, mock_redis): + with patch( + "async_worker.services.file_id_consumer.aioredis.from_url", + return_value=mock_redis, + ) as mock_from_url: + from async_worker.services.file_id_consumer import _get_redis + + r1 = await _get_redis() + r2 = await _get_redis() + assert r1 is r2 + mock_from_url.assert_called_once() + + +# --------------------------------------------------------------------------- +# _process_file_id_update +# --------------------------------------------------------------------------- + + +class TestProcessFileIdUpdate: + @pytest.mark.asyncio + async def test_updates_matching_media_file(self): + from async_worker.services.file_id_consumer import _process_file_id_update + + # Create a mock Metadata document with media_files + mock_mf = MagicMock() + mock_mf.url = "https://img.com/1.jpg" + mock_mf.telegram_file_id = None + + mock_doc = MagicMock() + mock_doc.media_files = [mock_mf] + mock_doc.save = AsyncMock() + + mock_query = MagicMock() + mock_query.sort = MagicMock(return_value=mock_query) + mock_query.limit = MagicMock(return_value=mock_query) + mock_query.first_or_none = AsyncMock(return_value=mock_doc) + + with patch( + "fastfetchbot_shared.database.mongodb.models.metadata.Metadata" + ) as MockMetadata: + MockMetadata.find = MagicMock(return_value=mock_query) + MockMetadata.url = "url" + + await _process_file_id_update({ + "metadata_url": "https://example.com/post/1", + "file_id_updates": [ + { + "url": "https://img.com/1.jpg", + "media_type": "image", + "telegram_file_id": "AgACAgI123", + }, + ], + }) + + assert mock_mf.telegram_file_id == "AgACAgI123" + mock_doc.save.assert_awaited_once() + + @pytest.mark.asyncio + async def test_skips_already_set_file_id(self): + from async_worker.services.file_id_consumer import _process_file_id_update + + mock_mf = MagicMock() + mock_mf.url = "https://img.com/1.jpg" + mock_mf.telegram_file_id = "existing_id" # already set + + mock_doc = MagicMock() + mock_doc.media_files = [mock_mf] + mock_doc.save = AsyncMock() + + mock_query = MagicMock() + mock_query.sort = MagicMock(return_value=mock_query) + mock_query.limit = MagicMock(return_value=mock_query) + mock_query.first_or_none = AsyncMock(return_value=mock_doc) + + with patch( + "fastfetchbot_shared.database.mongodb.models.metadata.Metadata" + ) as MockMetadata: + MockMetadata.find = MagicMock(return_value=mock_query) + MockMetadata.url = "url" + + await _process_file_id_update({ + "metadata_url": "https://example.com/post/1", + "file_id_updates": [ + { + "url": "https://img.com/1.jpg", + "media_type": "image", + "telegram_file_id": "new_id", + }, + ], + }) + + # Should not overwrite existing file_id + assert mock_mf.telegram_file_id == "existing_id" + # No changes → should not save + mock_doc.save.assert_not_awaited() + + @pytest.mark.asyncio + async def test_no_save_when_no_match(self): + from async_worker.services.file_id_consumer import _process_file_id_update + + mock_mf = MagicMock() + mock_mf.url = "https://other.com/2.jpg" # different URL + mock_mf.telegram_file_id = None + + mock_doc = MagicMock() + mock_doc.media_files = [mock_mf] + mock_doc.save = AsyncMock() + + mock_query = MagicMock() + mock_query.sort = MagicMock(return_value=mock_query) + mock_query.limit = MagicMock(return_value=mock_query) + mock_query.first_or_none = AsyncMock(return_value=mock_doc) + + with patch( + "fastfetchbot_shared.database.mongodb.models.metadata.Metadata" + ) as MockMetadata: + MockMetadata.find = MagicMock(return_value=mock_query) + MockMetadata.url = "url" + + await _process_file_id_update({ + "metadata_url": "https://example.com/post/1", + "file_id_updates": [ + { + "url": "https://img.com/1.jpg", + "media_type": "image", + "telegram_file_id": "AgACAgI123", + }, + ], + }) + + mock_doc.save.assert_not_awaited() + + @pytest.mark.asyncio + async def test_handles_no_document_found(self): + from async_worker.services.file_id_consumer import _process_file_id_update + + mock_query = MagicMock() + mock_query.sort = MagicMock(return_value=mock_query) + mock_query.limit = MagicMock(return_value=mock_query) + mock_query.first_or_none = AsyncMock(return_value=None) + + with patch( + "fastfetchbot_shared.database.mongodb.models.metadata.Metadata" + ) as MockMetadata: + MockMetadata.find = MagicMock(return_value=mock_query) + MockMetadata.url = "url" + + # Should not raise + await _process_file_id_update({ + "metadata_url": "https://example.com/missing", + "file_id_updates": [ + { + "url": "https://img.com/1.jpg", + "media_type": "image", + "telegram_file_id": "AgACAgI123", + }, + ], + }) + + @pytest.mark.asyncio + async def test_handles_empty_payload(self): + from async_worker.services.file_id_consumer import _process_file_id_update + + # Should not raise with empty/missing fields + await _process_file_id_update({"metadata_url": "", "file_id_updates": []}) + await _process_file_id_update({}) + + @pytest.mark.asyncio + async def test_updates_multiple_media_files(self): + from async_worker.services.file_id_consumer import _process_file_id_update + + mock_mf1 = MagicMock() + mock_mf1.url = "https://img.com/1.jpg" + mock_mf1.telegram_file_id = None + + mock_mf2 = MagicMock() + mock_mf2.url = "https://vid.com/v.mp4" + mock_mf2.telegram_file_id = None + + mock_doc = MagicMock() + mock_doc.media_files = [mock_mf1, mock_mf2] + mock_doc.save = AsyncMock() + + mock_query = MagicMock() + mock_query.sort = MagicMock(return_value=mock_query) + mock_query.limit = MagicMock(return_value=mock_query) + mock_query.first_or_none = AsyncMock(return_value=mock_doc) + + with patch( + "fastfetchbot_shared.database.mongodb.models.metadata.Metadata" + ) as MockMetadata: + MockMetadata.find = MagicMock(return_value=mock_query) + MockMetadata.url = "url" + + await _process_file_id_update({ + "metadata_url": "https://example.com/post/1", + "file_id_updates": [ + { + "url": "https://img.com/1.jpg", + "media_type": "image", + "telegram_file_id": "photo_id", + }, + { + "url": "https://vid.com/v.mp4", + "media_type": "video", + "telegram_file_id": "video_id", + }, + ], + }) + + assert mock_mf1.telegram_file_id == "photo_id" + assert mock_mf2.telegram_file_id == "video_id" + mock_doc.save.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# _consume_loop +# --------------------------------------------------------------------------- + + +class TestConsumeLoop: + @pytest.mark.asyncio + async def test_processes_valid_payload(self, mock_redis): + payload = _make_payload() + call_count = 0 + + async def brpop_side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return ("fileid:updates", payload) + raise asyncio.CancelledError() + + mock_redis.brpop = AsyncMock(side_effect=brpop_side_effect) + + with patch( + "async_worker.services.file_id_consumer.aioredis.from_url", + return_value=mock_redis, + ), patch( + "async_worker.services.file_id_consumer._process_file_id_update", + new_callable=AsyncMock, + ) as mock_process: + from async_worker.services.file_id_consumer import _consume_loop + + await _consume_loop() + + mock_process.assert_awaited_once() + call_args = mock_process.call_args[0][0] + assert call_args["metadata_url"] == "https://example.com/post/1" + + @pytest.mark.asyncio + async def test_malformed_json_goes_to_dlq(self, mock_redis): + """Malformed JSON should be moved to the dead-letter queue, not requeued.""" + call_count = 0 + + async def brpop_side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return ("fileid:updates", "NOT-VALID-JSON{{{") + raise asyncio.CancelledError() + + mock_redis.brpop = AsyncMock(side_effect=brpop_side_effect) + + with patch( + "async_worker.services.file_id_consumer.aioredis.from_url", + return_value=mock_redis, + ): + from async_worker.services.file_id_consumer import _consume_loop + + await _consume_loop() + + # Should have pushed to DLQ, not the main queue + mock_redis.lpush.assert_awaited_once() + dlq_key = mock_redis.lpush.call_args[0][0] + assert dlq_key == "fileid:updates:dlq" + + @pytest.mark.asyncio + async def test_transient_error_requeues_payload(self, mock_redis): + """A DB/IO error during processing should requeue the payload for retry.""" + payload = _make_payload() + call_count = 0 + + async def brpop_side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return ("fileid:updates", payload) + raise asyncio.CancelledError() + + mock_redis.brpop = AsyncMock(side_effect=brpop_side_effect) + + with patch( + "async_worker.services.file_id_consumer.aioredis.from_url", + return_value=mock_redis, + ), patch( + "async_worker.services.file_id_consumer._process_file_id_update", + new_callable=AsyncMock, + side_effect=RuntimeError("MongoDB down"), + ): + from async_worker.services.file_id_consumer import _consume_loop + + await _consume_loop() + + # Should have requeued to the main queue with _retry_count stamped + mock_redis.lpush.assert_awaited_once() + requeue_key = mock_redis.lpush.call_args[0][0] + assert requeue_key == "fileid:updates" + requeued = json.loads(mock_redis.lpush.call_args[0][1]) + assert requeued["_retry_count"] == 1 + assert requeued["metadata_url"] == "https://example.com/post/1" + + @pytest.mark.asyncio + async def test_max_retries_exceeded_goes_to_dlq(self, mock_redis): + """After _MAX_RETRIES failures, the payload should go to DLQ.""" + payload_dict = json.loads(_make_payload()) + payload_dict["_retry_count"] = 3 # already at max + payload = json.dumps(payload_dict, ensure_ascii=False) + call_count = 0 + + async def brpop_side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return ("fileid:updates", payload) + raise asyncio.CancelledError() + + mock_redis.brpop = AsyncMock(side_effect=brpop_side_effect) + + with patch( + "async_worker.services.file_id_consumer.aioredis.from_url", + return_value=mock_redis, + ), patch( + "async_worker.services.file_id_consumer._process_file_id_update", + new_callable=AsyncMock, + side_effect=RuntimeError("still failing"), + ): + from async_worker.services.file_id_consumer import _consume_loop + + await _consume_loop() + + # Should go to DLQ, not requeued + mock_redis.lpush.assert_awaited_once() + dlq_key = mock_redis.lpush.call_args[0][0] + assert dlq_key == "fileid:updates:dlq" + + @pytest.mark.asyncio + async def test_shutdown_requeues_in_flight_payload(self, mock_redis): + """If CancelledError hits after BRPOP but before processing completes, + the payload should be requeued.""" + payload = _make_payload() + + async def brpop_side_effect(*args, **kwargs): + return ("fileid:updates", payload) + + mock_redis.brpop = AsyncMock(side_effect=brpop_side_effect) + + async def process_side_effect(p): + raise asyncio.CancelledError() + + with patch( + "async_worker.services.file_id_consumer.aioredis.from_url", + return_value=mock_redis, + ), patch( + "async_worker.services.file_id_consumer._process_file_id_update", + new_callable=AsyncMock, + side_effect=process_side_effect, + ): + from async_worker.services.file_id_consumer import _consume_loop + + await _consume_loop() + + # Payload should have been requeued before shutdown + mock_redis.lpush.assert_awaited_once() + requeue_key = mock_redis.lpush.call_args[0][0] + assert requeue_key == "fileid:updates" + + @pytest.mark.asyncio + async def test_brpop_none_continues(self, mock_redis): + call_count = 0 + + async def brpop_side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count <= 2: + return None + raise asyncio.CancelledError() + + mock_redis.brpop = AsyncMock(side_effect=brpop_side_effect) + + with patch( + "async_worker.services.file_id_consumer.aioredis.from_url", + return_value=mock_redis, + ): + from async_worker.services.file_id_consumer import _consume_loop + + await _consume_loop() + + assert call_count == 3 + + +# --------------------------------------------------------------------------- +# start / stop +# --------------------------------------------------------------------------- + + +class TestStartStop: + @pytest.mark.asyncio + async def test_start_creates_task(self, mock_redis): + from async_worker.services import file_id_consumer as fic + + with patch.object( + fic, "_consume_loop", new_callable=AsyncMock, + ): + await fic.start() + assert fic._consumer_task is not None + await fic.stop() + + @pytest.mark.asyncio + async def test_start_idempotent(self): + from async_worker.services import file_id_consumer as fic + + fake_task = MagicMock() + fic._consumer_task = fake_task + + await fic.start() + assert fic._consumer_task is fake_task + fic._consumer_task = None + + @pytest.mark.asyncio + async def test_stop_cancels_task(self, mock_redis): + from async_worker.services import file_id_consumer as fic + + async def _noop(): + try: + await asyncio.sleep(999) + except asyncio.CancelledError: + pass + + task = asyncio.create_task(_noop()) + fic._consumer_task = task + fic._redis = mock_redis + + await fic.stop() + + assert task.cancelled() or task.done() + mock_redis.aclose.assert_awaited_once() + assert fic._consumer_task is None + assert fic._redis is None + + @pytest.mark.asyncio + async def test_stop_when_not_running(self): + from async_worker.services import file_id_consumer as fic + + fic._consumer_task = None + fic._redis = None + await fic.stop() # should not raise diff --git a/tests/unit/async_worker/test_worker_lifecycle.py b/tests/unit/async_worker/test_worker_lifecycle.py index fbc4f09..11f3263 100644 --- a/tests/unit/async_worker/test_worker_lifecycle.py +++ b/tests/unit/async_worker/test_worker_lifecycle.py @@ -19,13 +19,18 @@ async def test_init_mongodb_called_when_database_on(self): patch( "fastfetchbot_shared.database.mongodb.init_mongodb", new_callable=AsyncMock, - ) as mock_init: + ) as mock_init, \ + patch( + "async_worker.services.file_id_consumer.start", + new_callable=AsyncMock, + ) as mock_fic_start: mock_settings.DATABASE_ON = True mock_settings.MONGODB_URL = "mongodb://localhost:27017" await WorkerSettings.on_startup({}) mock_init.assert_awaited_once_with("mongodb://localhost:27017") + mock_fic_start.assert_awaited_once() @pytest.mark.asyncio async def test_init_mongodb_skipped_when_database_off(self): @@ -53,11 +58,17 @@ async def test_close_mongodb_called_when_database_on(self): patch( "fastfetchbot_shared.database.mongodb.close_mongodb", new_callable=AsyncMock, - ) as mock_close: + ) as mock_close, \ + patch( + "async_worker.services.file_id_consumer.stop", + new_callable=AsyncMock, + ) as mock_fic_stop: mock_settings.DATABASE_ON = True + mock_settings.file_id_consumer_ready = True await WorkerSettings.on_shutdown({}) + mock_fic_stop.assert_awaited_once() mock_close.assert_awaited_once() @pytest.mark.asyncio @@ -68,6 +79,7 @@ async def test_close_mongodb_skipped_when_database_off(self): new_callable=AsyncMock, ) as mock_close: mock_settings.DATABASE_ON = False + mock_settings.file_id_consumer_ready = False await WorkerSettings.on_shutdown({}) diff --git a/tests/unit/database/mongodb/test_file_id_fields.py b/tests/unit/database/mongodb/test_file_id_fields.py new file mode 100644 index 0000000..41d015e --- /dev/null +++ b/tests/unit/database/mongodb/test_file_id_fields.py @@ -0,0 +1,244 @@ +"""Tests for telegram_file_id field on MediaFile and DatabaseMediaFile, +and the bson_encoders fix for Beanie serialization. + +Does NOT require a running MongoDB instance. +""" + +from dataclasses import asdict +from unittest.mock import patch + +import pytest + +from fastfetchbot_shared.models.metadata_item import MediaFile, MessageType +from fastfetchbot_shared.database.mongodb.models.metadata import ( + DatabaseMediaFile, + Metadata, +) + + +# --------------------------------------------------------------------------- +# MediaFile.telegram_file_id +# --------------------------------------------------------------------------- + + +class TestMediaFileTelegramFileId: + def test_default_is_none(self): + mf = MediaFile(media_type="image", url="https://img.com/1.jpg") + assert mf.telegram_file_id is None + + def test_can_set_value(self): + mf = MediaFile( + media_type="image", + url="https://img.com/1.jpg", + telegram_file_id="AgACAgI123", + ) + assert mf.telegram_file_id == "AgACAgI123" + + def test_to_dict_includes_field_when_set(self): + mf = MediaFile( + media_type="image", + url="https://img.com/1.jpg", + telegram_file_id="AgACAgI123", + ) + d = mf.to_dict() + assert d["telegram_file_id"] == "AgACAgI123" + + def test_to_dict_excludes_field_when_none(self): + mf = MediaFile(media_type="image", url="https://img.com/1.jpg") + d = mf.to_dict() + assert "telegram_file_id" not in d + + def test_from_dict_reads_field(self): + d = { + "media_type": "image", + "url": "https://img.com/1.jpg", + "telegram_file_id": "AgACAgI123", + } + mf = MediaFile.from_dict(d) + assert mf.telegram_file_id == "AgACAgI123" + + def test_from_dict_without_field_defaults_none(self): + d = {"media_type": "image", "url": "https://img.com/1.jpg"} + mf = MediaFile.from_dict(d) + assert mf.telegram_file_id is None + + def test_round_trip_with_file_id(self): + original = MediaFile( + media_type="video", + url="https://vid.com/v.mp4", + telegram_file_id="BAACAgI456", + ) + restored = MediaFile.from_dict(original.to_dict()) + assert restored.telegram_file_id == "BAACAgI456" + assert restored.media_type == "video" + assert restored.url == "https://vid.com/v.mp4" + + def test_round_trip_without_file_id(self): + original = MediaFile(media_type="image", url="https://img.com/1.jpg") + restored = MediaFile.from_dict(original.to_dict()) + assert restored.telegram_file_id is None + + +# --------------------------------------------------------------------------- +# DatabaseMediaFile inherits telegram_file_id +# --------------------------------------------------------------------------- + + +class TestDatabaseMediaFileTelegramFileId: + def test_inherits_field(self): + dmf = DatabaseMediaFile( + media_type="image", + url="https://img.com/1.jpg", + telegram_file_id="AgACAgI123", + ) + assert dmf.telegram_file_id == "AgACAgI123" + + def test_default_is_none(self): + dmf = DatabaseMediaFile(media_type="image", url="https://img.com/1.jpg") + assert dmf.telegram_file_id is None + + def test_has_both_file_key_and_telegram_file_id(self): + dmf = DatabaseMediaFile( + media_type="image", + url="https://img.com/1.jpg", + file_key="s3://bucket/key.jpg", + telegram_file_id="AgACAgI123", + ) + assert dmf.file_key == "s3://bucket/key.jpg" + assert dmf.telegram_file_id == "AgACAgI123" + + +# --------------------------------------------------------------------------- +# bson_encoders — Beanie serialization fix +# --------------------------------------------------------------------------- + + +class TestBsonEncoders: + def test_settings_has_bson_encoders(self): + assert hasattr(Metadata.Settings, "bson_encoders") + assert DatabaseMediaFile in Metadata.Settings.bson_encoders + + def test_encoder_uses_asdict(self): + encoder = Metadata.Settings.bson_encoders[DatabaseMediaFile] + assert encoder is asdict + + def test_encoder_converts_to_dict(self): + dmf = DatabaseMediaFile( + media_type="image", + url="https://img.com/1.jpg", + telegram_file_id="AgACAgI123", + file_key="s3://bucket/key.jpg", + ) + encoder = Metadata.Settings.bson_encoders[DatabaseMediaFile] + result = encoder(dmf) + assert isinstance(result, dict) + assert result["media_type"] == "image" + assert result["url"] == "https://img.com/1.jpg" + assert result["telegram_file_id"] == "AgACAgI123" + assert result["file_key"] == "s3://bucket/key.jpg" + + def test_encoder_handles_none_fields(self): + dmf = DatabaseMediaFile(media_type="image", url="https://img.com/1.jpg") + encoder = Metadata.Settings.bson_encoders[DatabaseMediaFile] + result = encoder(dmf) + assert result["telegram_file_id"] is None + assert result["file_key"] is None + + def test_beanie_encoder_can_encode_database_media_file(self): + """Verify the Beanie Encoder uses our custom encoder.""" + from beanie.odm.utils.encoder import Encoder + + dmf = DatabaseMediaFile( + media_type="image", + url="https://img.com/1.jpg", + telegram_file_id="AgACAgI123", + ) + encoder = Encoder( + custom_encoders=Metadata.Settings.bson_encoders, + to_db=True, + keep_nulls=True, + ) + result = encoder.encode(dmf) + assert isinstance(result, dict) + assert result["media_type"] == "image" + assert result["telegram_file_id"] == "AgACAgI123" + + def test_beanie_encoder_can_encode_list_of_database_media_files(self): + """Verify list of DatabaseMediaFile encodes to list of dicts.""" + from beanie.odm.utils.encoder import Encoder + + files = [ + DatabaseMediaFile(media_type="image", url="https://img.com/1.jpg"), + DatabaseMediaFile( + media_type="video", + url="https://vid.com/v.mp4", + telegram_file_id="BAACAgI456", + ), + ] + encoder = Encoder( + custom_encoders=Metadata.Settings.bson_encoders, + to_db=True, + keep_nulls=True, + ) + result = encoder.encode(files) + assert isinstance(result, list) + assert len(result) == 2 + assert isinstance(result[0], dict) + assert isinstance(result[1], dict) + assert result[1]["telegram_file_id"] == "BAACAgI456" + + +# --------------------------------------------------------------------------- +# prepare_for_insert preserves telegram_file_id +# --------------------------------------------------------------------------- + + +def _make_metadata(**kwargs): + defaults = { + "url": "https://example.com", + "title": "untitled", + "message_type": MessageType.SHORT, + "text_length": 0, + "content_length": 0, + "scrape_status": False, + "version": 1, + } + defaults.update(kwargs) + return Metadata.model_construct(**defaults) + + +class TestPrepareForInsertFileId: + def test_preserves_telegram_file_id_from_dict(self): + m = _make_metadata( + media_files=[ + { + "media_type": "image", + "url": "https://img.com/1.jpg", + "telegram_file_id": "AgACAgI123", + }, + ], + ) + with patch( + "fastfetchbot_shared.database.mongodb.models.metadata.get_html_text_length", + return_value=0, + ): + m.prepare_for_insert() + + assert isinstance(m.media_files[0], DatabaseMediaFile) + assert m.media_files[0].telegram_file_id == "AgACAgI123" + + def test_preserves_telegram_file_id_from_media_file(self): + mf = MediaFile( + media_type="image", + url="https://img.com/1.jpg", + telegram_file_id="AgACAgI123", + ) + m = _make_metadata(media_files=[mf]) + with patch( + "fastfetchbot_shared.database.mongodb.models.metadata.get_html_text_length", + return_value=0, + ): + m.prepare_for_insert() + + assert isinstance(m.media_files[0], DatabaseMediaFile) + assert m.media_files[0].telegram_file_id == "AgACAgI123" diff --git a/tests/unit/telegram_bot/test_file_id_capture.py b/tests/unit/telegram_bot/test_file_id_capture.py new file mode 100644 index 0000000..bc6bfdb --- /dev/null +++ b/tests/unit/telegram_bot/test_file_id_capture.py @@ -0,0 +1,325 @@ +"""Tests for apps/telegram-bot/core/services/file_id_capture.py""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def reset_module_state(): + """Reset module-level global state before each test.""" + import core.services.file_id_capture as fic + + fic._redis = None + yield + fic._redis = None + + +@pytest.fixture +def mock_redis(): + r = AsyncMock() + r.lpush = AsyncMock() + r.aclose = AsyncMock() + return r + + +# --------------------------------------------------------------------------- +# extract_file_id +# --------------------------------------------------------------------------- + + +class TestExtractFileId: + def test_extracts_photo_file_id(self): + from core.services.file_id_capture import extract_file_id + + msg = MagicMock() + photo_small = MagicMock(file_id="small_id") + photo_large = MagicMock(file_id="large_id") + msg.photo = [photo_small, photo_large] + + result = extract_file_id(msg, "image") + assert result == "large_id" # last = largest + + def test_extracts_video_file_id(self): + from core.services.file_id_capture import extract_file_id + + msg = MagicMock() + msg.video = MagicMock(file_id="vid_id") + + result = extract_file_id(msg, "video") + assert result == "vid_id" + + def test_extracts_animation_file_id(self): + from core.services.file_id_capture import extract_file_id + + msg = MagicMock() + msg.animation = MagicMock(file_id="gif_id") + + result = extract_file_id(msg, "gif") + assert result == "gif_id" + + def test_extracts_audio_file_id(self): + from core.services.file_id_capture import extract_file_id + + msg = MagicMock() + msg.audio = MagicMock(file_id="audio_id") + + result = extract_file_id(msg, "audio") + assert result == "audio_id" + + def test_extracts_document_file_id(self): + from core.services.file_id_capture import extract_file_id + + msg = MagicMock() + msg.document = MagicMock(file_id="doc_id") + + result = extract_file_id(msg, "document") + assert result == "doc_id" + + def test_returns_none_for_empty_photo(self): + from core.services.file_id_capture import extract_file_id + + msg = MagicMock() + msg.photo = [] + + result = extract_file_id(msg, "image") + assert result is None + + def test_returns_none_for_missing_video(self): + from core.services.file_id_capture import extract_file_id + + msg = MagicMock() + msg.video = None + + result = extract_file_id(msg, "video") + assert result is None + + def test_returns_none_for_unknown_type(self): + from core.services.file_id_capture import extract_file_id + + msg = MagicMock() + result = extract_file_id(msg, "unknown_type") + assert result is None + + +# --------------------------------------------------------------------------- +# _get_redis +# --------------------------------------------------------------------------- + + +class TestGetRedis: + @pytest.mark.asyncio + async def test_creates_connection(self, mock_redis): + with patch( + "core.services.file_id_capture.aioredis.from_url", + return_value=mock_redis, + ): + from core.services.file_id_capture import _get_redis + + r = await _get_redis() + assert r is mock_redis + + @pytest.mark.asyncio + async def test_reuses_connection(self, mock_redis): + with patch( + "core.services.file_id_capture.aioredis.from_url", + return_value=mock_redis, + ) as mock_from_url: + from core.services.file_id_capture import _get_redis + + r1 = await _get_redis() + r2 = await _get_redis() + assert r1 is r2 + mock_from_url.assert_called_once() + + +# --------------------------------------------------------------------------- +# capture_and_push_file_ids +# --------------------------------------------------------------------------- + + +class TestCaptureAndPushFileIds: + @pytest.mark.asyncio + async def test_pushes_file_id_updates(self, mock_redis): + from core.services.file_id_capture import capture_and_push_file_ids + + # Create mock sent messages + msg1 = MagicMock() + photo = MagicMock(file_id="AgACAgI123") + msg1.photo = [photo] + + msg2 = MagicMock() + msg2.video = MagicMock(file_id="BAACAgI456") + + uncached_info = [ + {"url": "https://img.com/1.jpg", "media_type": "image"}, + {"url": "https://vid.com/v.mp4", "media_type": "video"}, + ] + + with patch( + "core.services.file_id_capture.aioredis.from_url", + return_value=mock_redis, + ): + await capture_and_push_file_ids( + uncached_info=uncached_info, + sent_messages=(msg1, msg2), + metadata_url="https://example.com/post/1", + ) + + mock_redis.lpush.assert_awaited_once() + queue_key = mock_redis.lpush.call_args[0][0] + payload = json.loads(mock_redis.lpush.call_args[0][1]) + + assert queue_key == "fileid:updates" + assert payload["metadata_url"] == "https://example.com/post/1" + assert len(payload["file_id_updates"]) == 2 + assert payload["file_id_updates"][0]["telegram_file_id"] == "AgACAgI123" + assert payload["file_id_updates"][1]["telegram_file_id"] == "BAACAgI456" + + @pytest.mark.asyncio + async def test_skips_none_entries(self, mock_redis): + """None entries in uncached_info are cached items — skip them.""" + from core.services.file_id_capture import capture_and_push_file_ids + + msg1 = MagicMock() + msg1.video = MagicMock(file_id="vid_id") + + # First item was cached (None), second was downloaded + uncached_info = [ + None, + {"url": "https://vid.com/v.mp4", "media_type": "video"}, + ] + + with patch( + "core.services.file_id_capture.aioredis.from_url", + return_value=mock_redis, + ): + await capture_and_push_file_ids( + uncached_info=uncached_info, + sent_messages=(MagicMock(), msg1), + metadata_url="https://example.com/post/1", + ) + + payload = json.loads(mock_redis.lpush.call_args[0][1]) + assert len(payload["file_id_updates"]) == 1 + assert payload["file_id_updates"][0]["url"] == "https://vid.com/v.mp4" + + @pytest.mark.asyncio + async def test_no_push_when_all_cached(self, mock_redis): + """When all items are cached (None), nothing should be pushed.""" + from core.services.file_id_capture import capture_and_push_file_ids + + with patch( + "core.services.file_id_capture.aioredis.from_url", + return_value=mock_redis, + ): + await capture_and_push_file_ids( + uncached_info=[None, None], + sent_messages=(MagicMock(), MagicMock()), + metadata_url="https://example.com/post/1", + ) + + mock_redis.lpush.assert_not_awaited() + + @pytest.mark.asyncio + async def test_no_push_when_empty_uncached(self, mock_redis): + from core.services.file_id_capture import capture_and_push_file_ids + + with patch( + "core.services.file_id_capture.aioredis.from_url", + return_value=mock_redis, + ): + await capture_and_push_file_ids( + uncached_info=[], + sent_messages=(), + metadata_url="https://example.com/post/1", + ) + + mock_redis.lpush.assert_not_awaited() + + @pytest.mark.asyncio + async def test_handles_more_uncached_than_messages(self, mock_redis): + """If uncached_info has more entries than sent_messages, stop at messages length.""" + from core.services.file_id_capture import capture_and_push_file_ids + + msg1 = MagicMock() + msg1.photo = [MagicMock(file_id="photo_id")] + + uncached_info = [ + {"url": "https://img.com/1.jpg", "media_type": "image"}, + {"url": "https://img.com/2.jpg", "media_type": "image"}, # no message for this + ] + + with patch( + "core.services.file_id_capture.aioredis.from_url", + return_value=mock_redis, + ): + await capture_and_push_file_ids( + uncached_info=uncached_info, + sent_messages=(msg1,), # only 1 message + metadata_url="https://example.com/post/1", + ) + + payload = json.loads(mock_redis.lpush.call_args[0][1]) + assert len(payload["file_id_updates"]) == 1 + + @pytest.mark.asyncio + async def test_skips_unextractable_file_ids(self, mock_redis): + """If extract_file_id returns None, skip that item.""" + from core.services.file_id_capture import capture_and_push_file_ids + + msg1 = MagicMock() + msg1.photo = [] # empty photo list → extract returns None + msg1.video = None + msg1.animation = None + msg1.audio = None + msg1.document = None + + uncached_info = [ + {"url": "https://img.com/1.jpg", "media_type": "image"}, + ] + + with patch( + "core.services.file_id_capture.aioredis.from_url", + return_value=mock_redis, + ): + await capture_and_push_file_ids( + uncached_info=uncached_info, + sent_messages=(msg1,), + metadata_url="https://example.com/post/1", + ) + + # No valid file_ids extracted, so nothing pushed + mock_redis.lpush.assert_not_awaited() + + @pytest.mark.asyncio + async def test_redis_error_is_caught(self, mock_redis): + """Redis failures should be caught and logged, not raised.""" + from core.services.file_id_capture import capture_and_push_file_ids + + mock_redis.lpush = AsyncMock(side_effect=ConnectionError("Redis down")) + + msg1 = MagicMock() + msg1.photo = [MagicMock(file_id="photo_id")] + + uncached_info = [ + {"url": "https://img.com/1.jpg", "media_type": "image"}, + ] + + with patch( + "core.services.file_id_capture.aioredis.from_url", + return_value=mock_redis, + ): + # Should not raise + await capture_and_push_file_ids( + uncached_info=uncached_info, + sent_messages=(msg1,), + metadata_url="https://example.com/post/1", + ) diff --git a/tests/unit/telegram_bot/test_message_sender.py b/tests/unit/telegram_bot/test_message_sender.py index 962777b..4f1b321 100644 --- a/tests/unit/telegram_bot/test_message_sender.py +++ b/tests/unit/telegram_bot/test_message_sender.py @@ -1,4 +1,8 @@ -"""Tests for exception handling in apps/telegram-bot/core/services/message_sender.py""" +"""Tests for apps/telegram-bot/core/services/message_sender.py + +Covers exception handling, the telegram_file_id shortcut in media_files_packaging, +and the background file_id capture wiring. +""" from unittest.mock import AsyncMock, MagicMock, patch @@ -26,7 +30,7 @@ async def test_http_download_failure_skips_media(self): new_callable=AsyncMock, side_effect=RuntimeError("network error"), ): - media_group, file_group = await media_files_packaging(media_files, data) + media_group, file_group, uncached = await media_files_packaging(media_files, data) # Media item should be skipped, not crash assert media_group == [] @@ -50,7 +54,7 @@ async def test_gif_download_failure_skips_media(self): new_callable=AsyncMock, side_effect=ConnectionError("timeout"), ): - media_group, file_group = await media_files_packaging(media_files, data) + media_group, file_group, uncached = await media_files_packaging(media_files, data) assert media_group == [] assert file_group == [] @@ -107,13 +111,212 @@ async def download_side_effect(*args, **kwargs): mock_settings.TELEGRAM_IMAGE_DIMENSION_LIMIT = 2000 mock_settings.TELEGRAM_IMAGE_SIZE_LIMIT = 10 * 1024 * 1024 # 10MB - media_group, file_group = await media_files_packaging(media_files, data) + media_group, file_group, uncached = await media_files_packaging(media_files, data) # The image media_group should have the photo, but file_group should be empty # because the document download failed and was skipped assert file_group == [] +# --------------------------------------------------------------------------- +# file_id shortcut in media_files_packaging +# --------------------------------------------------------------------------- + + +class TestMediaFilesPackagingFileIdShortcut: + """Test that media_files_packaging uses telegram_file_id when available.""" + + @pytest.mark.asyncio + async def test_uses_file_id_for_image(self): + from core.services.message_sender import media_files_packaging + from telegram import InputMediaPhoto + + media_files = [ + { + "media_type": "image", + "url": "https://img.com/1.jpg", + "telegram_file_id": "AgACAgI123", + }, + ] + data = {"url": "https://example.com", "category": "twitter", "message_type": "short"} + + media_group, file_group, uncached = await media_files_packaging(media_files, data) + + assert len(media_group) == 1 + assert len(media_group[0]) == 1 + # file_id was used — no download should have been attempted + assert uncached == [None] + + @pytest.mark.asyncio + async def test_uses_file_id_for_video(self): + from core.services.message_sender import media_files_packaging + + media_files = [ + { + "media_type": "video", + "url": "https://vid.com/v.mp4", + "telegram_file_id": "BAACAgI456", + }, + ] + data = {"url": "https://example.com", "category": "twitter", "message_type": "short"} + + media_group, file_group, uncached = await media_files_packaging(media_files, data) + + assert len(media_group) == 1 + assert len(media_group[0]) == 1 + assert uncached == [None] + + @pytest.mark.asyncio + async def test_uses_file_id_for_gif(self): + from core.services.message_sender import media_files_packaging + + media_files = [ + { + "media_type": "gif", + "url": "https://img.com/anim.gif", + "telegram_file_id": "CgACAgI789", + }, + ] + data = {"url": "https://example.com", "category": "twitter", "message_type": "short"} + + media_group, file_group, uncached = await media_files_packaging(media_files, data) + + assert len(media_group) == 1 + assert uncached == [None] + + @pytest.mark.asyncio + async def test_uses_file_id_for_document(self): + from core.services.message_sender import media_files_packaging + + media_files = [ + { + "media_type": "document", + "url": "https://example.com/doc.pdf", + "telegram_file_id": "BQACAgI000", + }, + ] + data = {"url": "https://example.com", "category": "twitter", "message_type": "short"} + + media_group, file_group, uncached = await media_files_packaging(media_files, data) + + assert len(file_group) == 1 + assert uncached == [None] + + @pytest.mark.asyncio + async def test_no_download_when_file_id_present(self): + """Verify that download_file_by_metadata_item is NOT called for cached items.""" + from core.services.message_sender import media_files_packaging + + media_files = [ + { + "media_type": "image", + "url": "https://img.com/1.jpg", + "telegram_file_id": "AgACAgI123", + }, + ] + data = {"url": "https://example.com", "category": "twitter", "message_type": "short"} + + with patch( + "core.services.message_sender.download_file_by_metadata_item", + new_callable=AsyncMock, + ) as mock_download: + await media_files_packaging(media_files, data) + + mock_download.assert_not_awaited() + + @pytest.mark.asyncio + async def test_uncached_info_for_downloaded_items(self): + """Items without file_id should appear in uncached_media_info.""" + from core.services.message_sender import media_files_packaging + + mock_io = MagicMock() + mock_io.name = "video.mp4" + mock_io.size = 1024 + + media_files = [ + { + "media_type": "video", + "url": "https://vid.com/v.mp4", + }, + ] + data = {"url": "https://example.com", "category": "twitter", "message_type": "short"} + + with patch( + "core.services.message_sender.download_file_by_metadata_item", + new_callable=AsyncMock, + return_value=mock_io, + ), patch( + "core.services.message_sender.settings" + ) as mock_settings: + mock_settings.TELEBOT_API_SERVER = "http://local:8081/bot" + media_group, file_group, uncached = await media_files_packaging(media_files, data) + + assert len(uncached) == 1 + assert uncached[0] is not None + assert uncached[0]["url"] == "https://vid.com/v.mp4" + assert uncached[0]["media_type"] == "video" + + @pytest.mark.asyncio + async def test_mixed_cached_and_uncached(self): + """Test a mix of items with and without file_ids.""" + from core.services.message_sender import media_files_packaging + + mock_io = MagicMock() + mock_io.name = "video.mp4" + mock_io.size = 1024 + + media_files = [ + { + "media_type": "image", + "url": "https://img.com/1.jpg", + "telegram_file_id": "AgACAgI123", + }, + { + "media_type": "video", + "url": "https://vid.com/v.mp4", + # no telegram_file_id — will be downloaded + }, + ] + data = {"url": "https://example.com", "category": "twitter", "message_type": "short"} + + with patch( + "core.services.message_sender.download_file_by_metadata_item", + new_callable=AsyncMock, + return_value=mock_io, + ), patch( + "core.services.message_sender.settings" + ) as mock_settings: + mock_settings.TELEBOT_API_SERVER = "http://local:8081/bot" + media_group, file_group, uncached = await media_files_packaging(media_files, data) + + assert len(media_group) == 1 # one group with 2 items + assert len(media_group[0]) == 2 + assert len(uncached) == 2 + assert uncached[0] is None # cached + assert uncached[1] is not None # downloaded + assert uncached[1]["media_type"] == "video" + + @pytest.mark.asyncio + async def test_file_id_skips_even_for_long_messages(self): + """file_id check is above the message_type guard, so it works for long messages too.""" + from core.services.message_sender import media_files_packaging + + media_files = [ + { + "media_type": "image", + "url": "https://img.com/1.jpg", + "telegram_file_id": "AgACAgI123", + }, + ] + data = {"url": "https://example.com", "category": "twitter", "message_type": "long"} + + media_group, file_group, uncached = await media_files_packaging(media_files, data) + + # Should still use file_id even though message_type is "long" + assert len(media_group) == 1 + assert uncached == [None] + + class TestSendItemMessageExceptionHandling: @pytest.mark.asyncio async def test_exception_logged_and_sent_to_debug_channel(self):