diff --git a/pyproject.toml b/pyproject.toml index e54707c..83f9a74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ server = [ "markdownify>=1.1.0", # Scheduling "apscheduler>=3.10.0,<4.0.0", + "opencite>=0.5.2", ] observability = [ diff --git a/src/knowledge/papers_sync.py b/src/knowledge/papers_sync.py index 95eea0d..640ad4c 100644 --- a/src/knowledge/papers_sync.py +++ b/src/knowledge/papers_sync.py @@ -1,258 +1,289 @@ -"""Paper sync from OpenALEX, Semantic Scholar, and PubMed Central. +"""Paper sync backed by opencite. -Syncs papers for community-configured search queries. -Only stores title, abstract snippet, URL, and publication date. +Fetches papers through the `opencite` multi-source search/citation client and +writes them into the local knowledge database. opencite aggregates and +deduplicates across OpenAlex, Semantic Scholar, PubMed (and more), replacing +the previous hand-rolled per-source fetchers and inverted-index handling. -Rate limits: -- OpenALEX: No key required, generous limits -- Semantic Scholar: ~100 requests/5 min (free), higher with API key -- PubMed: ~3 requests/sec without key, 10/sec with key +Public sync functions keep their original signatures so the CLI +(`src/cli/sync.py`) and the scheduler (`src/api/scheduler.py`) call them +unchanged; only the fetch layer is swapped. + +See: https://github.com/neuromechanist/opencite """ +import asyncio import logging -import time -import xml.etree.ElementTree as ET -from typing import Any +from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor -import httpx -import pyalex -from pyalex import Works +from opencite import Config, Paper +from opencite.citations import CitationExplorer +from opencite.search import SearchOrchestrator from src.knowledge.db import get_connection, update_sync_metadata, upsert_paper logger = logging.getLogger(__name__) -# Rate limiting settings -SEMANTIC_SCHOLAR_DELAY = 3.0 # seconds between requests (to stay under 100/5min) -PUBMED_DELAY = 0.4 # seconds between requests (to stay under 3/sec) +# Scholarly sources synced by default. opencite also supports arxiv, biorxiv, +# medrxiv, osf, zenodo, figshare, crossref and core; those broader sources are +# reserved for the opt-in live-search feature (issue #308) so batch sync stays +# focused on peer-reviewed literature and matches prior coverage. +DEFAULT_SOURCES: tuple[str, ...] = ("openalex", "s2", "pubmed") + +# opencite source name -> OSA `papers.source` label. Kept stable so dedup and +# the existing rows in the database (openalex / semanticscholar / pubmed) line +# up with newly synced papers. +_OSA_SOURCE_BY_OPENCITE: dict[str, str] = { + "openalex": "openalex", + "s2": "semanticscholar", + "pubmed": "pubmed", +} +# OSA source label -> opencite source name (used to restrict per-source syncs). +_OPENCITE_SOURCE_BY_OSA: dict[str, str] = {v: k for k, v in _OSA_SOURCE_BY_OPENCITE.items()} + +# OpenAlex credentials set via configure_openalex(); merged into the per-sync +# Config as a fallback when explicit call arguments are not supplied. This +# preserves the CLI's "configure once, sync many" pattern. +_OPENALEX_API_KEY: str | None = None +_OPENALEX_EMAIL: str | None = None def configure_openalex(api_key: str | None = None, email: str | None = None) -> None: - """Configure pyalex with API key or email for polite pool access. + """Store OpenAlex credentials for subsequent opencite-backed syncs. + + OpenAlex works anonymously; an API key grants premium limits and a contact + email enables the faster polite pool. Values are merged into the opencite + Config built for each sync (explicit per-call arguments still win). Args: - api_key: OpenAlex API key for premium access (~2M requests). - email: Email for polite pool access (faster than anonymous). + api_key: OpenAlex API key for premium access. + email: Contact email for OpenAlex polite pool access. """ - # Treat empty strings as None - api_key = api_key.strip() if api_key else None - email = email.strip() if email else None + global _OPENALEX_API_KEY, _OPENALEX_EMAIL + _OPENALEX_API_KEY = api_key.strip() if api_key and api_key.strip() else None + _OPENALEX_EMAIL = email.strip() if email and email.strip() else None - if api_key: - pyalex.config.api_key = api_key + if _OPENALEX_API_KEY: logger.info("OpenAlex configured with API key") - elif email: - pyalex.config.email = email - logger.info("OpenAlex configured with email: %s (polite pool)", email) + elif _OPENALEX_EMAIL: + logger.info("OpenAlex configured with email: %s (polite pool)", _OPENALEX_EMAIL) else: logger.debug("OpenAlex using anonymous access (lower rate limits)") -def _reconstruct_abstract(inverted_index: dict[str, list[int]] | None) -> str: - """Reconstruct abstract from OpenALEX inverted index format. - - OpenALEX stores abstracts as inverted indexes: {"word": [positions]} - This function reconstructs the original text. - - Args: - inverted_index: Dict mapping words to their positions - - Returns: - Reconstructed abstract text - """ - if not inverted_index: - return "" - - # Find max position to size the array - max_pos = 0 - for positions in inverted_index.values(): - if positions: - max_pos = max(max_pos, max(positions)) - - # Build word array - words = [""] * (max_pos + 1) - for word, positions in inverted_index.items(): - for pos in positions: - words[pos] = word - - return " ".join(words) - - -def _get_paper_url(doi: str | None, fallback_id: str) -> str: - """Get paper URL, preferring DOI when available. - - Args: - doi: The DOI (may be full URL or bare DOI) - fallback_id: Fallback URL/ID if no DOI - - Returns: - URL string +def _build_config( + *, + openalex_api_key: str | None = None, + openalex_email: str | None = None, + semantic_scholar_api_key: str | None = None, + pubmed_api_key: str | None = None, +) -> Config: + """Build an opencite Config from explicit args and configure_openalex(). + + Credentials come from OSA settings (passed explicitly) with a fallback to + values set via configure_openalex(). We construct Config directly rather + than Config.from_env() so paper sync never depends on ambient ``.env`` + files in the working directory, which are environment-specific and have + tripped opencite's dotenv loader. """ - if not doi: - return fallback_id - return doi if doi.startswith("http") else f"https://doi.org/{doi}" - - -def _get_openalex_external_id(openalex_id: str) -> str: - """Extract external ID from OpenALEX URL. - - Args: - openalex_id: Full OpenALEX URL or bare ID - - Returns: - Bare external ID (e.g., "W12345") + return Config( + openalex_api_key=openalex_api_key or _OPENALEX_API_KEY or "", + contact_email=openalex_email or _OPENALEX_EMAIL or "", + semantic_scholar_api_key=semantic_scholar_api_key or "", + pubmed_api_key=pubmed_api_key or "", + ) + + +def _native_id(paper: Paper, osa_source: str) -> str: + """Return the identifier matching a specific OSA source label, or ''.""" + ids = paper.ids + if osa_source == "openalex": + return ids.openalex_id.removeprefix("https://openalex.org/") if ids.openalex_id else "" + if osa_source == "semanticscholar": + return ids.s2_id or "" + if osa_source == "pubmed": + return ids.pmid or "" + return "" + + +def _paper_source_and_id(paper: Paper) -> tuple[str | None, str | None]: + """Pick a stable (source, external_id) for the papers table. + + Prefers identifiers in the order OpenAlex > Semantic Scholar > PubMed > DOI + > arXiv so a paper maps to the same row across syncs and aligns with rows + already stored from the previous per-source fetchers. Returns (None, None) + when no usable identifier is present (such papers are skipped). """ - return openalex_id.removeprefix("https://openalex.org/") - - -def sync_openalex_papers(query: str, max_results: int = 100, project: str = "hed") -> int: - """Sync papers from OpenALEX matching query. + ids = paper.ids + openalex = ids.openalex_id.removeprefix("https://openalex.org/") if ids.openalex_id else "" + if openalex: + return "openalex", openalex + if ids.s2_id: + return "semanticscholar", ids.s2_id + if ids.pmid: + return "pubmed", ids.pmid + if ids.doi: + return "doi", ids.doi.lower() + if ids.arxiv_id: + return "arxiv", ids.arxiv_id + return None, None + + +def _paper_url(paper: Paper) -> str: + """Best link for a paper, preferring a stable DOI landing page.""" + if paper.doi: + return f"https://doi.org/{paper.doi}" + if paper.url: + return paper.url + if paper.best_pdf_url: + return paper.best_pdf_url + return "" + + +def _store_papers( + papers: Iterable[Paper], + project: str, + *, + force_source: str | None = None, +) -> dict[str, int]: + """Upsert opencite papers into the knowledge DB, returning counts by source. Args: - query: Search query - max_results: Maximum number of papers to sync - project: Assistant/project name for database isolation. Defaults to 'hed'. - - Returns: - Number of papers synced + papers: opencite Paper objects to store. + project: Community/project ID for database isolation. + force_source: When set (a single-source sync), record this OSA source + label using its native identifier; falls back to the priority + mapping if that identifier is missing. """ - logger.info("Syncing OpenALEX papers for query: %s", query) - - try: - # Build query and fetch results - # pyalex returns a lazy query object, need to call .get() to fetch results - works_query = ( - Works() - .search(query) - .select( - [ - "id", - "title", - "abstract_inverted_index", - "publication_date", - "doi", - "primary_location", - ] - ) - ) - # Fetch up to max_results using pagination - works = list(works_query.get(per_page=min(max_results, 200))) - except Exception as e: - logger.warning("OpenALEX error for '%s': %s", query, e) - return 0 - - count = 0 + counts: dict[str, int] = {} with get_connection(project) as conn: - for work in works: - if count >= max_results: - break - - # Skip if no title - title = work.get("title") - if not title: + for paper in papers: + if not paper.title: continue - abstract = _reconstruct_abstract(work.get("abstract_inverted_index")) - url = _get_paper_url(work.get("doi"), work.get("id", "")) - external_id = _get_openalex_external_id(work.get("id", "")) + if force_source: + external_id = _native_id(paper, force_source) + source: str | None = force_source if external_id else None + if not source: + source, external_id = _paper_source_and_id(paper) + else: + source, external_id = _paper_source_and_id(paper) + + if not source or not external_id: + continue upsert_paper( conn, - source="openalex", + source=source, external_id=external_id, - title=title, - first_message=abstract, - url=url, - created_at=work.get("publication_date"), + title=paper.title, + first_message=paper.abstract or None, + url=_paper_url(paper), + created_at=paper.publication_date or (str(paper.year) if paper.year else None), ) - count += 1 - + counts[source] = counts.get(source, 0) + 1 conn.commit() + return counts - logger.info("Synced %d papers from OpenALEX for '%s'", count, query) - update_sync_metadata("papers", f"openalex:{query}", count, project) - return count +def _run(coro): + """Execute an async coroutine from synchronous code. -def sync_semanticscholar_papers( + OSA's sync callers (CLI command, scheduler thread) have no running event + loop, so asyncio.run is used directly. If a loop is already running in the + calling thread, the coroutine runs in a dedicated worker thread so these + public sync functions stay safe to call from any context. + """ + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + with ThreadPoolExecutor(max_workers=1) as pool: + return pool.submit(asyncio.run, coro).result() + + +async def _search_queries( + config: Config, + queries: list[str], + max_results: int, + sources: tuple[str, ...] | None, +) -> list[tuple[str, list[Paper]]]: + """Search every query through one shared opencite orchestrator. + + A single orchestrator (and its HTTP client pool) is opened for the whole + batch. A failure for an individual query is logged and yields an empty + result for that query rather than aborting the batch. + """ + out: list[tuple[str, list[Paper]]] = [] + async with SearchOrchestrator(config) as searcher: + for query in queries: + try: + result = await searcher.search(query, max_results=max_results, sources=sources) + out.append((query, result.papers)) + except Exception as e: + logger.warning("opencite search error for '%s': %s", query, e) + out.append((query, [])) + return out + + +async def _citing_for_dois( + config: Config, + dois: list[str], + max_results: int, +) -> list[tuple[str, list[Paper]]]: + """Fetch citing papers for every DOI through one shared CitationExplorer.""" + out: list[tuple[str, list[Paper]]] = [] + async with CitationExplorer(config) as explorer: + for doi in dois: + try: + result = await explorer.citing_papers(doi, max_results=max_results) + out.append((doi, result.papers)) + except Exception as e: + logger.warning("opencite citation error for DOI %s: %s", doi, e) + out.append((doi, [])) + return out + + +def _sync_single_source( query: str, - max_results: int = 100, - api_key: str | None = None, - project: str = "hed", + max_results: int, + project: str, + osa_source: str, + config: Config, ) -> int: - """Sync papers from Semantic Scholar matching query. - - Args: - query: Search query - max_results: Maximum number of papers to sync - api_key: Optional API key for higher rate limits - project: Assistant/project name for database isolation. Defaults to 'hed'. - - Returns: - Number of papers synced - """ - logger.info("Syncing Semantic Scholar papers for query: %s", query) - - url = "https://api.semanticscholar.org/graph/v1/paper/search" - params: dict[str, Any] = { - "query": query, - "limit": min(max_results, 100), # API limit per request - "fields": "paperId,title,abstract,year,url,openAccessPdf", - } - - headers = {} - if api_key: - headers["x-api-key"] = api_key - + """Sync papers for one source (restricted opencite search) into the DB.""" + opencite_source = _OPENCITE_SOURCE_BY_OSA[osa_source] try: - response = httpx.get(url, params=params, headers=headers, timeout=30.0) - response.raise_for_status() - data = response.json() - except httpx.HTTPStatusError as e: - logger.warning("Semantic Scholar HTTP error for '%s': %s", query, e) - return 0 - except httpx.RequestError as e: - logger.warning("Semantic Scholar request error for '%s': %s", query, e) + searched = _run(_search_queries(config, [query], max_results, (opencite_source,))) + except Exception as e: + logger.warning("opencite %s search failed for '%s': %s", osa_source, query, e) return 0 - count = 0 - with get_connection(project) as conn: - for paper in data.get("data", []): - if count >= max_results: - break - - # Skip if no title - title = paper.get("title") - if not title: - continue - - paper_id = paper.get("paperId", "") - paper_url = paper.get("url") or f"https://www.semanticscholar.org/paper/{paper_id}" + _, papers = searched[0] + counts = _store_papers(papers, project, force_source=osa_source) + count = sum(counts.values()) + logger.info("Synced %d papers from %s for '%s'", count, osa_source, query) + update_sync_metadata("papers", f"{osa_source}:{query}", count, project) + return count - # Prefer open access PDF URL if available - open_access = paper.get("openAccessPdf") - if open_access and open_access.get("url"): - paper_url = open_access["url"] - upsert_paper( - conn, - source="semanticscholar", - external_id=paper_id, - title=title, - first_message=paper.get("abstract"), - url=paper_url, - created_at=str(paper.get("year")) if paper.get("year") else None, - ) - count += 1 - - conn.commit() +def sync_openalex_papers(query: str, max_results: int = 100, project: str = "hed") -> int: + """Sync papers from OpenAlex matching query (via opencite).""" + logger.info("Syncing OpenAlex papers for query: %s", query) + return _sync_single_source(query, max_results, project, "openalex", _build_config()) - logger.info("Synced %d papers from Semantic Scholar for '%s'", count, query) - update_sync_metadata("papers", f"semanticscholar:{query}", count, project) - # Rate limiting - time.sleep(SEMANTIC_SCHOLAR_DELAY) - return count +def sync_semanticscholar_papers( + query: str, + max_results: int = 100, + api_key: str | None = None, + project: str = "hed", +) -> int: + """Sync papers from Semantic Scholar matching query (via opencite).""" + logger.info("Syncing Semantic Scholar papers for query: %s", query) + config = _build_config(semantic_scholar_api_key=api_key) + return _sync_single_source(query, max_results, project, "semanticscholar", config) def sync_pubmed_papers( @@ -261,109 +292,10 @@ def sync_pubmed_papers( api_key: str | None = None, project: str = "hed", ) -> int: - """Sync papers from PubMed matching query. - - Uses NCBI E-utilities API (esearch + efetch). - - Args: - query: Search query - max_results: Maximum number of papers to sync - api_key: Optional NCBI API key for higher rate limits - project: Assistant/project name for database isolation. Defaults to 'hed'. - - Returns: - Number of papers synced - """ + """Sync papers from PubMed matching query (via opencite).""" logger.info("Syncing PubMed papers for query: %s", query) - - base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" - - # Step 1: Search for paper IDs - search_params: dict[str, Any] = { - "db": "pubmed", - "term": query, - "retmax": max_results, - "retmode": "json", - } - if api_key: - search_params["api_key"] = api_key - - try: - search_response = httpx.get(f"{base_url}/esearch.fcgi", params=search_params, timeout=30.0) - search_response.raise_for_status() - search_data = search_response.json() - except (httpx.HTTPStatusError, httpx.RequestError) as e: - logger.warning("PubMed search error for '%s': %s", query, e) - return 0 - - id_list = search_data.get("esearchresult", {}).get("idlist", []) - if not id_list: - logger.info("No PubMed results for '%s'", query) - return 0 - - # Rate limiting between requests - time.sleep(PUBMED_DELAY) - - # Step 2: Fetch paper details - fetch_params: dict[str, Any] = { - "db": "pubmed", - "id": ",".join(id_list), - "retmode": "xml", - } - if api_key: - fetch_params["api_key"] = api_key - - try: - fetch_response = httpx.get(f"{base_url}/efetch.fcgi", params=fetch_params, timeout=60.0) - fetch_response.raise_for_status() - except (httpx.HTTPStatusError, httpx.RequestError) as e: - logger.warning("PubMed fetch error for '%s': %s", query, e) - return 0 - - # Parse XML response - try: - root = ET.fromstring(fetch_response.text) - except ET.ParseError as e: - logger.warning("PubMed XML parse error for '%s': %s", query, e) - return 0 - - count = 0 - with get_connection(project) as conn: - for article in root.findall(".//PubmedArticle"): - pmid_elem = article.find(".//PMID") - title_elem = article.find(".//ArticleTitle") - abstract_elem = article.find(".//AbstractText") - year_elem = article.find(".//PubDate/Year") - - if pmid_elem is None or title_elem is None: - continue - - pmid = pmid_elem.text or "" - title = title_elem.text or "" - abstract = abstract_elem.text if abstract_elem is not None else None - year = year_elem.text if year_elem is not None else None - - url = f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/" - - upsert_paper( - conn, - source="pubmed", - external_id=pmid, - title=title, - first_message=abstract, - url=url, - created_at=year, - ) - count += 1 - - conn.commit() - - logger.info("Synced %d papers from PubMed for '%s'", count, query) - update_sync_metadata("papers", f"pubmed:{query}", count, project) - - # Rate limiting - time.sleep(PUBMED_DELAY) - return count + config = _build_config(pubmed_api_key=api_key) + return _sync_single_source(query, max_results, project, "pubmed", config) def sync_all_papers( @@ -375,19 +307,23 @@ def sync_all_papers( openalex_email: str | None = None, project: str = "hed", ) -> dict[str, int]: - """Sync papers from all sources for given queries. + """Sync papers from all default sources for given queries via opencite. + + A single deduplicated opencite search runs per query across + ``DEFAULT_SOURCES``, replacing the previous three sequential per-source + fetches. Args: - queries: List of search queries (required - no default queries) - max_results: Max results per query per source - semantic_scholar_api_key: Optional Semantic Scholar API key - pubmed_api_key: Optional PubMed/NCBI API key - openalex_api_key: Optional OpenAlex API key for premium access - openalex_email: Optional email for OpenAlex polite pool - project: Project/community ID for database isolation + queries: List of search queries (required - no default queries). + max_results: Max deduplicated results per query. + semantic_scholar_api_key: Optional Semantic Scholar API key. + pubmed_api_key: Optional PubMed/NCBI API key. + openalex_api_key: Optional OpenAlex API key for premium access. + openalex_email: Optional email for OpenAlex polite pool. + project: Project/community ID for database isolation. Returns: - Dict mapping source to total items synced + Dict mapping OSA source label to total papers synced. """ if isinstance(queries, str): raise TypeError(f"queries must be a list of strings, not a bare string: {queries!r}") @@ -395,21 +331,25 @@ def sync_all_papers( logger.warning("No queries provided for paper sync") return {"openalex": 0, "semanticscholar": 0, "pubmed": 0} - # Configure OpenAlex with API key or email if provided - configure_openalex(api_key=openalex_api_key, email=openalex_email) + config = _build_config( + openalex_api_key=openalex_api_key, + openalex_email=openalex_email, + semantic_scholar_api_key=semantic_scholar_api_key, + pubmed_api_key=pubmed_api_key, + ) - results = { - "openalex": 0, - "semanticscholar": 0, - "pubmed": 0, - } + results: dict[str, int] = {"openalex": 0, "semanticscholar": 0, "pubmed": 0} + try: + searched = _run(_search_queries(config, queries, max_results, DEFAULT_SOURCES)) + except Exception as e: + logger.warning("opencite search failed for %s: %s", project, e) + return results - for query in queries: - results["openalex"] += sync_openalex_papers(query, max_results, project=project) - results["semanticscholar"] += sync_semanticscholar_papers( - query, max_results, semantic_scholar_api_key, project=project - ) - results["pubmed"] += sync_pubmed_papers(query, max_results, pubmed_api_key, project=project) + for query, papers in searched: + counts = _store_papers(papers, project) + for source, n in counts.items(): + results[source] = results.get(source, 0) + n + update_sync_metadata("papers", f"opencite:{query}", sum(counts.values()), project) total = sum(results.values()) logger.info("Total papers synced for %s: %d", project, total) @@ -423,92 +363,35 @@ def sync_citing_papers( openalex_api_key: str | None = None, openalex_email: str | None = None, ) -> int: - """Sync papers that cite the given DOIs using OpenALEX. - - OpenALEX supports finding papers that cite a specific work via - the `cites` filter. This is useful for tracking citations to - foundational papers in a field. + """Sync papers that cite the given DOIs using opencite's citation graph. Args: - dois: List of DOIs to find citations for. Should be in bare format - (e.g., "10.1016/j.neuroimage.2021.118809") without the - https://doi.org/ prefix. Invalid or unfound DOIs are skipped - with a warning log. - max_results: Maximum number of citing papers per DOI - project: Project/community ID for database isolation - openalex_api_key: Optional OpenAlex API key for premium access - openalex_email: Optional email for OpenAlex polite pool + dois: List of DOIs to find citations for. Bare format preferred + (e.g. "10.1016/j.neuroimage.2021.118809"); opencite auto-detects + and resolves the identifier. Unresolved DOIs are skipped with a + warning. + max_results: Maximum number of citing papers per DOI. + project: Project/community ID for database isolation. + openalex_api_key: Optional OpenAlex API key for premium access. + openalex_email: Optional email for OpenAlex polite pool. Returns: - Total number of citing papers synced + Total number of citing papers synced. """ if isinstance(dois, str): raise TypeError(f"dois must be a list of strings, not a bare string: {dois!r}") - configure_openalex(api_key=openalex_api_key, email=openalex_email) - total = 0 - for doi in dois: - logger.info("Syncing papers citing DOI: %s", doi) - - try: - # First, look up the OpenALEX work ID for this DOI - work_lookup = Works()[f"https://doi.org/{doi}"] - openalex_id = work_lookup.get("id") - - if not openalex_id: - logger.warning("Could not find OpenALEX ID for DOI %s", doi) - continue + config = _build_config(openalex_api_key=openalex_api_key, openalex_email=openalex_email) + try: + cited = _run(_citing_for_dois(config, dois, max_results)) + except Exception as e: + logger.warning("opencite citation lookup failed for %s: %s", project, e) + return 0 - logger.debug("Found OpenALEX ID %s for DOI %s", openalex_id, doi) - - # Now find papers that cite this work using the OpenALEX ID - works_query = ( - Works() - .filter(cites=openalex_id) - .select( - [ - "id", - "title", - "abstract_inverted_index", - "publication_date", - "doi", - "primary_location", - ] - ) - ) - works = list(works_query.get(per_page=min(max_results, 200))) - except Exception as e: - logger.warning("OpenALEX citation error for DOI %s: %s", doi, e) - continue - - count = 0 - with get_connection(project) as conn: - for work in works: - if count >= max_results: - break - - title = work.get("title") - if not title: - continue - - abstract = _reconstruct_abstract(work.get("abstract_inverted_index")) - url = _get_paper_url(work.get("doi"), work.get("id", "")) - external_id = _get_openalex_external_id(work.get("id", "")) - - upsert_paper( - conn, - source="openalex", - external_id=external_id, - title=title, - first_message=abstract, - url=url, - created_at=work.get("publication_date"), - ) - count += 1 - - conn.commit() - - # Update sync metadata with citing_ prefix to distinguish from query-based syncs + total = 0 + for doi, papers in cited: + counts = _store_papers(papers, project) + count = sum(counts.values()) update_sync_metadata("papers", f"citing_{doi}", count, project) logger.info("Synced %d papers citing %s", count, doi) total += count diff --git a/tests/test_knowledge/test_papers_sync.py b/tests/test_knowledge/test_papers_sync.py index ab2d123..323d0f4 100644 --- a/tests/test_knowledge/test_papers_sync.py +++ b/tests/test_knowledge/test_papers_sync.py @@ -1,17 +1,23 @@ -"""Tests for papers sync module. +"""Tests for the opencite-backed papers sync module. -Note: These are real API tests, not mocks, per project guidelines. +Mapping tests use real opencite ``Paper`` objects and a real SQLite database +(no mocks). The sync smoke tests make real network calls per project +guidelines. """ +import asyncio from pathlib import Path from unittest.mock import patch -import pyalex import pytest +from opencite import IDSet, Paper +import src.knowledge.papers_sync as ps from src.knowledge.db import get_connection, init_db from src.knowledge.papers_sync import ( - _reconstruct_abstract, + _paper_source_and_id, + _paper_url, + _store_papers, configure_openalex, sync_all_papers, sync_citing_papers, @@ -29,117 +35,181 @@ def temp_db(tmp_path: Path): class TestConfigureOpenalex: - """Tests for configure_openalex helper.""" + """Tests for the configure_openalex credential helper.""" def setup_method(self): - """Reset pyalex config before each test.""" - pyalex.config.api_key = None - pyalex.config.email = None + """Reset stored OpenAlex credentials before each test.""" + configure_openalex(api_key=None, email=None) def teardown_method(self): - """Reset pyalex config after each test.""" - pyalex.config.api_key = None - pyalex.config.email = None + """Reset stored OpenAlex credentials after each test.""" + configure_openalex(api_key=None, email=None) def test_sets_api_key(self): - """Should set pyalex.config.api_key when api_key provided.""" configure_openalex(api_key="test-key-123") - assert pyalex.config.api_key == "test-key-123" + assert ps._OPENALEX_API_KEY == "test-key-123" - def test_sets_email_when_no_api_key(self): - """Should set pyalex.config.email when only email provided.""" + def test_sets_email(self): configure_openalex(email="test@example.com") - assert pyalex.config.email == "test@example.com" + assert ps._OPENALEX_EMAIL == "test@example.com" + assert ps._OPENALEX_API_KEY is None - def test_api_key_takes_precedence_over_email(self): - """Should use API key over email when both provided.""" + def test_sets_both_key_and_email(self): configure_openalex(api_key="test-key", email="test@example.com") - assert pyalex.config.api_key == "test-key" + assert ps._OPENALEX_API_KEY == "test-key" + assert ps._OPENALEX_EMAIL == "test@example.com" def test_handles_empty_strings(self): - """Should treat empty strings as None (no config).""" configure_openalex(api_key="", email="") - assert pyalex.config.api_key is None - assert pyalex.config.email is None + assert ps._OPENALEX_API_KEY is None + assert ps._OPENALEX_EMAIL is None def test_handles_whitespace_strings(self): - """Should strip whitespace and treat blank as None.""" configure_openalex(api_key=" ", email=" ") - assert pyalex.config.api_key is None - assert pyalex.config.email is None + assert ps._OPENALEX_API_KEY is None + assert ps._OPENALEX_EMAIL is None def test_handles_none_values(self): - """Should handle None values gracefully (anonymous access).""" configure_openalex(api_key=None, email=None) - assert pyalex.config.api_key is None - assert pyalex.config.email is None - - -class TestAbstractReconstruction: - """Test OpenALEX inverted index reconstruction.""" - - def test_reconstruct_abstract_basic(self): - """Test basic abstract reconstruction from inverted index.""" - inverted_index = { - "hello": [0], - "world": [1], - } - result = _reconstruct_abstract(inverted_index) - assert "hello" in result - assert "world" in result - - def test_reconstruct_abstract_with_gaps(self): - """Test reconstruction with gaps in position array.""" - inverted_index = { - "hello": [0], - "world": [2], # Missing position 1 - } - result = _reconstruct_abstract(inverted_index) - # Should handle gaps gracefully (empty string at position 1) - assert "hello" in result - assert "world" in result - - def test_reconstruct_abstract_empty(self): - """Test reconstruction with empty/None input.""" - assert _reconstruct_abstract(None) == "" - assert _reconstruct_abstract({}) == "" - - def test_reconstruct_abstract_complex(self): - """Test reconstruction with longer text.""" - inverted_index = { - "Hierarchical": [0], - "Event": [1], - "Descriptors": [2], - "(HED)": [3], - "is": [4], - "a": [5], - "framework": [6], - } - result = _reconstruct_abstract(inverted_index) - expected_words = ["Hierarchical", "Event", "Descriptors", "HED", "framework"] - for word in expected_words: - assert word in result + assert ps._OPENALEX_API_KEY is None + assert ps._OPENALEX_EMAIL is None + + +class TestPaperMapping: + """Map opencite Paper objects to (source, external_id) and URLs.""" + + def test_prefers_openalex_id(self): + paper = Paper( + title="X", + ids=IDSet(openalex_id="https://openalex.org/W7", doi="10.1/A", pmid="9"), + ) + assert _paper_source_and_id(paper) == ("openalex", "W7") + + def test_falls_back_to_semantic_scholar(self): + paper = Paper(title="Y", ids=IDSet(s2_id="S99")) + assert _paper_source_and_id(paper) == ("semanticscholar", "S99") + + def test_falls_back_to_pubmed(self): + paper = Paper(title="Y", ids=IDSet(pmid="12345")) + assert _paper_source_and_id(paper) == ("pubmed", "12345") + + def test_falls_back_to_doi_lowercased(self): + paper = Paper(title="Y", ids=IDSet(doi="10.1/AbC")) + assert _paper_source_and_id(paper) == ("doi", "10.1/abc") + + def test_falls_back_to_arxiv(self): + paper = Paper(title="Y", ids=IDSet(arxiv_id="2106.15928")) + assert _paper_source_and_id(paper) == ("arxiv", "2106.15928") + + def test_no_identifier_is_skipped(self): + paper = Paper(title="orphan", ids=IDSet()) + assert _paper_source_and_id(paper) == (None, None) + + def test_url_prefers_doi_landing_page(self): + paper = Paper(title="X", ids=IDSet(doi="10.1/A"), url="https://openalex.org/W7") + assert _paper_url(paper) == "https://doi.org/10.1/A" + + def test_url_falls_back_to_paper_url(self): + paper = Paper(title="X", ids=IDSet(), url="https://example.org/p") + assert _paper_url(paper) == "https://example.org/p" + + def test_url_empty_when_nothing_available(self): + paper = Paper(title="X", ids=IDSet()) + assert _paper_url(paper) == "" + + +class TestStorePapers: + """Persist opencite papers into the knowledge DB (real SQLite, no mocks).""" + + def test_stores_and_labels_sources(self, temp_db: Path): + papers = [ + Paper( + title="EEGLAB toolbox", + ids=IDSet(openalex_id="https://openalex.org/W1", doi="10.1/eeglab"), + year=2004, + abstract="An open source toolbox.", + ), + Paper(title="S2 paper", ids=IDSet(s2_id="S2"), year=2020), + ] + with patch("src.knowledge.db.get_db_path", return_value=temp_db): + counts = _store_papers(papers, "test") + + assert counts == {"openalex": 1, "semanticscholar": 1} + with get_connection("test") as conn: + rows = { + r["source"]: r + for r in conn.execute("SELECT source, external_id, url, title FROM papers") + } + assert rows["openalex"]["external_id"] == "W1" + assert rows["openalex"]["url"] == "https://doi.org/10.1/eeglab" + assert rows["semanticscholar"]["external_id"] == "S2" + + def test_skips_papers_without_title_or_id(self, temp_db: Path): + papers = [ + Paper(title="", ids=IDSet(openalex_id="https://openalex.org/W1")), + Paper(title="no id", ids=IDSet()), + ] + with patch("src.knowledge.db.get_db_path", return_value=temp_db): + counts = _store_papers(papers, "test") + assert counts == {} + + def test_upsert_deduplicates_same_paper(self, temp_db: Path): + paper = Paper(title="dup", ids=IDSet(openalex_id="https://openalex.org/W1"), year=2020) + with patch("src.knowledge.db.get_db_path", return_value=temp_db): + _store_papers([paper], "test") + _store_papers([paper], "test") + with get_connection("test") as conn: + count = conn.execute("SELECT COUNT(*) AS c FROM papers").fetchone()["c"] + assert count == 1 + + def test_force_source_uses_native_id(self, temp_db: Path): + # A PubMed-restricted sync should label the row 'pubmed' using the PMID, + # even though the paper also carries an OpenAlex id. + paper = Paper( + title="P", + ids=IDSet(openalex_id="https://openalex.org/W1", pmid="555"), + ) + with patch("src.knowledge.db.get_db_path", return_value=temp_db): + counts = _store_papers([paper], "test", force_source="pubmed") + assert counts == {"pubmed": 1} + with get_connection("test") as conn: + row = conn.execute("SELECT source, external_id FROM papers").fetchone() + assert row["source"] == "pubmed" + assert row["external_id"] == "555" + + +async def _answer() -> int: + return 42 + + +class TestRunHelper: + """The _run async bridge must work with or without a running event loop.""" + + def test_runs_without_existing_loop(self): + # Sync context (CLI / scheduler thread): asyncio.run path. + assert ps._run(_answer()) == 42 + + def test_runs_inside_running_loop(self): + # If a loop is already running, _run offloads to a worker thread instead + # of raising "asyncio.run() cannot be called from a running event loop". + async def driver() -> int: + return ps._run(_answer()) + + assert asyncio.run(driver()) == 42 class TestPapersSync: - """Test papers sync functionality.""" + """Smoke tests using real opencite/network calls.""" def test_sync_openalex_papers_basic(self, temp_db: Path): - """Test basic OpenALEX papers sync. - - This is a smoke test using a real OpenALEX API call. - """ + """Basic OpenAlex sync through opencite (real API call).""" with patch("src.knowledge.db.get_db_path", return_value=temp_db): - # Sync a small number of papers with a specific query count = sync_openalex_papers( "Hierarchical Event Descriptors", max_results=5, project="test" ) - # Should find at least some results (OpenALEX doesn't require auth) - # Accept 0 for network issues + # Accept 0 for transient network issues. assert count >= 0 - - # If count > 0, verify data was written if count > 0: with get_connection("test") as conn: row = conn.execute( @@ -147,27 +217,15 @@ def test_sync_openalex_papers_basic(self, temp_db: Path): ).fetchone() assert row["count"] > 0 - def test_sync_openalex_papers_no_results(self, temp_db: Path): - """Test OpenALEX sync with query that returns no results.""" - with patch("src.knowledge.db.get_db_path", return_value=temp_db): - # Use an extremely specific nonsense query - count = sync_openalex_papers("xyzabc123nonsensequery", max_results=5, project="test") - - # Should return 0 for no results (not an error) - assert count == 0 - def test_sync_respects_max_results(self, temp_db: Path): - """Test that max_results parameter is respected.""" + """max_results is respected for a single-source sync.""" with patch("src.knowledge.db.get_db_path", return_value=temp_db): - # Request only 2 results count = sync_openalex_papers("neuroscience", max_results=2, project="test") - - # Should not exceed max_results assert count <= 2 class TestPapersSyncTypeGuard: - """Test that sync functions reject bare strings to prevent character iteration.""" + """Sync functions reject bare strings to prevent character iteration.""" def test_sync_all_papers_rejects_bare_string(self) -> None: with pytest.raises(TypeError, match="must be a list of strings"): diff --git a/uv.lock b/uv.lock index 22d7aee..6995cbd 100644 --- a/uv.lock +++ b/uv.lock @@ -2040,6 +2040,7 @@ dev = [ { name = "lxml" }, { name = "markdownify" }, { name = "mypy" }, + { name = "opencite" }, { name = "pre-commit" }, { name = "psycopg", extra = ["binary"] }, { name = "pyalex" }, @@ -2071,6 +2072,7 @@ server = [ { name = "litellm" }, { name = "lxml" }, { name = "markdownify" }, + { name = "opencite" }, { name = "psycopg", extra = ["binary"] }, { name = "pyalex" }, { name = "pydantic-settings" }, @@ -2112,6 +2114,8 @@ requires-dist = [ { name = "markdownify", marker = "extra == 'dev'", specifier = ">=1.1.0" }, { name = "markdownify", marker = "extra == 'server'", specifier = ">=1.1.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.19.0" }, + { name = "opencite", marker = "extra == 'dev'", specifier = ">=0.5.2" }, + { name = "opencite", marker = "extra == 'server'", specifier = ">=0.5.2" }, { name = "platformdirs", specifier = ">=4.5.0" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=4.5.0" }, { name = "psycopg", extras = ["binary"], marker = "extra == 'dev'", specifier = ">=3.3.0" }, @@ -2157,6 +2161,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/27/4b/7c1a00c2c3fbd004253937f7520f692a9650767aa73894d7a34f0d65d3f4/openai-2.14.0-py3-none-any.whl", hash = "sha256:7ea40aca4ffc4c4a776e77679021b47eec1160e341f42ae086ba949c9dcc9183", size = 1067558, upload-time = "2025-12-19T03:28:43.727Z" }, ] +[[package]] +name = "opencite" +version = "0.5.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, + { name = "pyalex" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7a/18/313a129d8cee89784c409ec19049f09adf6ba515792747b9dd65dbbd9ddb/opencite-0.5.2.tar.gz", hash = "sha256:d7a63482e4d1a0372fd2d01cc246e3d10b74d50393ae153afe53029b490e7d7e", size = 142084, upload-time = "2026-05-21T19:54:15.971Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/b8/d00880ff96e9d3808c6b65e6347b7c4f9fdc8b72213a5dd4cd147224207e/opencite-0.5.2-py3-none-any.whl", hash = "sha256:d4ffb6b553c94df454dbc1de4eed55fdc2ea7604099fea0e3fdbc5a5ea654ce1", size = 96732, upload-time = "2026-05-21T19:54:14.574Z" }, +] + [[package]] name = "opentelemetry-api" version = "1.39.1"