diff --git a/dashboard/osa/index.html b/dashboard/osa/index.html index 991725e..248fcb5 100644 --- a/dashboard/osa/index.html +++ b/dashboard/osa/index.html @@ -697,6 +697,7 @@

Admin Access

let toolsChartInstance = null; let adminTokenChartInstance = null; let adminCostChartInstance = null; + let citationsChartInstance = null; const COLORS = [ '#2563eb', '#1e3a5f', '#059669', '#d97706', '#dc2626', @@ -858,11 +859,12 @@

Communities

document.title = `${safeName.toUpperCase()} - OSA Dashboard`; try { - const [summaryResp, usageResp, syncResp, healthResp] = await Promise.all([ + const [summaryResp, usageResp, syncResp, healthResp, citationsResp] = await Promise.all([ fetch(`${API_BASE}/${encodeURIComponent(communityId)}/metrics/public`), fetch(`${API_BASE}/${encodeURIComponent(communityId)}/metrics/public/usage?period=${activePeriod}`), fetch(`${API_BASE}/sync/status?community_id=${encodeURIComponent(communityId)}`).catch(err => { console.warn('Sync status fetch failed (non-critical):', err.message); return null; }), fetch(`${API_BASE}/sync/health?community_id=${encodeURIComponent(communityId)}`).catch(err => { console.warn('Health check fetch failed (non-critical):', err.message); return null; }), + fetch(`${API_BASE}/${encodeURIComponent(communityId)}/citations`).catch(err => { console.warn('Citations fetch failed (non-critical):', err.message); return null; }), ]); const failedStatus = !summaryResp.ok ? summaryResp.status : (!usageResp.ok ? usageResp.status : null); @@ -872,8 +874,10 @@

Communities

const usage = await usageResp.json(); const sync = syncResp && syncResp.ok ? await syncResp.json() : null; const health = healthResp && healthResp.ok ? await healthResp.json() : null; + // Citations feed is opt-in per community; a 404 just means it is off. + const citations = citationsResp && citationsResp.ok ? await citationsResp.json() : null; - renderCommunityView(summary, usage, sync, health, communityId); + renderCommunityView(summary, usage, sync, health, citations, communityId); document.getElementById('adminCard').style.display = ''; if (adminKey) loadAdminData(communityId); @@ -885,7 +889,7 @@

Communities

} } - function renderCommunityView(summary, usage, sync, health, communityId) { + function renderCommunityView(summary, usage, sync, health, citations, communityId) { const app = document.getElementById('app'); const safeName = escapeHtml(communityId); const meta = communityMeta[communityId] || {}; @@ -917,6 +921,19 @@

Communities

: ''; const links = linkHtml(meta.links, 'community-detail-links'); + // Publication citations card: shown only when the community exposes the + // citations feed and at least one canonical paper has citations. + const hasCitations = citations && citations.by_paper + && Object.keys(citations.by_paper).length > 0; + const citationsCardHtml = hasCitations ? ` +
+

Publication Citations

+

+ ${Number(citations.total || 0).toLocaleString()} papers citing this community's canonical works, by year. +

+
+
` : ''; + app.className = ''; app.innerHTML = `
@@ -974,10 +991,12 @@

Admin: Feedback
Loading feedback...

+ ${citationsCardHtml} `; renderUsageChart(usage); renderToolsChart(summary.top_tools); + renderCitationsChart(citations); } const SYNC_LABELS = { @@ -1154,6 +1173,68 @@

Admin: Feedback byPaper[d]); + const extras = Object.keys(byPaper).filter(d => !configured.includes(d)); + const dois = configured.concat(extras); + + // Union of all years present, sorted ascending for the x-axis. + const yearsSet = new Set(); + dois.forEach(d => Object.keys(byPaper[d]).forEach(y => yearsSet.add(y))); + const years = Array.from(yearsSet).sort((a, b) => Number(a) - Number(b)); + + const datasets = dois.map((doi, idx) => ({ + label: labels[doi] || doi, + data: years.map(y => byPaper[doi][y] || 0), + backgroundColor: seriesColor(idx), + borderWidth: 0, + })); + + citationsChartInstance = new Chart(canvas, { + type: 'bar', + data: { labels: years, datasets }, + options: { + responsive: true, maintainAspectRatio: false, + plugins: { + legend: { position: 'bottom', labels: { boxWidth: 12, font: { size: 11 } } }, + tooltip: { mode: 'index' }, + }, + scales: { + x: { stacked: true }, + y: { stacked: true, beginAtZero: true, ticks: { precision: 0 } }, + }, + } + }); + } + function changePeriod(period, communityId) { activePeriod = period; loadCommunityView(decodeURIComponent(communityId)); diff --git a/src/api/routers/community.py b/src/api/routers/community.py index 0fc6426..c165fb6 100644 --- a/src/api/routers/community.py +++ b/src/api/routers/community.py @@ -18,7 +18,7 @@ from pathlib import Path from typing import Annotated, Any, Literal -from fastapi import APIRouter, Header, HTTPException, Query, Request +from fastapi import APIRouter, Header, HTTPException, Query, Request, Response from fastapi.responses import FileResponse, StreamingResponse from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages.utils import count_tokens_approximately @@ -34,6 +34,7 @@ from src.assistants.registry import AssistantInfo from src.core.config.community import WidgetConfig from src.core.services.litellm_llm import create_openrouter_llm +from src.knowledge.search import FAQResult, get_citation_stats, list_faq_entries from src.metrics.cost import COST_BLOCK_THRESHOLD, COST_WARN_THRESHOLD, MODEL_PRICING, estimate_cost from src.metrics.db import ( RequestLogEntry, @@ -205,6 +206,79 @@ class CommunityConfigResponse(BaseModel): status: str = Field(..., description="Health status: healthy, degraded, or error") +class FAQEntryResponse(BaseModel): + """A single FAQ entry exposed via the public feed.""" + + question: str = Field(..., description="Synthesized question") + answer: str = Field(..., description="Synthesized answer") + tags: list[str] = Field(default_factory=list, description="Keyword tags") + category: str = Field(..., description="Entry category (how-to, troubleshooting, etc.)") + quality_score: float = Field(..., description="LLM quality score (0.0-1.0)") + message_count: int = Field(..., description="Number of source messages in the thread") + first_message_date: str = Field(..., description="Date of the first message in the thread") + thread_url: str = Field(..., description="URL of the source discussion thread") + + +class FAQFeedResponse(BaseModel): + """Paginated public FAQ feed for a community.""" + + community_id: str = Field(..., description="Community identifier") + total: int = Field(..., description="Total entries matching the filters") + limit: int = Field(..., description="Page size used for this response") + offset: int = Field(..., description="Offset used for this response") + entries: list[FAQEntryResponse] = Field(default_factory=list, description="FAQ entries") + + +class CitationsFeedResponse(BaseModel): + """Public citation dashboard data for a community's canonical papers.""" + + community_id: str = Field(..., description="Community identifier") + total: int = Field(..., description="Total citing papers with a recorded canonical link") + per_year: dict[str, int] = Field( + default_factory=dict, description="Citing-paper count per year across all papers" + ) + by_paper: dict[str, dict[str, int]] = Field( + default_factory=dict, + description="Stacked breakdown: canonical DOI -> year -> citing-paper count", + ) + canonical_dois: list[str] = Field( + default_factory=list, description="Canonical DOIs tracked for this community" + ) + labels: dict[str, str] = Field( + default_factory=dict, + description="Human-readable labels per canonical DOI (DOI -> label), when configured", + ) + + +# Matches bare email addresses so they can be stripped from the public feed. +_EMAIL_PATTERN = re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}") + + +def _redact_emails(text: str) -> str: + """Replace any email address in ``text`` with a redaction marker. + + The FAQ feed is derived from public mailing-list content. The summarizer + strips most personal data, but a handful of entries still embed addresses + (mostly vendor support lines). A public JSON feed should not emit raw + addresses, so they are redacted at serialization time. + """ + return _EMAIL_PATTERN.sub("[email redacted]", text) + + +def _faq_result_to_response(entry: FAQResult) -> FAQEntryResponse: + """Convert a knowledge-layer FAQResult into a public response model.""" + return FAQEntryResponse( + question=_redact_emails(entry.question), + answer=_redact_emails(entry.answer), + tags=[_redact_emails(tag) for tag in entry.tags], + category=entry.category, + quality_score=entry.quality_score, + message_count=entry.message_count, + first_message_date=entry.first_message_date, + thread_url=entry.thread_url, + ) + + # --------------------------------------------------------------------------- # Session Management (In-Memory, per-community isolation) # --------------------------------------------------------------------------- @@ -1502,6 +1576,118 @@ async def community_usage_public( detail="Metrics database is temporarily unavailable.", ) + @router.get("/faq", response_model=FAQFeedResponse) + async def community_faq( + response: Response, + q: str | None = Query( + default=None, + description="Optional full-text search phrase. If omitted, browses all entries.", + max_length=200, + ), + category: str | None = Query( + default=None, + description="Filter by category (how-to, troubleshooting, reference, etc.)", + max_length=50, + ), + min_quality: float = Query( + default=0.0, ge=0.0, le=1.0, description="Minimum quality score" + ), + limit: int = Query(default=50, ge=1, le=200, description="Page size"), + offset: int = Query(default=0, ge=0, description="Pagination offset"), + ) -> FAQFeedResponse: + """Public, read-only FAQ feed for this community. + + Returns synthesized question/answer entries generated from the + community's mailing-list and forum archives. Disabled by default; + a community opts in via ``public_feeds.faq: true`` in its config. + Email addresses are redacted from the output. ``total`` is the full + match count before pagination, in both browse and search modes. + """ + config = info.community_config + if config is None or config.public_feeds is None or not config.public_feeds.faq: + raise HTTPException( + status_code=404, + detail="Public FAQ feed is not enabled for this community.", + ) + + try: + entries, total = list_faq_entries( + project=community_id, + limit=limit, + offset=offset, + query=q, + category=category, + min_quality=min_quality, + ) + except sqlite3.Error: + logger.exception("Failed to query FAQ feed for community %s", community_id) + raise HTTPException( + status_code=503, + detail="Knowledge database is temporarily unavailable.", + ) + except Exception: + logger.exception("Unexpected error serving FAQ feed for community %s", community_id) + raise HTTPException( + status_code=500, + detail="An unexpected error occurred while building the FAQ feed.", + ) + + # Public, read-only data; cacheable like the other /…/public endpoints. + response.headers["Cache-Control"] = "public, max-age=3600" + return FAQFeedResponse( + community_id=community_id, + total=total, + limit=limit, + offset=offset, + entries=[_faq_result_to_response(e) for e in entries], + ) + + @router.get("/citations", response_model=CitationsFeedResponse) + async def community_citations(response: Response) -> CitationsFeedResponse: + """Public, read-only citation dashboard for this community. + + Returns per-year counts of papers citing the community's canonical + works, plus a stacked breakdown keyed by the cited DOI (the shape + behind a citations-per-year chart). Disabled by default; a community + opts in via ``public_feeds.citations: true`` in its config. + """ + config = info.community_config + if config is None or config.public_feeds is None or not config.public_feeds.citations: + raise HTTPException( + status_code=404, + detail="Public citations feed is not enabled for this community.", + ) + + try: + stats = get_citation_stats(project=community_id) + except sqlite3.Error: + logger.exception("Failed to query citations for community %s", community_id) + raise HTTPException( + status_code=503, + detail="Knowledge database is temporarily unavailable.", + ) + except Exception: + logger.exception( + "Unexpected error serving citations feed for community %s", community_id + ) + raise HTTPException( + status_code=500, + detail="An unexpected error occurred while building the citations feed.", + ) + + canonical_dois = list(config.citations.dois) if config.citations else [] + labels = dict(config.citations.paper_labels) if config.citations else {} + + response.headers["Cache-Control"] = "public, max-age=3600" + return CitationsFeedResponse( + community_id=community_id, + total=stats.total, + per_year=stats.per_year, + by_paper=stats.by_paper, + canonical_dois=canonical_dois, + labels=labels, + ) + return router diff --git a/src/api/scheduler.py b/src/api/scheduler.py index 555f278..a14e96c 100644 --- a/src/api/scheduler.py +++ b/src/api/scheduler.py @@ -137,6 +137,7 @@ def _run_papers_sync_for_community(community_id: str) -> bool: project=community_id, openalex_api_key=settings.openalex_api_key, openalex_email=settings.openalex_email, + aliases=citations.aliases, ) total += citing_count diff --git a/src/assistants/bids/config.yaml b/src/assistants/bids/config.yaml index 17cb046..fa9f670 100644 --- a/src/assistants/bids/config.yaml +++ b/src/assistants/bids/config.yaml @@ -574,6 +574,31 @@ citations: - "10.1038/s41597-025-05543-2" # MRS-BIDS (Bouchard et al., 2025) # Related ecosystem - "10.1371/journal.pcbi.1005209" # BIDS Apps (Gorgolewski et al., 2017) + # Short labels for the public citations dashboard (stacked series legend) + paper_labels: + "10.1038/sdata.2016.44": "BIDS (Gorgolewski 2016)" + "10.1038/s41597-019-0104-8": "EEG-BIDS (Pernet 2019)" + "10.1038/s41597-019-0105-7": "iEEG-BIDS (Holdgraf 2019)" + "10.1038/sdata.2018.110": "MEG-BIDS (Niso 2018)" + "10.1038/s41597-022-01164-1": "PET-BIDS (Norgaard 2021)" + "10.1177/0271678X20905433": "PET guidelines (Knudsen 2020)" + "10.1093/gigascience/giaa104": "Genetics-BIDS (Moreau 2020)" + "10.3389/fnins.2022.871228": "Microscopy-BIDS (Bourget 2022)" + "10.1038/s41597-022-01571-4": "qMRI-BIDS (Karakuzu 2022)" + "10.1038/s41597-022-01615-9": "ASL-BIDS (Clement 2022)" + "10.1038/s41597-024-04136-9": "NIRS-BIDS (Luke 2025)" + "10.1038/s41597-024-03559-8": "Motion-BIDS (Jeung 2024)" + "10.1038/s41597-025-05543-2": "MRS-BIDS (Bouchard 2025)" + "10.1371/journal.pcbi.1005209": "BIDS Apps (Gorgolewski 2017)" + # Merge preprint + published versions so split OpenAlex citations accumulate + aliases: + "10.1371/journal.pcbi.1005209": # BIDS Apps published (PLoS Comp Biol) + - "10.1101/079145" # BIDS Apps bioRxiv preprint (2016) + +# Expose the citation dashboard as a public, read-only JSON feed +# (GET /bids/citations). FAQ feed stays off: BIDS has no FAQ pipeline configured. +public_feeds: + citations: true # Discourse forums discourse: diff --git a/src/assistants/eeglab/config.yaml b/src/assistants/eeglab/config.yaml index a51c072..31c0d8b 100644 --- a/src/assistants/eeglab/config.yaml +++ b/src/assistants/eeglab/config.yaml @@ -426,6 +426,23 @@ citations: - "10.1016/j.jneumeth.2003.10.009" # EEGLAB: an open source toolbox (Delorme & Makeig, 2004) - "10.1016/j.neuroimage.2019.05.026" # ICLabel: automated EEG IC classification (Pion-Tonachini et al., 2019) - "10.3389/fninf.2015.00016" # PREP: standardized preprocessing (Bigdely-Shamlo et al., 2015) + - "10.1162/IMAG.a.136" # The lab streaming layer for synchronized multimodal recording (Kothe et al., 2025) + # Short labels for the public citations dashboard (stacked series legend) + paper_labels: + "10.1016/j.jneumeth.2003.10.009": "EEGLAB (Delorme 2004)" + "10.1016/j.neuroimage.2019.05.026": "ICLabel (Pion-Tonachini 2019)" + "10.3389/fninf.2015.00016": "PREP (Bigdely-Shamlo 2015)" + "10.1162/IMAG.a.136": "LSL (Kothe 2025)" + # Merge preprint + published versions so split OpenAlex citations accumulate + aliases: + "10.1162/IMAG.a.136": # LSL published (Imaging Neuroscience) + - "10.1101/2024.02.13.580071" # LSL bioRxiv preprint (2024) + +# Expose generated FAQ entries and citation stats as public, read-only JSON feeds +# (GET /eeglab/faq and GET /eeglab/citations). Off by default platform-wide. +public_feeds: + faq: true + citations: true # Mailing list configuration for FAQ generation mailman: diff --git a/src/cli/sync.py b/src/cli/sync.py index ce4503a..a893eed 100644 --- a/src/cli/sync.py +++ b/src/cli/sync.py @@ -99,6 +99,14 @@ def _get_community_paper_dois(community_id: str) -> list[str]: return [] +def _get_community_paper_aliases(community_id: str) -> dict[str, list[str]]: + """Get the primary-DOI -> version-DOIs alias map from the registry.""" + info = registry.get(community_id) + if info and info.community_config and info.community_config.citations: + return info.community_config.citations.aliases + return {} + + def _get_all_community_ids() -> list[str]: """Get all registered community IDs.""" return [info.id for info in registry.list_all()] @@ -364,16 +372,22 @@ def sync_papers( total += count console.print(f" [dim]{src}: {count} papers[/dim]") - # Sync citing papers if DOIs are configured + # Sync citing papers if DOIs are configured. Counts are fetched complete + # (uncapped) from OpenAlex; only the stored sample of recent citing papers + # uses the default cap, independent of the query --limit above. if include_citations: dois = _get_community_paper_dois(community) if dois: - console.print(f"\n[dim]Syncing papers citing {len(dois)} DOI(s)...[/dim]") - with console.status("[green]Syncing citing papers...[/green]"): - citing_count = sync_citing_papers(dois, limit, project=community) + console.print(f"\n[dim]Syncing citations for {len(dois)} DOI(s)...[/dim]") + with console.status("[green]Syncing citations...[/green]"): + citing_count = sync_citing_papers( + dois, + project=community, + aliases=_get_community_paper_aliases(community), + ) results_by_source["citing"] = citing_count total += citing_count - console.print(f"[dim]Citing papers: {citing_count}[/dim]") + console.print(f"[dim]Recent citing papers stored: {citing_count}[/dim]") console.print(f"\n[green]Total papers synced for {community}: {total}[/green]") @@ -579,10 +593,15 @@ def sync_all( ) paper_total += sum(paper_results.values()) - # Sync citing papers + # Sync citing papers. Counts are uncapped; the stored sample uses + # sync_citing_papers' own default cap, not the per-query --limit. if dois: with console.status("[green]Syncing citing papers...[/green]"): - citing_count = sync_citing_papers(dois, max_results=limit, project=comm_id) + citing_count = sync_citing_papers( + dois, + project=comm_id, + aliases=_get_community_paper_aliases(comm_id), + ) paper_total += citing_count console.print(f"[green]Papers: {paper_total} items[/green]") diff --git a/src/core/config/community.py b/src/core/config/community.py index 75d01b4..057cdda 100644 --- a/src/core/config/community.py +++ b/src/core/config/community.py @@ -243,6 +243,76 @@ class CitationConfig(BaseModel): OpenAlex anonymously. Communities opt in explicitly, and their prompt should tell the agent to ask the user before running it.""" + paper_labels: dict[str, str] = Field(default_factory=dict) + """Optional human-readable labels for canonical DOIs (DOI -> short label). + + Used to label the stacked series in the public citations dashboard + (e.g. '10.1038/s41597-019-0104-8' -> 'EEG-BIDS (Pernet 2019)'). Keys are + normalized like ``dois`` so they match the stored ``cites_doi`` values. + DOIs without a label fall back to the bare DOI in consumers.""" + + @field_validator("paper_labels") + @classmethod + def validate_paper_labels(cls, v: dict[str, str]) -> dict[str, str]: + """Normalize and validate DOI keys so labels line up with stored DOIs. + + Applies the same prefix-stripping and format check as ``dois`` so a + mistyped key fails loudly at config load instead of silently producing + a label that never matches a citation bucket. If two keys normalize to + the same DOI, the last one wins (mirrors ``dois`` dedup behavior). + """ + doi_pattern = re.compile(r"^10\.\d{4,}/[^\s]+$") + normalized: dict[str, str] = {} + for doi, label in v.items(): + clean_doi = re.sub(r"^(https?://)?(dx\.)?doi\.org/", "", doi.strip()) + if not clean_doi: + continue + if not doi_pattern.match(clean_doi): + raise ValueError( + f"Invalid DOI key in paper_labels (expected '10.xxxx/yyyy'): {doi}" + ) + normalized[clean_doi] = label + return normalized + + aliases: dict[str, list[str]] = Field(default_factory=dict) + """Version DOIs to merge into a canonical paper's citation count. + + Maps a primary DOI (from ``dois``) to other DOIs for the *same paper* + (typically a preprint and the published version). OpenAlex splits citations + across version records, so the citation sync queries them together and + deduplicates, attributing the merged per-year counts to the primary DOI. + Example: '10.1162/IMAG.a.136' -> ['10.1101/2024.02.13.580071']. Keys and + values are normalized like ``dois``.""" + + @field_validator("aliases") + @classmethod + def validate_aliases(cls, v: dict[str, list[str]]) -> dict[str, list[str]]: + """Normalize and validate primary + alias DOIs (same rules as ``dois``).""" + doi_pattern = re.compile(r"^10\.\d{4,}/[^\s]+$") + + def _clean(doi: str) -> str: + cleaned = re.sub(r"^(https?://)?(dx\.)?doi\.org/", "", doi.strip()) + if cleaned and not doi_pattern.match(cleaned): + raise ValueError(f"Invalid DOI in aliases (expected '10.xxxx/yyyy'): {doi}") + return cleaned + + normalized: dict[str, list[str]] = {} + for primary, versions in v.items(): + clean_primary = _clean(primary) + if not clean_primary: + continue + clean_versions: list[str] = [] + for d in versions: + clean = _clean(d) + if not clean: + # An empty version entry (e.g. `- ""`) is an authoring slip + # that would silently drop a version from the merge. + raise ValueError(f"Empty alias version DOI for primary '{primary}'") + if clean not in clean_versions: + clean_versions.append(clean) + normalized[clean_primary] = clean_versions + return normalized + @field_validator("queries") @classmethod def validate_queries(cls, v: list[str]) -> list[str]: @@ -273,6 +343,18 @@ def validate_dois(cls, v: list[str]) -> list[str]: # Deduplicate return list(dict.fromkeys(normalized)) + @model_validator(mode="after") + def validate_alias_primaries_in_dois(self) -> "CitationConfig": + """Every alias primary DOI must be a tracked DOI, else the merge is a no-op. + + Runs after field validators, so both ``dois`` and ``aliases`` keys are + already normalized and directly comparable. + """ + unknown = set(self.aliases) - set(self.dois) + if unknown: + raise ValueError(f"aliases primary DOIs not present in dois: {sorted(unknown)}") + return self + class DiscourseCategoryConfig(BaseModel): """A Discourse category to sync.""" @@ -637,6 +719,23 @@ def validate_agent_roles(self) -> "FAQGenerationConfig": return self +class PublicFeedsConfig(BaseModel): + """Opt-in flags for exposing community data as public, read-only JSON feeds. + + Both feeds are off by default. Enabling a feed publishes already-synced + data (FAQ entries, citation counts) at unauthenticated endpoints so + communities can build their own frontends on top of it. + """ + + model_config = ConfigDict(extra="forbid") + + faq: bool = False + """Expose generated FAQ entries at GET /{community_id}/faq.""" + + citations: bool = False + """Expose canonical-paper citation counts at GET /{community_id}/citations.""" + + class BudgetConfig(BaseModel): """Budget limits and alert thresholds for a community. @@ -918,6 +1017,9 @@ def validate_id(cls, v: str) -> str: faq_generation: FAQGenerationConfig | None = None """FAQ generation configuration from threaded discussions (mailman, discourse, etc.).""" + public_feeds: PublicFeedsConfig | None = None + """Opt-in flags for exposing FAQ/citation data as public JSON feeds.""" + sync: SyncConfig | None = None """Per-community sync schedule configuration. diff --git a/src/knowledge/db.py b/src/knowledge/db.py index 5c9166d..416b454 100644 --- a/src/knowledge/db.py +++ b/src/knowledge/db.py @@ -132,6 +132,9 @@ def active_mirror_context(mirror_id: str) -> Iterator[None]: url TEXT NOT NULL, created_at TEXT, synced_at TEXT NOT NULL, + -- Canonical DOI this paper cites, when discovered via citation sync. + -- NULL for papers found through keyword search rather than a citation link. + cites_doi TEXT, UNIQUE(source, external_id) ); @@ -171,6 +174,18 @@ def active_mirror_context(mirror_id: str) -> Iterator[None]: UNIQUE(source_type, source_name) ); +-- True per-year citation counts per canonical DOI, fetched from OpenAlex +-- group_by (complete, uncapped). This is the source of truth for the public +-- citations dashboard; the papers table only stores a recent sample of the +-- citing papers themselves for the search tool. +CREATE TABLE IF NOT EXISTS citation_counts ( + cites_doi TEXT NOT NULL, + year INTEGER NOT NULL, + count INTEGER NOT NULL, + synced_at TEXT NOT NULL, + PRIMARY KEY (cites_doi, year) +); + -- Docstrings extracted from source code CREATE TABLE IF NOT EXISTS docstrings ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -409,6 +424,8 @@ def active_mirror_context(mirror_id: str) -> Iterator[None]: CREATE INDEX IF NOT EXISTS idx_github_items_status ON github_items(status); CREATE INDEX IF NOT EXISTS idx_github_items_type ON github_items(item_type); CREATE INDEX IF NOT EXISTS idx_papers_source ON papers(source); +-- idx_papers_cites_doi is created in _migrate_db, after the cites_doi column +-- is ensured, so init_db stays safe on databases predating that column. CREATE INDEX IF NOT EXISTS idx_docstrings_repo ON docstrings(repo); CREATE INDEX IF NOT EXISTS idx_docstrings_language ON docstrings(language); CREATE INDEX IF NOT EXISTS idx_messages_list ON mailing_list_messages(list_name); @@ -507,6 +524,28 @@ def _migrate_db(conn: sqlite3.Connection) -> None: # Table doesn't exist yet - this is fine, schema will create it logger.debug("Docstrings table not found during migration (will be created): %s", e) + # Migration: Add cites_doi column to papers table (added 2026-06-09). + # The index lives here (not in SCHEMA_SQL) so executescript never references + # cites_doi on a database created before the column existed. + try: + cursor = conn.execute("PRAGMA table_info(papers)") + columns = [row[1] for row in cursor.fetchall()] + except sqlite3.OperationalError as e: + # Only the PRAGMA is guarded here: a missing papers table is fine since + # SCHEMA_SQL creates it. DDL errors below (locked DB, I/O fault) must + # propagate rather than be swallowed and leave the table un-indexed. + logger.debug("Papers table not found during migration (will be created): %s", e) + columns = [] + + if columns: # papers table exists; migrate it in place + if "cites_doi" not in columns: + logger.info("Migrating papers table: adding cites_doi column") + conn.execute("ALTER TABLE papers ADD COLUMN cites_doi TEXT") + logger.info("Migration complete: cites_doi column added to papers") + # Ensure the index exists for both new and migrated databases. + conn.execute("CREATE INDEX IF NOT EXISTS idx_papers_cites_doi ON papers(cites_doi)") + conn.commit() + def init_db(project: str = "hed") -> None: """Initialize database schema for a project. @@ -586,6 +625,7 @@ def upsert_paper( first_message: str | None, url: str, created_at: str | None, + cites_doi: str | None = None, ) -> None: """Insert or update a paper. @@ -597,6 +637,14 @@ def upsert_paper( first_message: Abstract (limited to ~2000 chars) url: URL to the paper (DOI or source URL) created_at: Publication date (ISO 8601 or year string) + cites_doi: Canonical DOI this paper cites, when known from a citation + sync. ``None`` for keyword-search results. On conflict the first + recorded link is kept (COALESCE), so a later keyword sync passing + ``None`` never erases an existing citation link, and a re-sync + backfills the link onto rows stored before this column existed. + A single column holds one link: a paper citing two tracked DOIs is + attributed to whichever was synced first (it is still counted once + in the per-year total, only its by-paper bucket is approximate). """ # Limit first_message size if first_message and len(first_message) > 2000: @@ -605,14 +653,15 @@ def upsert_paper( conn.execute( """ INSERT INTO papers (source, external_id, title, first_message, - status, url, created_at, synced_at) - VALUES (?, ?, ?, ?, 'published', ?, ?, ?) + status, url, created_at, synced_at, cites_doi) + VALUES (?, ?, ?, ?, 'published', ?, ?, ?, ?) ON CONFLICT(source, external_id) DO UPDATE SET title=excluded.title, first_message=excluded.first_message, - synced_at=excluded.synced_at + synced_at=excluded.synced_at, + cites_doi=COALESCE(papers.cites_doi, excluded.cites_doi) """, - (source, external_id, title, first_message, url, created_at, _now_iso()), + (source, external_id, title, first_message, url, created_at, _now_iso(), cites_doi), ) @@ -715,6 +764,35 @@ def update_sync_metadata( conn.commit() +def replace_citation_counts(cites_doi: str, counts: dict[int, int], project: str = "hed") -> None: + """Replace the stored per-year citation counts for one canonical DOI. + + The counts are an exact, complete histogram from OpenAlex, so the row set + is replaced wholesale (delete + insert) inside one transaction: this keeps + the table an accurate mirror and drops any year that no longer appears. + + Args: + cites_doi: Canonical DOI whose citations these counts describe. + counts: Mapping of publication year to citing-paper count. + project: Assistant/project name. Defaults to 'hed'. + """ + now = _now_iso() + with get_connection(project) as conn: + try: + conn.execute("DELETE FROM citation_counts WHERE cites_doi = ?", (cites_doi,)) + if counts: + conn.executemany( + "INSERT INTO citation_counts (cites_doi, year, count, synced_at) " + "VALUES (?, ?, ?, ?)", + [(cites_doi, year, count, now) for year, count in counts.items()], + ) + conn.commit() + except Exception: + # Keep the delete+insert atomic: never leave a DOI half-replaced. + conn.rollback() + raise + + def upsert_bep_item( conn: sqlite3.Connection, *, diff --git a/src/knowledge/openalex_citations.py b/src/knowledge/openalex_citations.py new file mode 100644 index 0000000..3bdfeb8 --- /dev/null +++ b/src/knowledge/openalex_citations.py @@ -0,0 +1,198 @@ +"""Direct OpenAlex client for citation analysis. + +opencite returns citing papers from a single page (<=200), ordered for its own +ranking, with no pagination and no aggregation exposed. For a citations +dashboard that silently truncates recent citations (the first page skews to +older, highly-cited works). We therefore query OpenAlex directly: + +- ``counts_by_year`` uses ``group_by=publication_year`` for the *exact, + complete* per-year histogram with no cap. +- ``recent_citing_papers`` cursor-paginates ``sort=publication_date:desc`` to + collect the latest N citing papers for the search corpus. + +The client takes an optional injected ``httpx.Client`` so tests can supply an +``httpx.MockTransport`` instead of hitting the network. +""" + +import logging +from collections.abc import Sequence +from dataclasses import dataclass + +import httpx + +logger = logging.getLogger(__name__) + +OPENALEX_BASE = "https://api.openalex.org" +_TIMEOUT = 30.0 +_PER_PAGE = 200 # OpenAlex maximum page size + + +@dataclass +class CitingPaper: + """A minimal citing-paper record for the search corpus.""" + + openalex_id: str + doi: str | None + title: str + publication_date: str | None + url: str + + +def _strip_id(value: str | None) -> str: + """Reduce an OpenAlex IRI (https://openalex.org/W123) to its bare id.""" + if not value: + return "" + return value.rstrip("/").rsplit("/", 1)[-1] + + +def _strip_doi(value: str | None) -> str | None: + """Reduce a DOI URL to the bare ``10.xxxx/yyyy`` form.""" + if not value: + return None + cleaned = value.strip() + for prefix in ("https://doi.org/", "http://doi.org/", "https://dx.doi.org/"): + if cleaned.lower().startswith(prefix): + cleaned = cleaned[len(prefix) :] + break + return cleaned or None + + +class OpenAlexCitationClient: + """Queries OpenAlex for citation counts and recent citing papers.""" + + def __init__( + self, + *, + email: str = "", + api_key: str = "", + client: httpx.Client | None = None, + ) -> None: + self._email = email + self._api_key = api_key + self._owns_client = client is None + self._client = client or httpx.Client(timeout=_TIMEOUT) + + def __enter__(self) -> "OpenAlexCitationClient": + return self + + def __exit__(self, *exc: object) -> None: + self.close() + + def close(self) -> None: + if self._owns_client: + self._client.close() + + def _params(self, **extra: object) -> dict[str, object]: + params: dict[str, object] = dict(extra) + # mailto routes to the polite pool; api_key unlocks premium throughput. + if self._email: + params["mailto"] = self._email + if self._api_key: + params["api_key"] = self._api_key + return params + + def resolve_work_id(self, doi: str) -> str | None: + """Resolve a DOI to its OpenAlex work id (e.g. ``W2128495200``).""" + resp = self._client.get( + f"{OPENALEX_BASE}/works/doi:{doi}", + params=self._params(select="id"), + ) + if resp.status_code == 404: + logger.warning("OpenAlex has no work for DOI %s", doi) + return None + resp.raise_for_status() + work_id = _strip_id(resp.json().get("id")) + return work_id or None + + @staticmethod + def _cites_filter(work_ids: str | Sequence[str]) -> str: + """Build a ``cites:`` filter, OR-joining multiple work ids with ``|``. + + OpenAlex deduplicates across an OR group, so passing every version of a + paper (preprint + published) yields the merged, non-double-counted set. + """ + ids = [work_ids] if isinstance(work_ids, str) else [w for w in work_ids if w] + if not ids: + raise ValueError("work_ids must contain at least one OpenAlex work id") + return "cites:" + "|".join(ids) + + def counts_by_year(self, work_ids: str | Sequence[str]) -> dict[int, int]: + """Return the complete per-year count of works citing ``work_ids``. + + Accepts one work id or several (a version group); multiple ids are + OR-joined and deduplicated by OpenAlex. Uses ``group_by`` so the counts + are exact and uncapped, independent of how many papers are stored. + """ + resp = self._client.get( + f"{OPENALEX_BASE}/works", + params=self._params(filter=self._cites_filter(work_ids), group_by="publication_year"), + ) + resp.raise_for_status() + counts: dict[int, int] = {} + for group in resp.json().get("group_by", []): + try: + year = int(group["key"]) + except (KeyError, TypeError, ValueError): + continue # non-year buckets (e.g. "unknown") are skipped + counts[year] = int(group.get("count", 0)) + return counts + + def recent_citing_papers( + self, work_ids: str | Sequence[str], limit: int = 2000 + ) -> list[CitingPaper]: + """Collect up to ``limit`` most-recent works citing ``work_ids``. + + Accepts one work id or a version group (OR-joined, deduplicated by + OpenAlex). Cursor-paginates ``sort=publication_date:desc`` so the stored + sample is the newest citations rather than an arbitrary first page. + """ + cites_filter = self._cites_filter(work_ids) + papers: list[CitingPaper] = [] + cursor: str | None = "*" + # Bound the page count: a highly-cited work may have title-less records + # that never accumulate, so cap pages (with headroom) to avoid spinning. + pages = 0 + max_pages = (limit // _PER_PAGE) + 50 + while cursor and len(papers) < limit and pages < max_pages: + pages += 1 + page_size = min(_PER_PAGE, limit - len(papers)) + resp = self._client.get( + f"{OPENALEX_BASE}/works", + params=self._params( + filter=cites_filter, + sort="publication_date:desc", + select="id,doi,title,publication_date", + cursor=cursor, + **{"per-page": page_size}, + ), + ) + resp.raise_for_status() + data = resp.json() + results = data.get("results", []) + if not results: + break # no more works; a non-null cursor with no rows would spin + for work in results: + title = work.get("title") + if not title: + continue + doi = _strip_doi(work.get("doi")) + papers.append( + CitingPaper( + openalex_id=_strip_id(work.get("id")), + doi=doi, + title=title, + publication_date=work.get("publication_date"), + url=f"https://doi.org/{doi}" if doi else (work.get("id") or ""), + ) + ) + if len(papers) >= limit: + break + cursor = data.get("meta", {}).get("next_cursor") + if pages >= max_pages and cursor: + logger.warning( + "recent_citing_papers hit page cap for %s (%d pages, %d stored)", + cites_filter, + pages, + len(papers), + ) + return papers diff --git a/src/knowledge/papers_sync.py b/src/knowledge/papers_sync.py index a83806b..c0d2f69 100644 --- a/src/knowledge/papers_sync.py +++ b/src/knowledge/papers_sync.py @@ -21,11 +21,16 @@ from typing import Any, TypeVar from opencite import Config, Paper -from opencite.citations import CitationExplorer from opencite.exceptions import APIKeyError, ConfigurationError, OpenCiteError from opencite.search import SearchOrchestrator -from src.knowledge.db import get_connection, update_sync_metadata, upsert_paper +from src.knowledge.db import ( + get_connection, + replace_citation_counts, + update_sync_metadata, + upsert_paper, +) +from src.knowledge.openalex_citations import CitingPaper, OpenAlexCitationClient from src.knowledge.search import SearchResult logger = logging.getLogger(__name__) @@ -158,6 +163,7 @@ def _store_papers( project: str, *, force_source: str | None = None, + cites_doi: str | None = None, ) -> dict[str, int]: """Upsert opencite papers into the knowledge DB, returning counts by source. @@ -167,6 +173,8 @@ def _store_papers( force_source: When set (a single-source sync), record this OSA source label using its native identifier; falls back to the priority mapping if that identifier is missing. + cites_doi: Canonical DOI these papers cite, recorded on each row when + storing the results of a citation sync. ``None`` for keyword search. """ counts: dict[str, int] = {} with get_connection(project) as conn: @@ -193,6 +201,7 @@ def _store_papers( first_message=paper.abstract or None, url=_paper_url(paper), created_at=paper.publication_date or (str(paper.year) if paper.year else None), + cites_doi=cites_doi, ) counts[source] = counts.get(source, 0) + 1 conn.commit() @@ -248,27 +257,6 @@ async def _search_queries( return out -async def _citing_for_dois( - config: Config, - dois: list[str], - max_results: int, -) -> list[tuple[str, list[Paper]]]: - """Fetch citing papers for every DOI through one shared CitationExplorer.""" - out: list[tuple[str, list[Paper]]] = [] - async with CitationExplorer(config) as explorer: - for doi in dois: - try: - result = await explorer.citing_papers(doi, max_results=max_results) - out.append((doi, result.papers)) - except (OpenCiteError, TimeoutError) as e: - logger.warning("opencite citation error for DOI %s: %s", doi, e) - out.append((doi, [])) - except Exception: - logger.exception("unexpected error fetching citations for DOI %s", doi) - out.append((doi, [])) - return out - - def _sync_single_source( query: str, max_results: int, @@ -385,51 +373,129 @@ def sync_all_papers( return results +def _store_citing_papers(papers: Iterable[CitingPaper], project: str, *, cites_doi: str) -> int: + """Upsert OpenAlex citing-paper records into the papers table. + + Returns the number of rows stored. Each row is labelled with ``cites_doi`` + so it links back to the canonical paper it cites. + """ + stored = 0 + with get_connection(project) as conn: + for paper in papers: + if not paper.openalex_id or not paper.title: + continue + upsert_paper( + conn, + source="openalex", + external_id=paper.openalex_id, + title=paper.title, + first_message=None, + url=paper.url, + created_at=paper.publication_date, + cites_doi=cites_doi, + ) + stored += 1 + conn.commit() + return stored + + def sync_citing_papers( dois: list[str], - max_results: int = 100, + max_results: int = 2000, project: str = "hed", openalex_api_key: str | None = None, openalex_email: str | None = None, + aliases: dict[str, list[str]] | None = None, ) -> int: - """Sync papers that cite the given DOIs using opencite's citation graph. + """Sync citation data for the given canonical DOIs from OpenAlex. + + For each DOI this records two things, queried directly from OpenAlex + (opencite caps citing-paper fetches at one page and exposes no aggregation, + which truncates recent citations): + + 1. The *complete, uncapped* per-year citation histogram, via + ``group_by=publication_year``, stored in ``citation_counts``. This is + the source of truth for the public citations dashboard. + 2. The latest ``max_results`` citing papers (publication date descending), + upserted into the ``papers`` table for the search corpus. + + When a DOI has version ``aliases`` (e.g. a preprint plus the published + version), every version is resolved and queried together: OpenAlex splits + citations across version records, so OR-joining and deduplicating them + recovers the true count, attributed to the primary DOI. Args: - dois: List of DOIs to find citations for. Bare format preferred - (e.g. "10.1016/j.neuroimage.2021.118809"); opencite auto-detects - and resolves the identifier. Unresolved DOIs are skipped with a - warning. - max_results: Maximum number of citing papers per DOI. + dois: Canonical (primary) DOIs to track citations for. Unresolvable + DOIs are skipped with a warning. + max_results: Maximum number of recent citing papers stored per DOI. + Does not limit the per-year counts, which are always complete. project: Project/community ID for database isolation. - openalex_api_key: Optional OpenAlex API key for premium access. - openalex_email: Optional email for OpenAlex polite pool. + openalex_api_key: Optional OpenAlex API key for premium throughput. + openalex_email: Optional email for the OpenAlex polite pool. + aliases: Optional map of primary DOI -> additional version DOIs whose + citations merge into the primary. Returns: - Total number of citing papers synced. + Total citing papers stored across all DOIs (counts are uncapped). """ if isinstance(dois, str): raise TypeError(f"dois must be a list of strings, not a bare string: {dois!r}") - config = _build_config(openalex_api_key=openalex_api_key, openalex_email=openalex_email) - try: - cited = _run(_citing_for_dois(config, dois, max_results)) - except Exception as e: - logger.warning("opencite citation lookup failed for %s: %s", project, e) - return 0 + email = openalex_email or _OPENALEX_EMAIL or "" + api_key = openalex_api_key or _OPENALEX_API_KEY or "" + aliases = aliases or {} - total = 0 - for doi, papers in cited: - try: - counts = _store_papers(papers, project) - count = sum(counts.values()) - update_sync_metadata("papers", f"citing_{doi}", count, project) - logger.info("Synced %d papers citing %s", count, doi) - total += count - except Exception: - # Isolate per-DOI so one DB failure does not abort the batch. - logger.exception("failed to store citing papers for %s (%s)", doi, project) - - return total + total_stored = 0 + with OpenAlexCitationClient(email=email, api_key=api_key) as client: + for doi in dois: + try: + # Resolve the primary DOI plus any version aliases to a group of + # OpenAlex work ids; citations across the group are merged. + group_dois = [doi, *aliases.get(doi, [])] + work_ids = [wid for d in group_dois if (wid := client.resolve_work_id(d))] + if not work_ids: + logger.warning("Skipping citations: cannot resolve DOI %s", doi) + continue + + # 1. Complete per-year counts (source of truth for the chart). + counts = client.counts_by_year(work_ids) + if not counts: + # A canonical paper with zero citations is implausible; an + # empty histogram almost always means a transient OpenAlex + # gap. Do not wipe existing counts on a likely-bad read. + logger.warning( + "Empty citation histogram for %s (works %s); keeping existing " + "counts and skipping this DOI", + doi, + work_ids, + ) + continue + replace_citation_counts(doi, counts, project) + total_citations = sum(counts.values()) + + # 2. Latest citing papers for the search corpus. + papers = client.recent_citing_papers(work_ids, limit=max_results) + stored = _store_citing_papers(papers, project, cites_doi=doi) + + update_sync_metadata("citations", f"citing_{doi}", total_citations, project) + logger.info( + "Citations for %s: %d total across years, stored %d recent papers", + doi, + total_citations, + stored, + ) + total_stored += stored + except Exception as exc: + # Isolate per-DOI so one failure does not abort the batch. + logger.exception( + "citation sync failed for %s (%s): %s: %s", + doi, + project, + type(exc).__name__, + exc, + ) + + return total_stored def _config_from_env() -> Config: diff --git a/src/knowledge/search.py b/src/knowledge/search.py index 61563f2..c3b0222 100644 --- a/src/knowledge/search.py +++ b/src/knowledge/search.py @@ -376,6 +376,75 @@ def search_github_items( return results +@dataclass +class CitationStats: + """Aggregated citation counts for a community's canonical papers.""" + + total: int + """Total citing papers with a recorded canonical link and a valid year.""" + + per_year: dict[str, int] + """Citing-paper count per publication year, summed across canonical DOIs.""" + + by_paper: dict[str, dict[str, int]] + """Per canonical DOI: a mapping of publication year to citing-paper count.""" + + +def get_citation_stats(project: str = "eeglab") -> CitationStats: + """Aggregate citation counts for the public citations dashboard. + + Reads the ``citation_counts`` table, which holds the exact, complete + per-year histogram per canonical DOI fetched from OpenAlex ``group_by`` + (not the capped sample of citing papers in the ``papers`` table). A + community that has not yet had its citations synced (table absent) yields + empty stats rather than an error. + + Args: + project: Community ID for database isolation. Defaults to 'eeglab'. + + Returns: + CitationStats with the overall ``total``, ``per_year`` totals, and the + stacked ``by_paper`` breakdown (canonical DOI -> year -> count). Years + are sorted ascending in every mapping. + """ + sql = "SELECT cites_doi, year, count FROM citation_counts" + + per_year: dict[str, int] = {} + by_paper: dict[str, dict[str, int]] = {} + total = 0 + try: + with get_connection(project) as conn: + for row in conn.execute(sql): + doi = row["cites_doi"] + year = str(row["year"]) + count = row["count"] + per_year[year] = per_year.get(year, 0) + count + by_paper.setdefault(doi, {})[year] = count + total += count + except sqlite3.OperationalError as e: + # The table is created on the first citation sync; before then, treat + # the feed as empty instead of failing the request. + if "no such table" in str(e).lower(): + logger.info("citation_counts not yet present for project %s", project) + return CitationStats(total=0, per_year={}, by_paper={}) + logger.error( + "Database operational error computing citation stats: %s", + e, + exc_info=True, + extra={"project": project}, + ) + raise + except sqlite3.Error as e: + logger.warning("Database error computing citation stats (project=%s): %s", project, e) + raise + + return CitationStats( + total=total, + per_year=dict(sorted(per_year.items())), + by_paper={doi: dict(sorted(years.items())) for doi, years in by_paper.items()}, + ) + + def search_papers( query: str, project: str = "hed", @@ -792,6 +861,28 @@ class FAQResult: first_message_date: str +def _parse_faq_tags(raw: str | None, *, thread_url: str, project: str) -> list[str]: + """Decode a FAQ entry's JSON ``tags`` column, tolerating malformed data. + + The column is written by the summarizer as a JSON array. A corrupt value + should degrade to an empty tag list (and a warning) rather than raise a + ``JSONDecodeError`` that escapes the sqlite handlers and surfaces as an + unlogged 500 at the API layer. + """ + if not raw: + return [] + try: + return json.loads(raw) + except (json.JSONDecodeError, TypeError): + logger.warning( + "Invalid JSON in FAQ tags (thread_url=%s, project=%s): %r", + thread_url, + project, + raw, + ) + return [] + + def search_faq_entries( query: str, project: str = "eeglab", @@ -845,7 +936,7 @@ def search_faq_entries( params[0] = safe_query for row in conn.execute(sql, params): - tags = json.loads(row["tags"]) if row["tags"] else [] + tags = _parse_faq_tags(row["tags"], thread_url=row["thread_url"], project=project) results.append( FAQResult( @@ -876,6 +967,111 @@ def search_faq_entries( return results +def list_faq_entries( + project: str = "eeglab", + limit: int = 50, + offset: int = 0, + query: str | None = None, + list_name: str | None = None, + category: str | None = None, + min_quality: float = 0.0, +) -> tuple[list[FAQResult], int]: + """List FAQ entries for the public feed, with pagination metadata. + + Serves both browse mode (no ``query``) and search mode (``query`` set, via + FTS5). Unlike :func:`search_faq_entries`, this always returns the full + matching ``total`` count computed before LIMIT/OFFSET, so callers can + paginate correctly in either mode. + + Args: + project: Community ID for database isolation. Defaults to 'eeglab'. + limit: Maximum number of entries to return. + offset: Number of entries to skip (for pagination). + query: Optional full-text search phrase. When omitted, all entries + matching the filters are browsed, ordered by quality then recency. + list_name: Filter by mailing list name. + category: Filter by category (e.g., 'troubleshooting', 'how-to'). + min_quality: Minimum quality score (0.0-1.0). + + Returns: + Tuple of (entries, total_count) where total_count is the number of + entries matching the query and filters before limit/offset are applied. + """ + use_fts = bool(query and query.strip()) + + leading_params: list[str | int | float] = [] + if use_fts: + from_clause = "faq_entries_fts fts JOIN faq_entries f ON fts.rowid = f.id" + where_clause = "faq_entries_fts MATCH ?" + order_clause = "f.quality_score DESC, rank" + # Sanitize to prevent FTS5 injection (query is guaranteed non-None here). + leading_params.append(_sanitize_fts5_query(query)) # type: ignore[arg-type] + else: + from_clause = "faq_entries f" + where_clause = "1=1" + order_clause = "f.quality_score DESC, f.first_message_date DESC" + + filters = "" + filter_params: list[str | int | float] = [] + if list_name: + filters += " AND f.list_name = ?" + filter_params.append(list_name) + if category: + filters += " AND f.category = ?" + filter_params.append(category) + if min_quality > 0: + filters += " AND f.quality_score >= ?" + filter_params.append(min_quality) + + base_params = [*leading_params, *filter_params] + count_sql = f"SELECT COUNT(*) FROM {from_clause} WHERE {where_clause}{filters}" + rows_sql = ( + "SELECT f.question, f.answer, f.thread_url, f.tags, f.category, " + "f.quality_score, f.message_count, f.first_message_date " + f"FROM {from_clause} WHERE {where_clause}{filters} " + f"ORDER BY {order_clause} LIMIT ? OFFSET ?" + ) + + results: list[FAQResult] = [] + try: + with get_connection(project) as conn: + total = conn.execute(count_sql, base_params).fetchone()[0] + + for row in conn.execute(rows_sql, [*base_params, limit, offset]): + tags = _parse_faq_tags(row["tags"], thread_url=row["thread_url"], project=project) + results.append( + FAQResult( + question=row["question"], + answer=row["answer"], + thread_url=row["thread_url"], + tags=tags, + category=row["category"], + quality_score=row["quality_score"], + message_count=row["message_count"], + first_message_date=row["first_message_date"] or "", + ) + ) + except sqlite3.OperationalError as e: + logger.error( + "Database operational error listing FAQ entries: %s", + e, + exc_info=True, + extra={"project": project}, + ) + raise + except sqlite3.Error as e: + logger.warning( + "Database error listing FAQ entries (project=%s, limit=%d, offset=%d): %s", + project, + limit, + offset, + e, + ) + raise + + return results, total + + @dataclass class BEPResult: """A BEP search result from the knowledge database.""" diff --git a/src/version.py b/src/version.py index 52c1278..9acce2d 100644 --- a/src/version.py +++ b/src/version.py @@ -1,7 +1,7 @@ """Version information for OSA.""" -__version__ = "0.8.5" -__version_info__ = (0, 8, 5) +__version__ = "0.8.7.dev0" +__version_info__ = (0, 8, 7, "dev") def get_version() -> str: diff --git a/tests/test_api/test_citations_feed.py b/tests/test_api/test_citations_feed.py new file mode 100644 index 0000000..d992e2e --- /dev/null +++ b/tests/test_api/test_citations_feed.py @@ -0,0 +1,198 @@ +"""Tests for the public citations feed endpoint: GET /{community_id}/citations. + +Uses a real registered community, a temporary SQLite knowledge database with +citing papers, and the config gate toggled per test. No business logic is +mocked except in TestCitationsFeedErrors, where get_citation_stats is patched +at the router call boundary to inject DB/unexpected errors and verify the +503/500 responses. +""" + +import sqlite3 +from collections.abc import Iterator +from pathlib import Path +from unittest.mock import patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from src.api.routers.community import create_community_router +from src.assistants import discover_assistants, registry +from src.core.config.community import PublicFeedsConfig +from src.knowledge.db import init_db, replace_citation_counts + +COMMUNITY_ID = "eeglab" +DOI_A = "10.1016/j.jneumeth.2003.10.009" +DOI_B = "10.1016/j.neuroimage.2019.05.026" + +discover_assistants() + + +@pytest.fixture +def citations_db(tmp_path: Path) -> Iterator[Path]: + """Temp knowledge DB with per-year citation counts for two canonical DOIs.""" + db_path = tmp_path / "knowledge" / "test.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db(COMMUNITY_ID) + # DOI_A: 2 in 2019, 1 in 2020 ; DOI_B: 1 in 2020 + replace_citation_counts(DOI_A, {2019: 2, 2020: 1}, project=COMMUNITY_ID) + replace_citation_counts(DOI_B, {2020: 1}, project=COMMUNITY_ID) + yield db_path + + +@pytest.fixture +def citations_enabled() -> Iterator[None]: + info = registry.get(COMMUNITY_ID) + assert info is not None and info.community_config is not None + original = info.community_config.public_feeds + info.community_config.public_feeds = PublicFeedsConfig(citations=True) + try: + yield + finally: + info.community_config.public_feeds = original + + +@pytest.fixture +def citations_disabled_none() -> Iterator[None]: + info = registry.get(COMMUNITY_ID) + assert info is not None and info.community_config is not None + original = info.community_config.public_feeds + info.community_config.public_feeds = None + try: + yield + finally: + info.community_config.public_feeds = original + + +@pytest.fixture +def citations_flag_false() -> Iterator[None]: + info = registry.get(COMMUNITY_ID) + assert info is not None and info.community_config is not None + original = info.community_config.public_feeds + info.community_config.public_feeds = PublicFeedsConfig(citations=False) + try: + yield + finally: + info.community_config.public_feeds = original + + +@pytest.fixture +def citations_enabled_no_config() -> Iterator[None]: + """Feed enabled but the community has no citations config block.""" + info = registry.get(COMMUNITY_ID) + assert info is not None and info.community_config is not None + orig_feeds = info.community_config.public_feeds + orig_citations = info.community_config.citations + info.community_config.public_feeds = PublicFeedsConfig(citations=True) + info.community_config.citations = None + try: + yield + finally: + info.community_config.public_feeds = orig_feeds + info.community_config.citations = orig_citations + + +@pytest.fixture +def client() -> TestClient: + app = FastAPI() + app.include_router(create_community_router(COMMUNITY_ID)) + return TestClient(app) + + +class TestCitationsFeedGate: + """The endpoint is opt-in via public_feeds.citations.""" + + @pytest.mark.usefixtures("citations_disabled_none") + def test_disabled_when_public_feeds_none(self, client, citations_db): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + resp = client.get(f"/{COMMUNITY_ID}/citations") + assert resp.status_code == 404 + + @pytest.mark.usefixtures("citations_flag_false") + def test_disabled_when_flag_false(self, client, citations_db): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + resp = client.get(f"/{COMMUNITY_ID}/citations") + assert resp.status_code == 404 + + @pytest.mark.usefixtures("citations_enabled") + def test_enabled_returns_200(self, client, citations_db): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + resp = client.get(f"/{COMMUNITY_ID}/citations") + assert resp.status_code == 200 + + +@pytest.mark.usefixtures("citations_enabled") +class TestCitationsFeedContent: + def test_total_and_per_year(self, client, citations_db): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + resp = client.get(f"/{COMMUNITY_ID}/citations") + body = resp.json() + assert body["community_id"] == COMMUNITY_ID + assert body["total"] == 4 # a1,a2,a3,b1 ; k1 unlinked excluded + assert body["per_year"] == {"2019": 2, "2020": 2} + + def test_by_paper_stacked_breakdown(self, client, citations_db): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + resp = client.get(f"/{COMMUNITY_ID}/citations") + by_paper = resp.json()["by_paper"] + assert by_paper == { + DOI_A: {"2019": 2, "2020": 1}, + DOI_B: {"2020": 1}, + } + + def test_canonical_dois_from_config(self, client, citations_db): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + resp = client.get(f"/{COMMUNITY_ID}/citations") + canonical = resp.json()["canonical_dois"] + # eeglab config tracks these canonical DOIs. + assert DOI_A in canonical + assert DOI_B in canonical + + def test_cache_control_header(self, client, citations_db): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + resp = client.get(f"/{COMMUNITY_ID}/citations") + assert resp.headers["Cache-Control"] == "public, max-age=3600" + + def test_labels_from_config(self, client, citations_db): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + resp = client.get(f"/{COMMUNITY_ID}/citations") + labels = resp.json()["labels"] + # eeglab config defines human-readable labels for its canonical DOIs. + assert labels.get(DOI_A) == "EEGLAB (Delorme 2004)" + assert labels.get(DOI_B) == "ICLabel (Pion-Tonachini 2019)" + # Mixed-case DOI suffix survives the config -> endpoint round-trip. + assert labels.get("10.1162/IMAG.a.136") == "LSL (Kothe 2025)" + + +class TestCitationsFeedNoConfig: + """Feed enabled for a community without a citations config block.""" + + @pytest.mark.usefixtures("citations_enabled_no_config") + def test_canonical_dois_empty_when_no_citations_config(self, client, citations_db): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + resp = client.get(f"/{COMMUNITY_ID}/citations") + body = resp.json() + assert resp.status_code == 200 + assert body["canonical_dois"] == [] + assert body["labels"] == {} + # Stats still come from the DB regardless of config presence. + assert body["total"] == 4 + + +@pytest.mark.usefixtures("citations_enabled") +class TestCitationsFeedErrors: + def test_db_error_returns_503(self, client): + with patch( + "src.api.routers.community.get_citation_stats", + side_effect=sqlite3.OperationalError("db is locked"), + ): + resp = client.get(f"/{COMMUNITY_ID}/citations") + assert resp.status_code == 503 + + def test_unexpected_error_returns_500(self, client): + with patch( + "src.api.routers.community.get_citation_stats", + side_effect=RuntimeError("boom"), + ): + resp = client.get(f"/{COMMUNITY_ID}/citations") + assert resp.status_code == 500 diff --git a/tests/test_api/test_dashboard.py b/tests/test_api/test_dashboard.py index cc14100..ba52901 100644 --- a/tests/test_api/test_dashboard.py +++ b/tests/test_api/test_dashboard.py @@ -77,6 +77,18 @@ def test_has_period_toggle(self) -> None: assert "weekly" in content assert "monthly" in content + def test_references_citations_api(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + # Community view fetches the public citations feed. + assert "/citations" in content + + def test_has_citations_chart(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "renderCitationsChart" in content + assert "citationsChart" in content + # Uses the configured labels for the stacked series legend. + assert "citations.labels" in content + def test_api_base_configurable(self) -> None: content = DASHBOARD_HTML_PATH.read_text() # Should support ?api= query param or window.OSA_API_BASE override diff --git a/tests/test_api/test_faq_feed.py b/tests/test_api/test_faq_feed.py new file mode 100644 index 0000000..9408a9e --- /dev/null +++ b/tests/test_api/test_faq_feed.py @@ -0,0 +1,256 @@ +"""Tests for the public FAQ feed endpoint: GET /{community_id}/faq. + +Uses a real registered community, a temporary SQLite knowledge database +populated with FAQ rows, and the config gate toggled per test. No business +logic is mocked; only the database path and the opt-in flag are controlled. +""" + +import sqlite3 +from collections.abc import Iterator +from pathlib import Path +from unittest.mock import patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from src.api.routers.community import create_community_router +from src.assistants import discover_assistants, registry +from src.core.config.community import PublicFeedsConfig +from src.knowledge.db import get_connection, init_db, upsert_faq_entry + +COMMUNITY_ID = "eeglab" + +discover_assistants() + + +@pytest.fixture +def faq_db(tmp_path: Path) -> Iterator[Path]: + """Temp knowledge DB populated with FAQ entries, including one with an email.""" + db_path = tmp_path / "knowledge" / "test.db" + # Write through the same project the endpoint reads (COMMUNITY_ID) so the + # test does not rely on get_db_path ignoring its project argument. + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db(COMMUNITY_ID) + with get_connection(COMMUNITY_ID) as conn: + upsert_faq_entry( + conn, + list_name="eeglablist", + thread_id="t1", + thread_url="https://example.org/t1", + question="How do I run ICA in EEGLAB?", + answer="Use runica from the Tools menu.", + tags=["ica"], + category="how-to", + message_count=3, + participant_count=2, + first_message_date="2020-01-01", + quality_score=0.95, + summary_model="test-model", + ) + # t2 carries an email in the question, the answer, and a tag so the + # endpoint's redaction can be verified across all three fields. + upsert_faq_entry( + conn, + list_name="eeglablist", + thread_id="t2", + thread_url="https://example.org/t2", + question="Who do I contact (e.g. sales@brainproducts.com) for support?", + answer="Email support@brainproducts.com for hardware questions.", + tags=["hardware", "contact:info@vendor.com"], + category="reference", + message_count=2, + participant_count=2, + first_message_date="2021-01-01", + quality_score=0.70, + summary_model="test-model", + ) + conn.commit() + yield db_path + + +@pytest.fixture +def feeds_enabled() -> Iterator[None]: + """Enable public_feeds.faq on the community config, restoring it afterward.""" + info = registry.get(COMMUNITY_ID) + assert info is not None and info.community_config is not None + original = info.community_config.public_feeds + info.community_config.public_feeds = PublicFeedsConfig(faq=True) + try: + yield + finally: + info.community_config.public_feeds = original + + +@pytest.fixture +def feeds_disabled() -> Iterator[None]: + """Force public_feeds off (None), restoring the original afterward.""" + info = registry.get(COMMUNITY_ID) + assert info is not None and info.community_config is not None + original = info.community_config.public_feeds + info.community_config.public_feeds = None + try: + yield + finally: + info.community_config.public_feeds = original + + +@pytest.fixture +def feeds_faq_false() -> Iterator[None]: + """public_feeds present but faq disabled (the non-None gate branch).""" + info = registry.get(COMMUNITY_ID) + assert info is not None and info.community_config is not None + original = info.community_config.public_feeds + info.community_config.public_feeds = PublicFeedsConfig(faq=False) + try: + yield + finally: + info.community_config.public_feeds = original + + +@pytest.fixture +def client() -> TestClient: + app = FastAPI() + app.include_router(create_community_router(COMMUNITY_ID)) + return TestClient(app) + + +class TestFAQFeedGate: + """The endpoint is opt-in via public_feeds.faq.""" + + @pytest.mark.usefixtures("feeds_disabled") + def test_disabled_when_public_feeds_none(self, client, faq_db): + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + resp = client.get(f"/{COMMUNITY_ID}/faq") + assert resp.status_code == 404 + + @pytest.mark.usefixtures("feeds_faq_false") + def test_disabled_when_faq_flag_false(self, client, faq_db): + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + resp = client.get(f"/{COMMUNITY_ID}/faq") + assert resp.status_code == 404 + + @pytest.mark.usefixtures("feeds_enabled") + def test_enabled_returns_200(self, client, faq_db): + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + resp = client.get(f"/{COMMUNITY_ID}/faq") + assert resp.status_code == 200 + + +@pytest.mark.usefixtures("feeds_enabled") +class TestFAQFeedContent: + """Response shape and filtering when enabled.""" + + def test_returns_all_entries(self, client, faq_db): + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + resp = client.get(f"/{COMMUNITY_ID}/faq") + body = resp.json() + assert body["community_id"] == COMMUNITY_ID + assert body["total"] == 2 + assert len(body["entries"]) == 2 + # Ordered by quality descending + assert body["entries"][0]["quality_score"] == 0.95 + + def test_exposed_fields_only(self, client, faq_db): + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + resp = client.get(f"/{COMMUNITY_ID}/faq") + entry = resp.json()["entries"][0] + assert set(entry.keys()) == { + "question", + "answer", + "tags", + "category", + "quality_score", + "message_count", + "first_message_date", + "thread_url", + } + + def test_emails_are_redacted(self, client, faq_db): + """Emails are stripped from question, answer, and tags alike.""" + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + resp = client.get(f"/{COMMUNITY_ID}/faq") + entries = resp.json()["entries"] + blob = " ".join( + e["question"] + " " + e["answer"] + " " + " ".join(e["tags"]) for e in entries + ) + assert "support@brainproducts.com" not in blob + assert "sales@brainproducts.com" not in blob + assert "info@vendor.com" not in blob + assert "[email redacted]" in blob + # Redaction reached all three field types on the t2 entry. + t2 = next(e for e in entries if e["category"] == "reference") + assert "[email redacted]" in t2["question"] + assert "[email redacted]" in t2["answer"] + assert any("[email redacted]" in tag for tag in t2["tags"]) + + def test_category_filter(self, client, faq_db): + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + resp = client.get(f"/{COMMUNITY_ID}/faq", params={"category": "how-to"}) + body = resp.json() + assert body["total"] == 1 + assert body["entries"][0]["category"] == "how-to" + + def test_min_quality_filter(self, client, faq_db): + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + resp = client.get(f"/{COMMUNITY_ID}/faq", params={"min_quality": 0.9}) + body = resp.json() + assert body["total"] == 1 + assert body["entries"][0]["quality_score"] >= 0.9 + + def test_search_query(self, client, faq_db): + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + resp = client.get(f"/{COMMUNITY_ID}/faq", params={"q": "ICA"}) + body = resp.json() + # Only the t1 entry mentions ICA; total is the real match count. + assert body["total"] == 1 + assert len(body["entries"]) == 1 + assert "ICA" in body["entries"][0]["question"] + + def test_pagination(self, client, faq_db): + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + resp = client.get(f"/{COMMUNITY_ID}/faq", params={"limit": 1, "offset": 0}) + body = resp.json() + assert body["total"] == 2 + assert len(body["entries"]) == 1 + assert body["limit"] == 1 + assert body["offset"] == 0 + + def test_cache_control_header(self, client, faq_db): + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + resp = client.get(f"/{COMMUNITY_ID}/faq") + assert resp.headers["Cache-Control"] == "public, max-age=3600" + + +@pytest.mark.usefixtures("feeds_enabled", "faq_db") +class TestFAQFeedValidation: + """Query parameter bounds are enforced (rejected before DB access).""" + + def test_invalid_min_quality_rejected(self, client): + resp = client.get(f"/{COMMUNITY_ID}/faq", params={"min_quality": 5}) + assert resp.status_code == 422 + + def test_limit_upper_bound_enforced(self, client): + resp = client.get(f"/{COMMUNITY_ID}/faq", params={"limit": 9999}) + assert resp.status_code == 422 + + +@pytest.mark.usefixtures("feeds_enabled") +class TestFAQFeedErrors: + """Database failures surface as 503, not silent empty responses.""" + + def test_browse_db_error_returns_503(self, client): + with patch( + "src.api.routers.community.list_faq_entries", + side_effect=sqlite3.OperationalError("db is locked"), + ): + resp = client.get(f"/{COMMUNITY_ID}/faq") + assert resp.status_code == 503 + + def test_search_db_error_returns_503(self, client): + with patch( + "src.api.routers.community.list_faq_entries", + side_effect=sqlite3.OperationalError("db is locked"), + ): + resp = client.get(f"/{COMMUNITY_ID}/faq", params={"q": "ICA"}) + assert resp.status_code == 503 diff --git a/tests/test_core/test_config/test_community.py b/tests/test_core/test_config/test_community.py index ab5f8f8..de942d5 100644 --- a/tests/test_core/test_config/test_community.py +++ b/tests/test_core/test_config/test_community.py @@ -189,6 +189,69 @@ def test_deduplicates_dois(self) -> None: assert "10.1234/example" in config.dois assert "10.5678/other" in config.dois + def test_paper_labels_default_empty(self) -> None: + """paper_labels defaults to an empty dict.""" + assert CitationConfig().paper_labels == {} + + def test_paper_labels_keys_normalized(self) -> None: + """DOI keys in paper_labels are normalized like dois so they match.""" + config = CitationConfig( + dois=["10.1234/example"], + paper_labels={ + "https://doi.org/10.1234/example": "Example (Author 2020)", + "doi.org/10.9012/paper": "Paper (Author 2019)", + "10.5678/other": "Other (Author 2021)", + }, + ) + assert config.paper_labels["10.1234/example"] == "Example (Author 2020)" + assert config.paper_labels["10.9012/paper"] == "Paper (Author 2019)" + assert config.paper_labels["10.5678/other"] == "Other (Author 2021)" + for key in config.paper_labels: + assert not key.startswith("http") + assert not key.startswith("doi.org") + + def test_paper_labels_rejects_invalid_doi_key(self) -> None: + """A malformed DOI key fails loudly rather than silently dropping the label.""" + with pytest.raises(ValidationError, match="Invalid DOI key in paper_labels"): + CitationConfig(paper_labels={"not-a-doi": "Label"}) + + def test_paper_labels_dedup_last_wins(self) -> None: + """Two keys that normalize to the same DOI collapse to one (last wins).""" + config = CitationConfig( + paper_labels={ + "https://doi.org/10.1234/x": "Label B", + "10.1234/x": "Label B", + } + ) + assert config.paper_labels == {"10.1234/x": "Label B"} + + def test_aliases_default_empty(self) -> None: + assert CitationConfig().aliases == {} + + def test_aliases_normalizes_primary_and_versions(self) -> None: + config = CitationConfig( + dois=["10.1234/primary"], + aliases={ + "https://doi.org/10.1234/primary": [ + "https://doi.org/10.1101/preprint", + "10.1101/preprint", # duplicate after normalization + ] + }, + ) + assert config.aliases == {"10.1234/primary": ["10.1101/preprint"]} + + def test_aliases_rejects_invalid_doi(self) -> None: + with pytest.raises(ValidationError, match="Invalid DOI in aliases"): + CitationConfig(dois=["10.1234/primary"], aliases={"10.1234/primary": ["not-a-doi"]}) + + def test_aliases_rejects_empty_version(self) -> None: + with pytest.raises(ValidationError, match="Empty alias version DOI"): + CitationConfig(dois=["10.1234/primary"], aliases={"10.1234/primary": [""]}) + + def test_aliases_primary_must_be_in_dois(self) -> None: + with pytest.raises(ValidationError, match="not present in dois"): + CitationConfig(dois=["10.1234/a"], aliases={"10.1234/b": ["10.1101/x"]}) + def test_deduplicates_queries(self) -> None: """Should deduplicate queries.""" config = CitationConfig(queries=["query 1", "query 1", "query 2"]) diff --git a/tests/test_knowledge/test_citation_stats.py b/tests/test_knowledge/test_citation_stats.py new file mode 100644 index 0000000..2178e68 --- /dev/null +++ b/tests/test_knowledge/test_citation_stats.py @@ -0,0 +1,191 @@ +"""Tests for citation stats aggregation and the cites_doi linkage column. + +Uses a real temporary SQLite database (only the DB path is redirected); no +business logic is mocked. +""" + +from pathlib import Path +from unittest.mock import patch + +import pytest + +from src.knowledge.db import ( + get_connection, + init_db, + replace_citation_counts, + upsert_paper, +) +from src.knowledge.search import CitationStats, get_citation_stats + +DOI_A = "10.1016/j.jneumeth.2003.10.009" +DOI_B = "10.1016/j.neuroimage.2019.05.026" + + +def _add_paper(conn, external_id, *, created_at, cites_doi=None, source="openalex"): + upsert_paper( + conn, + source=source, + external_id=external_id, + title=f"Citing paper {external_id}", + first_message=None, + url=f"https://doi.org/10.test/{external_id}", + created_at=created_at, + cites_doi=cites_doi, + ) + + +@pytest.fixture +def counts_db(tmp_path: Path): + """Temp DB with per-year citation counts for two canonical DOIs.""" + db_path = tmp_path / "knowledge" / "test.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db() + replace_citation_counts(DOI_A, {2019: 2, 2020: 1}, project="eeglab") + replace_citation_counts(DOI_B, {2020: 1, 2021: 1}, project="eeglab") + yield db_path + + +class TestGetCitationStats: + def test_returns_citation_stats_object(self, counts_db: Path): + with patch("src.knowledge.db.get_db_path", return_value=counts_db): + stats = get_citation_stats(project="eeglab") + assert isinstance(stats, CitationStats) + + def test_total_sums_all_counts(self, counts_db: Path): + with patch("src.knowledge.db.get_db_path", return_value=counts_db): + stats = get_citation_stats(project="eeglab") + assert stats.total == 5 # 2+1 + 1+1 + + def test_per_year_aggregates_across_dois(self, counts_db: Path): + with patch("src.knowledge.db.get_db_path", return_value=counts_db): + stats = get_citation_stats(project="eeglab") + assert stats.per_year == {"2019": 2, "2020": 2, "2021": 1} + + def test_per_year_is_sorted_ascending(self, counts_db: Path): + with patch("src.knowledge.db.get_db_path", return_value=counts_db): + stats = get_citation_stats(project="eeglab") + assert list(stats.per_year.keys()) == sorted(stats.per_year.keys()) + + def test_by_paper_stacked_breakdown(self, counts_db: Path): + with patch("src.knowledge.db.get_db_path", return_value=counts_db): + stats = get_citation_stats(project="eeglab") + assert stats.by_paper == { + DOI_A: {"2019": 2, "2020": 1}, + DOI_B: {"2020": 1, "2021": 1}, + } + + def test_replace_overwrites_previous_counts(self, counts_db: Path): + """A re-sync replaces a DOI's histogram wholesale (no stale years).""" + with patch("src.knowledge.db.get_db_path", return_value=counts_db): + replace_citation_counts(DOI_A, {2025: 9}, project="eeglab") + stats = get_citation_stats(project="eeglab") + assert stats.by_paper[DOI_A] == {"2025": 9} + assert "2019" not in stats.per_year # old DOI_A years gone + + def test_empty_database(self, tmp_path: Path): + db_path = tmp_path / "knowledge" / "empty.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db() + stats = get_citation_stats(project="eeglab") + assert stats.total == 0 + assert stats.per_year == {} + assert stats.by_paper == {} + + def test_missing_table_returns_empty(self, tmp_path: Path): + """Before any citation sync (table absent), stats are empty, not an error.""" + db_path = tmp_path / "knowledge" / "noinit.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + # Create the DB file with a connection but never run init_db, so + # citation_counts does not exist. + with get_connection() as conn: + conn.execute("CREATE TABLE placeholder (id INTEGER)") + conn.commit() + stats = get_citation_stats(project="eeglab") + assert stats.total == 0 + assert stats.by_paper == {} + + +class TestCitesDoiUpsert: + def test_backfill_sets_link_on_existing_row(self, tmp_path: Path): + """A row first stored without a link gets it on a later citation sync.""" + db_path = tmp_path / "knowledge" / "test.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db() + with get_connection() as conn: + _add_paper(conn, "p1", created_at="2020", cites_doi=None) + _add_paper(conn, "p1", created_at="2020", cites_doi=DOI_A) + conn.commit() + row = conn.execute( + "SELECT cites_doi FROM papers WHERE external_id = 'p1'" + ).fetchone() + assert row["cites_doi"] == DOI_A + + def test_first_link_wins_over_later_link(self, tmp_path: Path): + """COALESCE keeps the first recorded canonical DOI for overlapping papers.""" + db_path = tmp_path / "knowledge" / "test.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db() + with get_connection() as conn: + _add_paper(conn, "p1", created_at="2020", cites_doi=DOI_A) + _add_paper(conn, "p1", created_at="2020", cites_doi=DOI_B) + conn.commit() + row = conn.execute( + "SELECT cites_doi FROM papers WHERE external_id = 'p1'" + ).fetchone() + assert row["cites_doi"] == DOI_A + + def test_keyword_sync_does_not_erase_link(self, tmp_path: Path): + """A later keyword sync (cites_doi=None) must not clobber an existing link.""" + db_path = tmp_path / "knowledge" / "test.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db() + with get_connection() as conn: + _add_paper(conn, "p1", created_at="2020", cites_doi=DOI_A) + _add_paper(conn, "p1", created_at="2020", cites_doi=None) + conn.commit() + row = conn.execute( + "SELECT cites_doi FROM papers WHERE external_id = 'p1'" + ).fetchone() + assert row["cites_doi"] == DOI_A + + +class TestCitesDoiMigration: + def test_migration_adds_column_to_legacy_papers_table(self, tmp_path: Path): + """A papers table created before cites_doi gains the column via init_db.""" + db_path = tmp_path / "knowledge" / "legacy.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + # Simulate a pre-migration schema: papers without cites_doi. + with get_connection() as conn: + conn.execute( + """ + CREATE TABLE papers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + source TEXT NOT NULL, + external_id TEXT NOT NULL, + title TEXT NOT NULL, + first_message TEXT, + status TEXT NOT NULL DEFAULT 'published', + url TEXT NOT NULL, + created_at TEXT, + synced_at TEXT NOT NULL, + UNIQUE(source, external_id) + ) + """ + ) + conn.commit() + cols_before = [r[1] for r in conn.execute("PRAGMA table_info(papers)")] + assert "cites_doi" not in cols_before + + # Running init_db must migrate the existing table in place. + init_db() + with get_connection() as conn: + cols_after = [r[1] for r in conn.execute("PRAGMA table_info(papers)")] + # The new column is usable for inserts after migration. + _add_paper(conn, "p1", created_at="2020", cites_doi=DOI_A) + conn.commit() + row = conn.execute( + "SELECT cites_doi FROM papers WHERE external_id = 'p1'" + ).fetchone() + + assert "cites_doi" in cols_after + assert row["cites_doi"] == DOI_A diff --git a/tests/test_knowledge/test_faq_feed.py b/tests/test_knowledge/test_faq_feed.py new file mode 100644 index 0000000..e1436f9 --- /dev/null +++ b/tests/test_knowledge/test_faq_feed.py @@ -0,0 +1,213 @@ +"""Tests for the public FAQ feed listing helper. + +Uses a temporary SQLite database populated with real FAQ rows (no mocks of +business logic; only the database path is redirected to a temp file). +""" + +from pathlib import Path +from unittest.mock import patch + +import pytest + +from src.knowledge.db import get_connection, init_db, upsert_faq_entry +from src.knowledge.search import FAQResult, list_faq_entries + + +@pytest.fixture +def faq_db(tmp_path: Path): + """Create a test database populated with FAQ entries.""" + db_path = tmp_path / "knowledge" / "test.db" + + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db() + + with get_connection() as conn: + entries = [ + { + "thread_id": "t1", + "question": "How do I run ICA in EEGLAB?", + "answer": "Use runica via the Tools menu.", + "tags": ["ica", "eeglab"], + "category": "how-to", + "quality_score": 0.95, + "first_message_date": "2020-01-01", + }, + { + "thread_id": "t2", + "question": "Why does my dataset fail to load?", + "answer": "Check the file path and channel locations.", + "tags": ["loading"], + "category": "troubleshooting", + "quality_score": 0.80, + "first_message_date": "2021-06-15", + }, + { + "thread_id": "t3", + "question": "What is a reference electrode?", + "answer": "Contact support@brainproducts.com for hardware details.", + "tags": ["reference"], + "category": "reference", + "quality_score": 0.60, + "first_message_date": "2019-03-20", + }, + ] + for e in entries: + upsert_faq_entry( + conn, + list_name="eeglablist", + thread_id=e["thread_id"], + thread_url=f"https://example.org/{e['thread_id']}", + question=e["question"], + answer=e["answer"], + tags=e["tags"], + category=e["category"], + message_count=3, + participant_count=2, + first_message_date=e["first_message_date"], + quality_score=e["quality_score"], + summary_model="test-model", + ) + conn.commit() + + yield db_path + + +class TestListFAQEntries: + """Tests for list_faq_entries (browse mode, no FTS query).""" + + def test_returns_all_entries_and_total(self, faq_db: Path): + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + entries, total = list_faq_entries(project="eeglab") + + assert total == 3 + assert len(entries) == 3 + assert all(isinstance(e, FAQResult) for e in entries) + + def test_ordered_by_quality_descending(self, faq_db: Path): + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + entries, _ = list_faq_entries(project="eeglab") + + scores = [e.quality_score for e in entries] + assert scores == sorted(scores, reverse=True) + assert entries[0].quality_score == 0.95 + + def test_min_quality_filter(self, faq_db: Path): + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + entries, total = list_faq_entries(project="eeglab", min_quality=0.85) + + assert total == 1 + assert len(entries) == 1 + assert entries[0].quality_score >= 0.85 + + def test_category_filter(self, faq_db: Path): + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + entries, total = list_faq_entries(project="eeglab", category="troubleshooting") + + assert total == 1 + assert entries[0].category == "troubleshooting" + + def test_pagination_limit_and_offset(self, faq_db: Path): + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + page1, total1 = list_faq_entries(project="eeglab", limit=2, offset=0) + page2, total2 = list_faq_entries(project="eeglab", limit=2, offset=2) + + # total is the full count regardless of pagination window + assert total1 == 3 + assert total2 == 3 + assert len(page1) == 2 + assert len(page2) == 1 + # No overlap between pages + page1_questions = {e.question for e in page1} + page2_questions = {e.question for e in page2} + assert page1_questions.isdisjoint(page2_questions) + + def test_empty_database_returns_zero(self, tmp_path: Path): + db_path = tmp_path / "knowledge" / "empty.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db() + entries, total = list_faq_entries(project="eeglab") + + assert total == 0 + assert entries == [] + + def test_list_name_filter(self, tmp_path: Path): + """list_name filter restricts results to a single mailing list.""" + db_path = tmp_path / "knowledge" / "lists.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db() + with get_connection() as conn: + for list_name, thread_id in [ + ("list-a", "a1"), + ("list-a", "a2"), + ("list-b", "b1"), + ]: + upsert_faq_entry( + conn, + list_name=list_name, + thread_id=thread_id, + thread_url=f"https://example.org/{thread_id}", + question=f"Question {thread_id}?", + answer="An answer.", + tags=["t"], + category="how-to", + message_count=2, + participant_count=2, + first_message_date="2020-01-01", + quality_score=0.8, + summary_model="test-model", + ) + conn.commit() + + entries, total = list_faq_entries(project="eeglab", list_name="list-a") + + assert total == 2 + assert len(entries) == 2 + assert {e.question for e in entries} == {"Question a1?", "Question a2?"} + + +class TestListFAQEntriesSearch: + """Search mode of list_faq_entries (query set, via FTS5).""" + + def test_query_matches_entries(self, faq_db: Path): + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + entries, total = list_faq_entries(project="eeglab", query="ICA") + + assert total >= 1 + assert any("ICA" in e.question for e in entries) + + def test_query_no_match_returns_empty(self, faq_db: Path): + with patch("src.knowledge.db.get_db_path", return_value=faq_db): + entries, total = list_faq_entries(project="eeglab", query="zzzznomatchterm") + + assert total == 0 + assert entries == [] + + def test_query_total_is_full_count_not_page_size(self, tmp_path: Path): + """total reflects all FTS matches, independent of the page limit.""" + db_path = tmp_path / "knowledge" / "search.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db() + with get_connection() as conn: + for i in range(3): + upsert_faq_entry( + conn, + list_name="eeglablist", + thread_id=f"c{i}", + thread_url=f"https://example.org/c{i}", + question=f"How do I handle channels in case {i}?", + answer="Inspect the channel locations.", + tags=["channels"], + category="how-to", + message_count=2, + participant_count=2, + first_message_date="2020-01-01", + quality_score=0.8, + summary_model="test-model", + ) + conn.commit() + + page, total = list_faq_entries(project="eeglab", query="channels", limit=1) + + assert len(page) == 1 + assert total == 3 + assert total > len(page) diff --git a/tests/test_knowledge/test_openalex_citations.py b/tests/test_knowledge/test_openalex_citations.py new file mode 100644 index 0000000..dd2f921 --- /dev/null +++ b/tests/test_knowledge/test_openalex_citations.py @@ -0,0 +1,282 @@ +"""Tests for the direct OpenAlex citation client. + +Uses httpx.MockTransport to serve canned OpenAlex responses at the transport +layer (an HTTP fixture, not a mock of business logic) so the client's parsing, +pagination, and error handling are exercised without network access. +""" + +import httpx +import pytest + +from src.knowledge.openalex_citations import ( + CitingPaper, + OpenAlexCitationClient, + _strip_doi, + _strip_id, +) + + +class TestCitesFilter: + def test_single_work_id(self): + assert OpenAlexCitationClient._cites_filter("W1") == "cites:W1" + + def test_multiple_work_ids_or_joined(self): + assert OpenAlexCitationClient._cites_filter(["W1", "W2", "W3"]) == "cites:W1|W2|W3" + + def test_filters_empty_ids(self): + assert OpenAlexCitationClient._cites_filter(["W1", "", "W2"]) == "cites:W1|W2" + + def test_empty_raises(self): + with pytest.raises(ValueError, match="at least one"): + OpenAlexCitationClient._cites_filter([]) + + +def _client(handler) -> OpenAlexCitationClient: + transport = httpx.MockTransport(handler) + return OpenAlexCitationClient(email="t@example.org", client=httpx.Client(transport=transport)) + + +class TestHelpers: + def test_strip_id(self): + assert _strip_id("https://openalex.org/W123") == "W123" + assert _strip_id("W123") == "W123" + assert _strip_id(None) == "" + + def test_strip_doi(self): + assert _strip_doi("https://doi.org/10.1/x") == "10.1/x" + assert _strip_doi("10.1/x") == "10.1/x" + assert _strip_doi(None) is None + + +class TestResolveWorkId: + def test_resolves_doi_to_work_id(self): + def handler(request: httpx.Request) -> httpx.Response: + assert "/works/doi:10.1/x" in str(request.url) + return httpx.Response(200, json={"id": "https://openalex.org/W999"}) + + with _client(handler) as c: + assert c.resolve_work_id("10.1/x") == "W999" + + def test_unresolved_doi_returns_none(self): + def handler(_request: httpx.Request) -> httpx.Response: + return httpx.Response(404, json={"error": "not found"}) + + with _client(handler) as c: + assert c.resolve_work_id("10.1/missing") is None + + def test_includes_mailto_param(self): + seen = {} + + def handler(request: httpx.Request) -> httpx.Response: + seen["mailto"] = request.url.params.get("mailto") + return httpx.Response(200, json={"id": "https://openalex.org/W1"}) + + with _client(handler) as c: + c.resolve_work_id("10.1/x") + assert seen["mailto"] == "t@example.org" + + +class TestCountsByYear: + def test_parses_group_by_counts(self): + def handler(request: httpx.Request) -> httpx.Response: + assert request.url.params.get("group_by") == "publication_year" + assert request.url.params.get("filter") == "cites:W1" + return httpx.Response( + 200, + json={ + "meta": {"count": 17}, + "group_by": [ + {"key": "2024", "count": 10}, + {"key": "2023", "count": 5}, + {"key": "2022", "count": 2}, + ], + }, + ) + + with _client(handler) as c: + counts = c.counts_by_year("W1") + assert counts == {2024: 10, 2023: 5, 2022: 2} + + def test_version_group_uses_or_joined_filter(self): + seen = {} + + def handler(request: httpx.Request) -> httpx.Response: + seen["filter"] = request.url.params.get("filter") + return httpx.Response(200, json={"group_by": [{"key": "2024", "count": 5}]}) + + with _client(handler) as c: + counts = c.counts_by_year(["W1", "W2"]) + assert seen["filter"] == "cites:W1|W2" + assert counts == {2024: 5} + + def test_skips_non_year_buckets(self): + def handler(_request: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + json={ + "group_by": [ + {"key": "2024", "count": 3}, + {"key": "unknown", "count": 9}, + {"key": None, "count": 1}, + ] + }, + ) + + with _client(handler) as c: + counts = c.counts_by_year("W1") + assert counts == {2024: 3} + + +class TestRecentCitingPapers: + def test_paginates_with_cursor(self): + # Two pages: cursor "*" -> two works + next_cursor "p2"; "p2" -> one work, end. + def handler(request: httpx.Request) -> httpx.Response: + cursor = request.url.params.get("cursor") + assert request.url.params.get("sort") == "publication_date:desc" + if cursor == "*": + return httpx.Response( + 200, + json={ + "meta": {"next_cursor": "p2"}, + "results": [ + { + "id": "https://openalex.org/W10", + "doi": "https://doi.org/10.1/a", + "title": "Newest", + "publication_date": "2026-01-01", + }, + { + "id": "https://openalex.org/W11", + "doi": None, + "title": "Second", + "publication_date": "2025-06-01", + }, + ], + }, + ) + return httpx.Response( + 200, + json={ + "meta": {"next_cursor": None}, + "results": [ + { + "id": "https://openalex.org/W12", + "doi": "10.1/c", + "title": "Third", + "publication_date": "2025-01-01", + } + ], + }, + ) + + with _client(handler) as c: + papers = c.recent_citing_papers("W1", limit=100) + + assert [p.openalex_id for p in papers] == ["W10", "W11", "W12"] + assert all(isinstance(p, CitingPaper) for p in papers) + assert papers[0].doi == "10.1/a" # url-form DOI normalized + assert papers[1].doi is None + assert papers[0].url == "https://doi.org/10.1/a" + + def test_respects_limit_across_pages(self): + def handler(request: httpx.Request) -> httpx.Response: + # Always offer a next cursor; the client must stop at the limit. + return httpx.Response( + 200, + json={ + "meta": {"next_cursor": "more"}, + "results": [ + { + "id": f"https://openalex.org/W{request.url.params.get('cursor')}", + "doi": None, + "title": "P", + "publication_date": "2025-01-01", + } + ], + }, + ) + + with _client(handler) as c: + papers = c.recent_citing_papers("W1", limit=3) + assert len(papers) == 3 + + def test_stops_on_empty_results_page(self): + # A non-null cursor with no results must not spin forever. + calls = {"n": 0} + + def handler(request: httpx.Request) -> httpx.Response: + calls["n"] += 1 + if request.url.params.get("cursor") == "*": + return httpx.Response( + 200, + json={ + "meta": {"next_cursor": "p2"}, + "results": [ + { + "id": "https://openalex.org/W1", + "doi": None, + "title": "P", + "publication_date": "2025-01-01", + } + ], + }, + ) + # Second page: cursor still present but no results -> must stop. + return httpx.Response(200, json={"meta": {"next_cursor": "p3"}, "results": []}) + + with _client(handler) as c: + papers = c.recent_citing_papers("W1", limit=100) + assert len(papers) == 1 + assert calls["n"] == 2 # stopped at the empty page, did not continue + + def test_absent_meta_stops_pagination(self): + def handler(_request: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + json={ + "results": [ + { + "id": "https://openalex.org/W1", + "doi": "10.1/x", + "title": "P", + "publication_date": "2025-01-01", + } + ] + }, + ) + + with _client(handler) as c: + papers = c.recent_citing_papers("W1", limit=100) + assert len(papers) == 1 + assert papers[0].url == "https://doi.org/10.1/x" # url built from stripped doi + + def test_skips_titleless_works(self): + def handler(_request: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + json={ + "meta": {"next_cursor": None}, + "results": [ + {"id": "https://openalex.org/W1", "title": None, "doi": None}, + { + "id": "https://openalex.org/W2", + "title": "Has title", + "doi": None, + "publication_date": "2025-01-01", + }, + ], + }, + ) + + with _client(handler) as c: + papers = c.recent_citing_papers("W1", limit=10) + assert [p.openalex_id for p in papers] == ["W2"] + + +class TestErrorPropagation: + def test_http_error_raises(self): + def handler(_request: httpx.Request) -> httpx.Response: + return httpx.Response(500, json={"error": "server"}) + + with _client(handler) as c, pytest.raises(httpx.HTTPStatusError): + c.counts_by_year("W1") diff --git a/tests/test_knowledge/test_papers_sync.py b/tests/test_knowledge/test_papers_sync.py index b23740c..ad6ce47 100644 --- a/tests/test_knowledge/test_papers_sync.py +++ b/tests/test_knowledge/test_papers_sync.py @@ -9,11 +9,13 @@ from pathlib import Path from unittest.mock import patch +import httpx import pytest from opencite import IDSet, Paper import src.knowledge.papers_sync as ps -from src.knowledge.db import get_connection, init_db +from src.knowledge.db import get_connection, init_db, replace_citation_counts +from src.knowledge.openalex_citations import OpenAlexCitationClient from src.knowledge.papers_sync import ( _cache_papers_async, _paper_source_and_id, @@ -26,6 +28,7 @@ sync_citing_papers, sync_openalex_papers, ) +from src.knowledge.search import get_citation_stats @pytest.fixture @@ -165,6 +168,21 @@ def test_upsert_deduplicates_same_paper(self, temp_db: Path): count = conn.execute("SELECT COUNT(*) AS c FROM papers").fetchone()["c"] assert count == 1 + def test_stores_cites_doi_on_each_row(self, temp_db: Path): + # A citation sync threads the canonical DOI through to each stored row. + papers = [ + Paper(title="Citing A", ids=IDSet(openalex_id="https://openalex.org/W1"), year=2023), + Paper(title="Citing B", ids=IDSet(openalex_id="https://openalex.org/W2"), year=2024), + ] + with patch("src.knowledge.db.get_db_path", return_value=temp_db): + _store_papers(papers, "test", cites_doi="10.1/canonical") + with get_connection("test") as conn: + links = { + r["external_id"]: r["cites_doi"] + for r in conn.execute("SELECT external_id, cites_doi FROM papers") + } + assert links == {"W1": "10.1/canonical", "W2": "10.1/canonical"} + def test_force_source_uses_native_id(self, temp_db: Path): # A PubMed-restricted sync should label the row 'pubmed' using the PMID, # even though the paper also carries an OpenAlex id. @@ -314,3 +332,151 @@ def test_sync_all_papers_rejects_bare_string(self) -> None: def test_sync_citing_papers_rejects_bare_string(self) -> None: with pytest.raises(TypeError, match="must be a list of strings"): sync_citing_papers(dois="10.3389/fnins.2013.00267") # type: ignore[arg-type] + + +class TestSyncCitingPapers: + """End-to-end sync via a mock OpenAlex transport (real client + real DB).""" + + def _handler(self, request: httpx.Request) -> httpx.Response: + url = str(request.url) + if "/works/doi:" in url: + return httpx.Response(200, json={"id": "https://openalex.org/W1"}) + if request.url.params.get("group_by") == "publication_year": + return httpx.Response( + 200, + json={"group_by": [{"key": "2024", "count": 3}, {"key": "2025", "count": 7}]}, + ) + # recent citing papers page (single page) + return httpx.Response( + 200, + json={ + "meta": {"next_cursor": None}, + "results": [ + { + "id": "https://openalex.org/W2", + "doi": "10.1/citing-a", + "title": "Citing paper A", + "publication_date": "2025-03-01", + }, + { + "id": "https://openalex.org/W3", + "doi": None, + "title": "Citing paper B", + "publication_date": "2024-09-01", + }, + ], + }, + ) + + def test_stores_true_counts_and_recent_papers(self, tmp_path: Path, monkeypatch) -> None: + def factory(**_kwargs): + transport = httpx.MockTransport(self._handler) + return OpenAlexCitationClient(client=httpx.Client(transport=transport)) + + monkeypatch.setattr(ps, "OpenAlexCitationClient", factory) + + db_path = tmp_path / "knowledge" / "test.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db("test") + stored = sync_citing_papers(["10.1/canon"], project="test") + stats = get_citation_stats("test") + with get_connection("test") as conn: + rows = conn.execute( + "SELECT external_id, cites_doi FROM papers WHERE cites_doi IS NOT NULL" + ).fetchall() + + # Counts come from the (uncapped) group_by histogram, not the stored rows. + assert stats.by_paper == {"10.1/canon": {"2024": 3, "2025": 7}} + assert stats.total == 10 + # Two recent citing papers stored and linked to the canonical DOI. + assert stored == 2 + assert {r["external_id"] for r in rows} == {"W2", "W3"} + assert all(r["cites_doi"] == "10.1/canon" for r in rows) + + def test_unresolved_doi_skipped(self, tmp_path: Path, monkeypatch) -> None: + def handler(_request: httpx.Request) -> httpx.Response: + return httpx.Response(404, json={"error": "not found"}) + + def factory(**_kwargs): + return OpenAlexCitationClient( + client=httpx.Client(transport=httpx.MockTransport(handler)) + ) + + monkeypatch.setattr(ps, "OpenAlexCitationClient", factory) + + db_path = tmp_path / "knowledge" / "test.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db("test") + stored = sync_citing_papers(["10.1/missing"], project="test") + stats = get_citation_stats("test") + + assert stored == 0 + assert stats.total == 0 + + def test_version_aliases_merge_into_primary(self, tmp_path: Path, monkeypatch) -> None: + # Primary + preprint resolve to W1/W2; counts are queried as a group and + # attributed to the primary DOI. + seen = {} + + def handler(request: httpx.Request) -> httpx.Response: + url = str(request.url) + if "/works/doi:10.1/primary" in url: + return httpx.Response(200, json={"id": "https://openalex.org/W1"}) + if "/works/doi:10.1/preprint" in url: + return httpx.Response(200, json={"id": "https://openalex.org/W2"}) + if request.url.params.get("group_by"): + seen["filter"] = request.url.params.get("filter") + return httpx.Response(200, json={"group_by": [{"key": "2024", "count": 12}]}) + return httpx.Response(200, json={"meta": {"next_cursor": None}, "results": []}) + + def factory(**_kwargs): + return OpenAlexCitationClient( + client=httpx.Client(transport=httpx.MockTransport(handler)) + ) + + monkeypatch.setattr(ps, "OpenAlexCitationClient", factory) + + db_path = tmp_path / "knowledge" / "test.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db("test") + sync_citing_papers( + ["10.1/primary"], + project="test", + aliases={"10.1/primary": ["10.1/preprint"]}, + ) + stats = get_citation_stats("test") + + # Both work ids were OR-joined into one cites filter... + assert seen["filter"] == "cites:W1|W2" + # ...and the merged count is attributed to the primary DOI. + assert stats.by_paper == {"10.1/primary": {"2024": 12}} + + def test_empty_counts_does_not_wipe_existing(self, tmp_path: Path, monkeypatch) -> None: + # An empty histogram (likely a transient API gap) must not erase the + # previously stored counts for that canonical DOI. + def handler(request: httpx.Request) -> httpx.Response: + if "/works/doi:" in str(request.url): + return httpx.Response(200, json={"id": "https://openalex.org/W1"}) + if request.url.params.get("group_by"): + return httpx.Response(200, json={"group_by": []}) # transient gap + return httpx.Response(200, json={"meta": {"next_cursor": None}, "results": []}) + + def factory(**_kwargs): + return OpenAlexCitationClient( + client=httpx.Client(transport=httpx.MockTransport(handler)) + ) + + monkeypatch.setattr(ps, "OpenAlexCitationClient", factory) + + db_path = tmp_path / "knowledge" / "test.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db("test") + # Seed good counts as if a prior healthy sync ran. + replace_citation_counts("10.1/canon", {2024: 50, 2025: 80}, project="test") + + stored = sync_citing_papers(["10.1/canon"], project="test") + stats = get_citation_stats("test") + + assert stored == 0 + # Existing histogram is preserved, not wiped to empty. + assert stats.by_paper == {"10.1/canon": {"2024": 50, "2025": 80}}