diff --git a/Makefile b/Makefile index 1bd511bf..dcb75764 100644 --- a/Makefile +++ b/Makefile @@ -143,10 +143,10 @@ metaculus-update-questions: infer: infer-fetch infer-update-questions infer-fetch: - $(MAKE) -C src/questions/infer/fetch || echo "* $@" >> $(MAKE_FAILURE_LOG) + $(MAKE) -C src/orchestration/func_infer_fetch || echo "* $@" >> $(MAKE_FAILURE_LOG) infer-update-questions: - $(MAKE) -C src/questions/infer/update_questions || echo "* $@" >> $(MAKE_FAILURE_LOG) + $(MAKE) -C src/orchestration/func_infer_update || echo "* $@" >> $(MAKE_FAILURE_LOG) acled: acled-fetch acled-update-questions diff --git a/src/_fb_types.py b/src/_fb_types.py index 4e91f36d..99586587 100644 --- a/src/_fb_types.py +++ b/src/_fb_types.py @@ -45,3 +45,26 @@ class SourceQuestionBank: QuestionBank = dict[str, SourceQuestionBank] + + +@dataclass +class UpdateResult: + """Return value of a source's update() method. + + Validates contents on construction: dfq must be a valid QuestionFrame, + each resolution file must be a valid ResolutionFrame. + """ + + dfq: pd.DataFrame + resolution_files: dict[str, pd.DataFrame] | None = None + hash_mapping: dict[str, dict] | None = None + + def __post_init__(self): + """Validate schema constraints.""" + from _schemas import QuestionFrame, ResolutionFrame + + self.dfq = QuestionFrame.validate(self.dfq) + if self.resolution_files: + self.resolution_files = { + qid: ResolutionFrame.validate(df) for qid, df in self.resolution_files.items() + } diff --git a/src/_schemas.py b/src/_schemas.py index 866ec435..3c670723 100644 --- a/src/_schemas.py +++ b/src/_schemas.py @@ -84,6 +84,14 @@ class ResolveReadyFrame(ExplodedQuestionSetFrame): market_value_on_due_date: Series[float] = pa.Field(nullable=True) +class InferFetchFrame(QuestionFrame): + """Output of InferSource.fetch(). QuestionFrame plus transient fields for update().""" + + fetch_datetime: Series[str] + probability: Series[object] = pa.Field(nullable=True) + nullify_question: Series[bool] + + class AcledResolutionFrame(pa.DataFrameModel): """ACLED-specific: aggregated events by country and date. diff --git a/src/helpers/acled.py b/src/helpers/acled.py index 7ad73624..c0ba7905 100644 --- a/src/helpers/acled.py +++ b/src/helpers/acled.py @@ -6,8 +6,13 @@ import numpy as np import pandas as pd +from sources._metadata import SOURCE_METADATA + from . import data_utils +SOURCE_INTRO = SOURCE_METADATA["acled"]["source_intro"] +RESOLUTION_CRITERIA = SOURCE_METADATA["acled"]["resolution_criteria"] + source = "acled" # Lazy import to avoid circular imports at module level @@ -17,9 +22,9 @@ def _get_source(): global _source if _source is None: - from sources import SOURCES + from sources.acled import AcledSource - _source = SOURCES[source] + _source = AcledSource() return _source @@ -84,17 +89,6 @@ def upload_hash_mapping(): https://acleddata.com/knowledge-base/codebook/#acled-events """ -SOURCE_INTRO = ( - "The Armed Conflict Location & Event Data Project (ACLED) collects real-time data on the " - "locations, dates, actors, fatalities, and types of all reported political violence and " - "protest events around the world. You're going to predict how questions based on this data " - "will resolve." -) - -RESOLUTION_CRITERIA = ( - "Resolves to the value calculated from the ACLED dataset once the data is published." -) - def read_dff(local_question_bank_dir=None) -> pd.DataFrame: """ diff --git a/src/helpers/dbnomics.py b/src/helpers/dbnomics.py index da3cfa9b..8403c065 100644 --- a/src/helpers/dbnomics.py +++ b/src/helpers/dbnomics.py @@ -1,5 +1,10 @@ """DBnomics-specific variables.""" +from sources._metadata import SOURCE_METADATA + +SOURCE_INTRO = SOURCE_METADATA["dbnomics"]["source_intro"] +RESOLUTION_CRITERIA = SOURCE_METADATA["dbnomics"]["resolution_criteria"] + FETCH_COLUMN_DTYPE = { "id": str, "period": str, @@ -9,15 +14,6 @@ } FETCH_COLUMNS = list(FETCH_COLUMN_DTYPE.keys()) -SOURCE_INTRO = ( - "DBnomics collects data on topics such as population and living conditions, " - "environment and energy, agriculture, finance, trade and others from publicly available " - "resources, for example national and international statistical institutions, researchers and " - "private companies. You're going to predict how questions based on this data will resolve." -) - -RESOLUTION_CRITERIA = "Resolves to the value found at {url} once the data is published." - METEOFRANCE_STATIONS = [ {"id": "07005", "station": "Abbeville"}, {"id": "07015", "station": "Lille Airport"}, diff --git a/src/helpers/fred.py b/src/helpers/fred.py index e96fa698..5bf31a11 100644 --- a/src/helpers/fred.py +++ b/src/helpers/fred.py @@ -5,15 +5,12 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from sources.fred import NULLIFIED_IDS # noqa: F401, E402 +from sources._metadata import SOURCE_METADATA # noqa: E402 -SOURCE_INTRO = ( - "The Federal Reserve Economic Data database (FRED) provides economic data from national, " - "international, public, and private sources.You're going to predict how questions based on " - "this data will resolve." -) - -RESOLUTION_CRITERIA = "Resolves to the value found at {url} once the data is published." +_META = SOURCE_METADATA["fred"] +SOURCE_INTRO = _META["source_intro"] +RESOLUTION_CRITERIA = _META["resolution_criteria"] +NULLIFIED_IDS = [nq.id for nq in _META["nullified_questions"]] # flake8: noqa: B950 diff --git a/src/helpers/infer.py b/src/helpers/infer.py index 0853c2be..cc536d17 100644 --- a/src/helpers/infer.py +++ b/src/helpers/infer.py @@ -1,9 +1,6 @@ -"""Infer-specific variables.""" +"""Infer-specific variables. Delegates to sources._metadata.""" -SOURCE_INTRO = ( - "We would like you to predict the outcome of a prediction market. A prediction market, in this " - "context, is the aggregate of predictions submitted by users on the website INFER Public. " - "You're going to predict the probability that the market will resolve as 'Yes'." -) +from sources._metadata import SOURCE_METADATA -RESOLUTION_CRITERIA = "Resolves to the outcome of the question found at {url}." +SOURCE_INTRO = SOURCE_METADATA["infer"]["source_intro"] +RESOLUTION_CRITERIA = SOURCE_METADATA["infer"]["resolution_criteria"] diff --git a/src/helpers/manifold.py b/src/helpers/manifold.py index beafa081..d7921ec5 100644 --- a/src/helpers/manifold.py +++ b/src/helpers/manifold.py @@ -1,9 +1,6 @@ -"""Manifold-specific variables.""" +"""Manifold-specific variables. Delegates to sources._metadata.""" -SOURCE_INTRO = ( - "We would like you to predict the outcome of a prediction market. A prediction market, in this " - "context, is the aggregate of predictions submitted by users on the website Manifold. " - "You're going to predict the probability that the market will resolve as 'Yes'." -) +from sources._metadata import SOURCE_METADATA -RESOLUTION_CRITERIA = "Resolves to the outcome of the question found at {url}." +SOURCE_INTRO = SOURCE_METADATA["manifold"]["source_intro"] +RESOLUTION_CRITERIA = SOURCE_METADATA["manifold"]["resolution_criteria"] diff --git a/src/helpers/metaculus.py b/src/helpers/metaculus.py index b94e50cf..2037f1a0 100644 --- a/src/helpers/metaculus.py +++ b/src/helpers/metaculus.py @@ -1,5 +1,10 @@ """Metaculus-specific variables.""" +from sources._metadata import SOURCE_METADATA + +SOURCE_INTRO = SOURCE_METADATA["metaculus"]["source_intro"] +RESOLUTION_CRITERIA = SOURCE_METADATA["metaculus"]["resolution_criteria"] + CATEGORIES = [ "artificial-intelligence", "computing-and-math", @@ -16,12 +21,3 @@ "sports-entertainment", "technology", ] - -SOURCE_INTRO = ( - "We would like you to predict the outcome of a prediction market. A prediction market, in this " - "context, is the aggregate of predictions submitted by users on the website Metaculus. " - "You're going to predict the probability that the market will resolve as 'Yes'." -) - - -RESOLUTION_CRITERIA = "Resolves to the outcome of the question found at {url}." diff --git a/src/helpers/polymarket.py b/src/helpers/polymarket.py index 468e5bcd..f35aea64 100644 --- a/src/helpers/polymarket.py +++ b/src/helpers/polymarket.py @@ -5,12 +5,9 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from sources.polymarket import NULLIFIED_QUESTION_IDS # noqa: F401, E402 +from sources._metadata import SOURCE_METADATA # noqa: E402 -SOURCE_INTRO = ( - "We would like you to predict the outcome of a prediction market. A prediction market, in this " - "context, is the aggregate of predictions submitted by users on the website Polymarket. " - "You're going to predict the probability that the market will resolve as 'Yes'." -) - -RESOLUTION_CRITERIA = "Resolves to the outcome of the question found at {url}." +_META = SOURCE_METADATA["polymarket"] +SOURCE_INTRO = _META["source_intro"] +RESOLUTION_CRITERIA = _META["resolution_criteria"] +NULLIFIED_QUESTION_IDS = {nq.id for nq in _META["nullified_questions"]} diff --git a/src/helpers/wikipedia.py b/src/helpers/wikipedia.py index 3a6f2da7..234cf103 100644 --- a/src/helpers/wikipedia.py +++ b/src/helpers/wikipedia.py @@ -12,6 +12,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from sources._metadata import SOURCE_METADATA # noqa: E402 from sources.wikipedia import _IDS_TO_NULLIFY as IDS_TO_NULLIFY # noqa: F401, E402 from sources.wikipedia import ( # noqa: F401, E402 _TRANSFORM_ID_MAPPING as transform_id_mapping, @@ -20,6 +21,9 @@ from . import constants # noqa: E402 +SOURCE_INTRO = SOURCE_METADATA["wikipedia"]["source_intro"] +RESOLUTION_CRITERIA = SOURCE_METADATA["wikipedia"]["resolution_criteria"] + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -43,20 +47,12 @@ def _get_source(): global _source if _source is None: - from sources import SOURCES + from sources.wikipedia import WikipediaSource - _source = SOURCES[source] + _source = WikipediaSource() return _source -SOURCE_INTRO = ( - "Wikipedia is an online encyclopedia created and edited by volunteers. You're going to predict " - "how questions based on data sourced from Wikipedia will resolve." -) - -RESOLUTION_CRITERIA = "Resolves to the value calculated from {url} on the resolution date." - - def transform_id(wid): """Transform old id to new id.""" return _get_source()._transform_id(wid) diff --git a/src/helpers/yfinance.py b/src/helpers/yfinance.py index 1d159f44..a0516027 100644 --- a/src/helpers/yfinance.py +++ b/src/helpers/yfinance.py @@ -1,13 +1,6 @@ -"""Yfinance-specific variables.""" +"""Yfinance-specific variables. Delegates to sources._metadata.""" -SOURCE_INTRO = ( - "Yahoo Finance provides financial data on stocks, bonds, and currencies and also offers news, " - "commentary and tools for personal financial management. You're going to predict how questions " - "based on this data will resolve." -) +from sources._metadata import SOURCE_METADATA -RESOLUTION_CRITERIA = ( - "Resolves to the market close price at {url} for the resolution date. If the resolution date " - "coincides with a day the market is closed (weekend, holiday, etc.) the previous market close " - "price is used." -) +SOURCE_INTRO = SOURCE_METADATA["yfinance"]["source_intro"] +RESOLUTION_CRITERIA = SOURCE_METADATA["yfinance"]["resolution_criteria"] diff --git a/src/orchestration/__init__.py b/src/orchestration/__init__.py index 5a5b6a64..d6108c51 100644 --- a/src/orchestration/__init__.py +++ b/src/orchestration/__init__.py @@ -1 +1 @@ -"""Orchestration layer for resolve pipeline.""" +"""Orchestration layer.""" diff --git a/src/orchestration/_source_io.py b/src/orchestration/_source_io.py new file mode 100644 index 00000000..10a7d436 --- /dev/null +++ b/src/orchestration/_source_io.py @@ -0,0 +1,108 @@ +"""Shared IO helpers for source fetch/update orchestration.""" + +from __future__ import annotations + +import json +import logging +import os +from typing import Iterable + +import pandas as pd + +from helpers import constants, data_utils, env +from utils import gcp + +logger = logging.getLogger(__name__) + + +def write_fetch_output(source: str, dff: pd.DataFrame) -> None: + """Write fetch DataFrame to _fetch.jsonl and upload. + + Args: + source (str): Source name (e.g. "infer"). + dff (pd.DataFrame): Fetched data to write. + """ + filenames = data_utils.generate_filenames(source) + local = filenames["local_fetch"] + with open(local, "w", encoding="utf-8") as f: + for record in dff.to_dict(orient="records"): + f.write(json.dumps(record, ensure_ascii=False) + "\n") + logger.info(f"Uploading {filenames['jsonl_fetch']} to GCP...") + gcp.storage.upload( + bucket_name=env.QUESTION_BANK_BUCKET, + local_filename=local, + ) + + +def load_existing_resolution_files( + source: str, + ids: Iterable[str] | None = None, +) -> dict[str, pd.DataFrame]: + """Download /.jsonl resolution files. + + If ids is given, download only those. If ids is None, list the bucket and + download every .jsonl under / — use sparingly, scales with backlog. + + Args: + source (str): Source name (e.g. "infer"). + ids (Iterable[str] | None): Specific question IDs to load. If None, + load every resolution file present in the bucket for this source. + + Returns: + dict mapping question_id to its resolution DataFrame. + """ + if ids is None: + paths = gcp.storage.list_with_prefix( + bucket_name=env.QUESTION_BANK_BUCKET, prefix=f"{source}/" + ) + question_ids = [ + os.path.basename(p).removesuffix(".jsonl") for p in paths if p.endswith(".jsonl") + ] + else: + question_ids = [str(qid) for qid in ids] + + result: dict[str, pd.DataFrame] = {} + for question_id in question_ids: + basename = f"{question_id}.jsonl" + remote_path = f"{source}/{basename}" + local_filename = f"/tmp/{source}_{basename}" + + gcp.storage.download_no_error_message_on_404( + bucket_name=env.QUESTION_BANK_BUCKET, + filename=remote_path, + local_filename=local_filename, + ) + if os.path.exists(local_filename): + df = pd.read_json( + local_filename, + lines=True, + dtype=constants.RESOLUTION_FILE_COLUMN_DTYPE, + convert_dates=False, + ) + if not df.empty: + result[question_id] = df + logger.info(f"Loaded {len(result)} existing resolution files for {source}.") + return result + + +def upload_resolution_files(source: str, resolution_files: dict[str, pd.DataFrame]) -> None: + """Upload per-question resolution files to /.jsonl. + + Args: + source (str): Source name (e.g. "infer"). + resolution_files (dict): Mapping of question_id to resolution DataFrame. + """ + for question_id, df in resolution_files.items(): + basename = f"{question_id}.jsonl" + remote_filename = f"{source}/{basename}" + local_filename = f"/tmp/{basename}" + + df[["id", "date", "value"]].to_json( + local_filename, orient="records", lines=True, date_format="iso" + ) + gcp.storage.upload( + bucket_name=env.QUESTION_BANK_BUCKET, + local_filename=local_filename, + filename=remote_filename, + ) + logger.info(f"Uploaded {len(resolution_files)} resolution files for {source}.") diff --git a/src/questions/infer/fetch/Makefile b/src/orchestration/func_infer_fetch/Makefile similarity index 57% rename from src/questions/infer/fetch/Makefile rename to src/orchestration/func_infer_fetch/Makefile index 60f1c197..2542fa0e 100644 --- a/src/questions/infer/fetch/Makefile +++ b/src/orchestration/func_infer_fetch/Makefile @@ -9,20 +9,26 @@ UPLOAD_DIR = upload .gcloudignore: cp -r $(ROOT_DIR)src/helpers/.gcloudignore . -Procfile: - cp -r $(ROOT_DIR)src/helpers/Procfile . +Dockerfile: $(ROOT_DIR)src/helpers/Dockerfile.template + sed \ + -e 's/REGION/$(CLOUD_DEPLOY_REGION)/g' \ + -e 's/STACK/google-22-full/g' \ + -e 's/PYTHON_VERSION/python312/g' \ + $< > Dockerfile -deploy : main.py .gcloudignore requirements.txt Procfile +deploy : .gcloudignore requirements.txt Dockerfile mkdir -p $(UPLOAD_DIR) cp -r $(ROOT_DIR)utils $(UPLOAD_DIR)/ - cp -r $(ROOT_DIR)src/helpers $(UPLOAD_DIR)/ - cp -r $(ROOT_DIR)src/sources $(UPLOAD_DIR)/ - cp $(ROOT_DIR)src/_fb_types.py $(UPLOAD_DIR)/ - cp $(ROOT_DIR)src/_schemas.py $(UPLOAD_DIR)/ + cp -r $(ROOT_DIR)src/helpers $(UPLOAD_DIR)/helpers + cp -r $(ROOT_DIR)src/sources $(UPLOAD_DIR)/sources mkdir -p $(UPLOAD_DIR)/orchestration cp $(ROOT_DIR)src/orchestration/__init__.py $(UPLOAD_DIR)/orchestration/ - cp $(ROOT_DIR)src/orchestration/_io.py $(UPLOAD_DIR)/orchestration/ - cp $^ $(UPLOAD_DIR)/ + cp $(ROOT_DIR)src/orchestration/_source_io.py $(UPLOAD_DIR)/orchestration/ + cp $(ROOT_DIR)src/_fb_types.py $(UPLOAD_DIR)/ + cp $(ROOT_DIR)src/_schemas.py $(UPLOAD_DIR)/ + cp main.py $(UPLOAD_DIR)/main.py + cp requirements.txt $(UPLOAD_DIR)/requirements.txt + cp Dockerfile $(UPLOAD_DIR)/ gcloud run jobs deploy \ func-data-infer-fetch \ --project $(CLOUD_PROJECT) \ @@ -37,4 +43,4 @@ deploy : main.py .gcloudignore requirements.txt Procfile --source $(UPLOAD_DIR) clean : - rm -rf $(UPLOAD_DIR) .gcloudignore Procfile + rm -rf $(UPLOAD_DIR) .gcloudignore Dockerfile diff --git a/src/orchestration/func_infer_fetch/main.py b/src/orchestration/func_infer_fetch/main.py new file mode 100644 index 00000000..f39baf1e --- /dev/null +++ b/src/orchestration/func_infer_fetch/main.py @@ -0,0 +1,37 @@ +"""INFER fetch entry point.""" + +from __future__ import annotations + +import logging +from typing import Any + +from helpers import data_utils, decorator, env, keys +from orchestration import _source_io +from sources.infer import InferSource +from utils import gcp + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +SOURCE = "infer" + + +@decorator.log_runtime +def driver(_: Any) -> None: + """Fetch INFER questions and upload to question bank.""" + source = InferSource() + source.api_key = keys.API_KEY_INFER + + dfq = data_utils.get_data_from_cloud_storage(SOURCE, return_question_data=True) + files_in_storage = gcp.storage.list_with_prefix( + bucket_name=env.QUESTION_BANK_BUCKET, prefix=SOURCE + ) + + dff = source.fetch(dfq=dfq, files_in_storage=files_in_storage) + + _source_io.write_fetch_output(SOURCE, dff) + logger.info("Done.") + + +if __name__ == "__main__": + driver(None) diff --git a/src/questions/infer/fetch/requirements.txt b/src/orchestration/func_infer_fetch/requirements.txt similarity index 94% rename from src/questions/infer/fetch/requirements.txt rename to src/orchestration/func_infer_fetch/requirements.txt index 026bdf64..37337e71 100644 --- a/src/questions/infer/fetch/requirements.txt +++ b/src/orchestration/func_infer_fetch/requirements.txt @@ -1,8 +1,9 @@ google-cloud-storage -requests -certifi google-cloud-secret-manager pandas>=2.2.2,<3.0 -backoff pandera termcolor +requests +certifi +backoff +numpy diff --git a/src/questions/infer/update_questions/Makefile b/src/orchestration/func_infer_update/Makefile similarity index 57% rename from src/questions/infer/update_questions/Makefile rename to src/orchestration/func_infer_update/Makefile index 2c9c83da..c03e6e91 100644 --- a/src/questions/infer/update_questions/Makefile +++ b/src/orchestration/func_infer_update/Makefile @@ -9,20 +9,26 @@ UPLOAD_DIR = upload .gcloudignore: cp -r $(ROOT_DIR)src/helpers/.gcloudignore . -Procfile: - cp -r $(ROOT_DIR)src/helpers/Procfile . +Dockerfile: $(ROOT_DIR)src/helpers/Dockerfile.template + sed \ + -e 's/REGION/$(CLOUD_DEPLOY_REGION)/g' \ + -e 's/STACK/google-22-full/g' \ + -e 's/PYTHON_VERSION/python312/g' \ + $< > Dockerfile -deploy : main.py .gcloudignore requirements.txt Procfile +deploy : .gcloudignore requirements.txt Dockerfile mkdir -p $(UPLOAD_DIR) cp -r $(ROOT_DIR)utils $(UPLOAD_DIR)/ - cp -r $(ROOT_DIR)src/helpers $(UPLOAD_DIR)/ - cp -r $(ROOT_DIR)src/sources $(UPLOAD_DIR)/ - cp $(ROOT_DIR)src/_fb_types.py $(UPLOAD_DIR)/ - cp $(ROOT_DIR)src/_schemas.py $(UPLOAD_DIR)/ + cp -r $(ROOT_DIR)src/helpers $(UPLOAD_DIR)/helpers + cp -r $(ROOT_DIR)src/sources $(UPLOAD_DIR)/sources mkdir -p $(UPLOAD_DIR)/orchestration cp $(ROOT_DIR)src/orchestration/__init__.py $(UPLOAD_DIR)/orchestration/ - cp $(ROOT_DIR)src/orchestration/_io.py $(UPLOAD_DIR)/orchestration/ - cp $^ $(UPLOAD_DIR)/ + cp $(ROOT_DIR)src/orchestration/_source_io.py $(UPLOAD_DIR)/orchestration/ + cp $(ROOT_DIR)src/_fb_types.py $(UPLOAD_DIR)/ + cp $(ROOT_DIR)src/_schemas.py $(UPLOAD_DIR)/ + cp main.py $(UPLOAD_DIR)/main.py + cp requirements.txt $(UPLOAD_DIR)/requirements.txt + cp Dockerfile $(UPLOAD_DIR)/ gcloud run jobs deploy \ func-data-infer-update-questions \ --project $(CLOUD_PROJECT) \ @@ -37,4 +43,4 @@ deploy : main.py .gcloudignore requirements.txt Procfile --source $(UPLOAD_DIR) clean : - rm -rf $(UPLOAD_DIR) .gcloudignore Procfile + rm -rf $(UPLOAD_DIR) .gcloudignore Dockerfile diff --git a/src/orchestration/func_infer_update/main.py b/src/orchestration/func_infer_update/main.py new file mode 100644 index 00000000..0189a413 --- /dev/null +++ b/src/orchestration/func_infer_update/main.py @@ -0,0 +1,41 @@ +"""INFER update entry point.""" + +from __future__ import annotations + +import logging +from typing import Any + +from helpers import data_utils, decorator, keys +from orchestration import _source_io +from sources.infer import InferSource + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +SOURCE = "infer" + + +@decorator.log_runtime +def driver(_: Any) -> None: + """Update INFER questions and resolution files.""" + source = InferSource() + source.api_key = keys.API_KEY_INFER + + dfq, dff = data_utils.get_data_from_cloud_storage( + SOURCE, return_question_data=True, return_fetch_data=True + ) + existing_resolution_files = _source_io.load_existing_resolution_files( + SOURCE, ids=dff["id"].astype(str).tolist() + ) + + result = source.update(dfq, dff, existing_resolution_files=existing_resolution_files) + + logger.info("Uploading to GCP...") + data_utils.upload_questions(result.dfq, SOURCE) + if result.resolution_files: + _source_io.upload_resolution_files(SOURCE, result.resolution_files) + logger.info("Done.") + + +if __name__ == "__main__": + driver(None) diff --git a/src/questions/infer/update_questions/requirements.txt b/src/orchestration/func_infer_update/requirements.txt similarity index 73% rename from src/questions/infer/update_questions/requirements.txt rename to src/orchestration/func_infer_update/requirements.txt index 2fdeb8d2..37337e71 100644 --- a/src/questions/infer/update_questions/requirements.txt +++ b/src/orchestration/func_infer_update/requirements.txt @@ -3,3 +3,7 @@ google-cloud-secret-manager pandas>=2.2.2,<3.0 pandera termcolor +requests +certifi +backoff +numpy diff --git a/src/orchestration/func_resolve/main.py b/src/orchestration/func_resolve/main.py index ceaf7809..685abca0 100644 --- a/src/orchestration/func_resolve/main.py +++ b/src/orchestration/func_resolve/main.py @@ -17,7 +17,8 @@ from resolve._prepare import check_and_prepare_forecast_file, set_resolution_dates from resolve.explode_question_set import explode_question_set from resolve.resolve_all import resolve_all -from sources import DATASET_SOURCE_NAMES, MARKET_SOURCE_NAMES, SOURCES +from sources import DATASET_SOURCE_NAMES, MARKET_SOURCE_NAMES +from sources.registry import SOURCES logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) diff --git a/src/orchestration/func_resolve/requirements.txt b/src/orchestration/func_resolve/requirements.txt index 42a601dc..4c503990 100644 --- a/src/orchestration/func_resolve/requirements.txt +++ b/src/orchestration/func_resolve/requirements.txt @@ -10,3 +10,4 @@ slack_sdk pandera pytz python-dateutil +backoff \ No newline at end of file diff --git a/src/questions/infer/fetch/main.py b/src/questions/infer/fetch/main.py deleted file mode 100644 index 4a643942..00000000 --- a/src/questions/infer/fetch/main.py +++ /dev/null @@ -1,256 +0,0 @@ -"""INFER fetch new questions script.""" - -import json -import logging -import os -import sys - -import backoff -import certifi -import pandas as pd -import requests - -sys.path.append(os.path.join(os.path.dirname(__file__), "../../..")) -from helpers import data_utils, dates, decorator, env, keys # noqa: E402 - -sys.path.append(os.path.join(os.path.dirname(__file__), "../../../..")) -from utils import gcp # noqa: E402 - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -SOURCE = "infer" -INFER_URL = "https://www.randforecastinginitiative.org" - - -@backoff.on_exception( - backoff.expo, - requests.exceptions.RequestException, - max_time=300, - on_backoff=data_utils.print_error_info_handler, -) -def fetch_questions(potentially_closed_ids=None): - """ - Fetch all questions from a specified API endpoint. - - Iterates over pages of questions from the given base URL, authenticating - with the provided headers. Continues fetching until no more questions are - available. - - Parameters: - - potentially_closed_ids (dict): ids for questions that may or may not have been closed. - - Returns: - - list: A list of all questions fetched from the API. - """ - endpoint = INFER_URL + "/api/v1/questions" - headers = {"Authorization": f"Bearer {keys.API_KEY_INFER}"} - params = { - "page": 0, - "status": "active", - } - if potentially_closed_ids is not None: - params.update( - { - "status": "all", - "ids": ",".join(sorted(potentially_closed_ids)), - } - ) - - questions = [] - seen_ids = set() - while True: - response = requests.get(endpoint, params=params, headers=headers, verify=certifi.where()) - if not response.ok: - logger.error(f"Request to Infer questions endpoint failed with params: {params}") - response.raise_for_status() - - new_questions = response.json().get("questions", []) - if not new_questions: - break - - for q in new_questions: - if q["id"] not in seen_ids: - questions.append(q) - seen_ids.add(q["id"]) - - params["page"] += 1 - - return questions - - -def get_data(dfq): - """ - Fetch and prepare question data for processing. - - This function performs several key operations: - - Retrieves IDs of unresolved question and resolved questions without resolution files - from the current data. - - Fetches all active, unresolved, and binary questions using provided API endpoints. - - Filters out duplicated questions. - - Augments and restructures question data for further processing. - - Parameters: - - dfq (list of dict): A list of dictionaries containing data on questions, - where each dictionary represents a question and must have keys 'id' and 'resolved'. - - Returns: - - DataFrame: Each row representing a question ready for processing. - This includes a mix of newly fetched binary questions and existing unresolved questions, - with additional metadata and reformatted fields for consistency. - """ - resolved_ids = dfq[dfq["resolved"]]["id"].tolist() if not dfq.empty else [] - unresolved_ids = dfq[~dfq["resolved"]]["id"].tolist() if not dfq.empty else [] - logger.info(f"Number resolved_ids: {len(resolved_ids)}") - logger.info(f"Number unresolved_ids: {len(unresolved_ids)}") - - files_in_storage = gcp.storage.list_with_prefix( - bucket_name=env.QUESTION_BANK_BUCKET, prefix=SOURCE - ) - - resolved_ids_without_files_in_storage = [ - id for id in resolved_ids if f"{SOURCE}/{id}.jsonl" not in files_in_storage - ] - logger.info(f"resolved_ids_without_resolution_files: {resolved_ids_without_files_in_storage}") - - all_existing_ids_to_fetch = unresolved_ids + resolved_ids_without_files_in_storage - all_existing_questions = ( - fetch_questions(potentially_closed_ids=all_existing_ids_to_fetch) - if all_existing_ids_to_fetch - else [] - ) - - all_active_binary_questions = [ - q - for q in fetch_questions() - if q["state"] == "active" - and q["type"] == "Forecast::YesNoQuestion" - and q["answers"][0]["predictions_count"] > 0 - ] - - # Convert all_new_questions to a set of IDs for faster lookup - all_active_binary_question_ids = set(q["id"] for q in all_active_binary_questions) - - # Filter out questions from all_existing_questions if their IDs are in all_new_questions_ids - all_existing_questions = [ - q for q in all_existing_questions if q["id"] not in all_active_binary_question_ids - ] - - all_questions_to_add = all_active_binary_questions + all_existing_questions - - logger.info(f"Number of questions fetched: {len(all_questions_to_add)}") - current_time = dates.get_datetime_now() - questions_to_add = [] - for q in all_questions_to_add: - # There was a bug that pulled questions that we do not want to include in the question set. - # This field nullifies those questions and ensures the questions will not be resolved, even though - # they were included in the 2024-07-21 question set. - nullify_question = q["type"] != "Forecast::YesNoQuestion" - - # We use 'scoring_end_time_str' to ensure the closure time reflects when forecasts were - # actually scored. This is crucial because sometimes an administrator may resolve - # questions after the actual resolution is happened. - scoring_end_time_str = ( - dates.convert_datetime_str_to_iso_utc(q["scoring_end_time"]) - if q["scoring_end_time"] - else "N/A" - ) - ended_at_str = dates.convert_zulu_to_iso(q["ends_at"]) if q["ends_at"] else "N/A" - final_closed_at_str = ( - "N/A" - if scoring_end_time_str == "N/A" and ended_at_str == "N/A" - else ( - ended_at_str - if scoring_end_time_str == "N/A" - else ( - scoring_end_time_str - if ended_at_str == "N/A" - else min(scoring_end_time_str, ended_at_str) - ) - ) - ) - - scoring_start_time_str = ( - dates.convert_datetime_str_to_iso_utc(q["scoring_start_time"]) - if q["scoring_start_time"] - else "N/A" - ) - resolved_at_str = dates.convert_zulu_to_iso(q["resolved_at"]) if q["resolved_at"] else "N/A" - final_resolved_str = ( - "N/A" - if resolved_at_str == "N/A" and final_closed_at_str == "N/A" - else ( - final_closed_at_str - if resolved_at_str == "N/A" - else ( - resolved_at_str - if final_closed_at_str == "N/A" - else min(resolved_at_str, final_closed_at_str) - ) - ) - ) - - forecast_yes = "N/A" - if len(q["answers"]) == 2 and not nullify_question: - yes_index = 0 if q["answers"][0]["name"].lower() == "yes" else 1 - forecast_yes = q["answers"][yes_index]["probability"] - - questions_to_add.append( - { - "id": str(q["id"]), - "question": q["name"], - "background": q["description"], - "market_info_resolution_criteria": ( - " ".join([content["content"] for content in q["clarifications"]]) - if q["clarifications"] - else "N/A" - ), - "market_info_open_datetime": scoring_start_time_str, - "market_info_close_datetime": final_closed_at_str, - "url": f"{INFER_URL}/questions/{q['id']}", - "resolved": q.get("resolved?", False), - "market_info_resolution_datetime": ( - "N/A" if not q.get("resolved?", False) else final_resolved_str - ), - "fetch_datetime": current_time, - "probability": forecast_yes, - "forecast_horizons": "N/A", - "freeze_datetime_value": forecast_yes, - "freeze_datetime_value_explanation": "The crowd forecast.", - "nullify_question": nullify_question, - } - ) - - return pd.DataFrame(questions_to_add) - - -@decorator.log_runtime -def driver(_): - """Execute the main workflow of fetching, processing, and uploading questions.""" - # Download existing questions from cloud storage - dfq = data_utils.get_data_from_cloud_storage(SOURCE, return_question_data=True) - - filenames = data_utils.generate_filenames(SOURCE) - - # Get the latest data - all_questions_to_add = get_data(dfq) - - # Save and upload - with open(filenames["local_fetch"], "w", encoding="utf-8") as f: - # can't use `dfq.to_json` because we don't want escape chars - for record in all_questions_to_add.to_dict(orient="records"): - json_str = json.dumps(record, ensure_ascii=False) - f.write(json_str + "\n") - - # Upload - logger.info("Uploading to GCP...") - gcp.storage.upload( - bucket_name=env.QUESTION_BANK_BUCKET, - local_filename=filenames["local_fetch"], - ) - - logger.info("Done.") - - -if __name__ == "__main__": - driver(None) diff --git a/src/questions/infer/update_questions/main.py b/src/questions/infer/update_questions/main.py deleted file mode 100644 index ba46ddb8..00000000 --- a/src/questions/infer/update_questions/main.py +++ /dev/null @@ -1,296 +0,0 @@ -"""INFER update question script.""" - -import logging -import os -import sys -import time -from datetime import timedelta, timezone - -import certifi -import numpy as np -import pandas as pd -import requests - -sys.path.append(os.path.join(os.path.dirname(__file__), "../../..")) # noqa: E402 -from helpers import constants, data_utils, dates, decorator, env, keys # noqa: E402 - -sys.path.append(os.path.join(os.path.dirname(__file__), "../../../..")) -from utils import gcp # noqa: E402 - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) -endpoint = "https://www.randforecastinginitiative.org/api/v1/prediction_sets" -SOURCE = "infer" - - -def get_historical_forecasts(current_df, id): - """ - Fetch historical forecasts from a specified API endpoint and integrate them with a given DataFrame. - - This function retrieves all forecast records for a specific question identified by 'id' from the - API until it reaches a record before the latest date in 'current_df'. It processes these forecasts - and merges them with the existing DataFrame, handling cases where the DataFrame is initially empty. - - Parameters: - current_df (pd.DataFrame): The existing DataFrame containing previous forecast data. - id (int): The unique identifier for the forecast question. - - Returns: - pd.DataFrame: A DataFrame containing combined old and new forecast data, sorted. - """ - params = {"question_id": id, "page": 0} - headers = {"Authorization": f"Bearer {keys.API_KEY_INFER}"} - all_responses = [] - current_time = dates.get_datetime_today_midnight() - - # Check if 'current_df' is not empty and contains the 'datetime' column - last_date = ( - pd.to_datetime(current_df["date"].iloc[-1]).tz_localize("UTC") - if not current_df.empty and "date" in current_df.columns - else constants.BENCHMARK_START_DATE_DATETIME.replace(tzinfo=timezone.utc) - ) - - while True: - try: - logger.info(f"Fetched page: {params['page']}, for question ID: {id}") - response = requests.get( - endpoint, params=params, headers=headers, verify=certifi.where() - ) - response.raise_for_status() - new_responses = response.json().get("prediction_sets", []) - all_responses.extend(new_responses) - if ( - not new_responses - or pd.to_datetime(new_responses[-1]["created_at"], utc=True) <= last_date - ): - break - - params["page"] += 1 - except requests.exceptions.HTTPError as e: - if e.response.status_code != 429: - raise - logger.error("Rate limit reached, waiting 10s before retrying...") - time.sleep(10) - - all_forecasts = [] - for forecast in all_responses: - if current_df.empty or pd.to_datetime(forecast["created_at"], utc=True) > last_date: - if len(forecast["predictions"]) == 2: - forecast_yes = forecast["predictions"][0] - if forecast_yes["answer_name"] == "No": - forecast_yes = forecast["predictions"][1] - elif len(forecast["predictions"]) == 1: - forecast_yes = forecast["predictions"][0] - - all_forecasts.append( - ( - dates.convert_zulu_to_iso(forecast["created_at"]), - forecast_yes["final_probability"], - ) - ) - - df = pd.DataFrame(all_forecasts, columns=["date", "value"]) - df["date"] = pd.to_datetime(df["date"]) - - df = df[df["date"].dt.date < current_time.date()] - - df["value"] = df["value"].astype(float) - df["id"] = id - - # Sort by datetime first - df_sorted = df.sort_values("date") - - # Reset index after sorting - df_sorted.reset_index(drop=True, inplace=True) - - # Convert datetime to date only after sorting - df_sorted["date"] = df_sorted["date"].dt.date - df_final = df_sorted[["id", "date", "value"]] - - # Check if the existing dataframe is empty - if current_df.empty: - # Directly return if there's no existing data to merge - result_df = df_final.drop_duplicates(subset=["id", "date"], keep="last") - else: - # Process current dataframe similarly - current_df["date"] = pd.to_datetime(current_df["date"]).dt.date - current_df_final = current_df[["id", "date", "value"]] - # Concatenate and remove duplicates - result_df = ( - pd.concat([current_df_final, df_final], axis=0) - .sort_values(by=["date"], ascending=True) # Ensure sorting by date for consistency - .drop_duplicates(subset=["id", "date"], keep="last") - .reset_index(drop=True) - ) - - # fill in mising date with previous date's value - result_df.loc[:, "date"] = pd.to_datetime(result_df["date"]).dt.tz_localize("UTC") - result_df = result_df.infer_objects() - result_df = result_df.sort_values(by="date") - # Reindex to fill in missing dates - all_dates = pd.date_range( - start=result_df["date"].min(), end=current_time - timedelta(days=1), freq="D" - ) - result_df = result_df.set_index("date").reindex(all_dates, method="ffill").reset_index() - - result_df["id"] = id - result_df.reset_index(inplace=True) - result_df.rename(columns={"index": "date"}, inplace=True) - - return result_df[["id", "date", "value"]] - - -def create_resolution_file(question, resolved): - """ - Create or update a resolution file based on the question ID provided. Download the existing file, if any. - - Check the last entry date, and update with new data if there's no entry for today. Upload the updated file - back to the specified Google Cloud Platform bucket. - - Args: - - question (dict): A dictionary containing at least the 'id' of the question. - - Returns: - - DataFrame: Return the current state of the resolution file as a DataFrame if no update is needed. - If an update occurs, the function returns None after uploading the updated file. - """ - basename = f"{question['id']}.jsonl" - remote_filename = f"{SOURCE}/{basename}" - local_filename = "/tmp/tmp.jsonl" - yesterday = dates.get_datetime_today_midnight() - timedelta(days=1) - - def write_and_upload(df): - df["date"] = pd.to_datetime(df["date"]) - df = df[df["date"].dt.date >= constants.BENCHMARK_START_DATE_DATETIME_DATE] - df = df[["id", "date", "value"]].astype(dtype=constants.RESOLUTION_FILE_COLUMN_DTYPE) - df.to_json(local_filename, orient="records", lines=True, date_format="iso") - gcp.storage.upload( - bucket_name=env.QUESTION_BANK_BUCKET, - local_filename=local_filename, - filename=remote_filename, - ) - - if os.path.exists(local_filename): - os.remove(local_filename) - gcp.storage.download_no_error_message_on_404( - bucket_name=env.QUESTION_BANK_BUCKET, - filename=remote_filename, - local_filename=local_filename, - ) - - if os.path.exists(local_filename): - df = pd.read_json( - local_filename, - lines=True, - dtype=constants.RESOLUTION_FILE_COLUMN_DTYPE, - convert_dates=False, - ) - else: - df = pd.DataFrame( - { - col: pd.Series( - dtype=( - constants.RESOLUTION_FILE_COLUMN_DTYPE[col] - if col in constants.RESOLUTION_FILE_COLUMN_DTYPE - else "object" - ) - ) - for col in constants.RESOLUTION_FILE_COLUMNS - } - ) - - if question["nullify_question"]: - logger.warning( - f"Nullifying question {question['id']}. Pushing np.nan values to resolution file." - ) - if df.empty: - df = pd.DataFrame(columns=constants.RESOLUTION_FILE_COLUMNS) - df = df[["id", "date", "value"]].astype(dtype=constants.RESOLUTION_FILE_COLUMN_DTYPE) - df.loc[0] = [question["id"], yesterday.date(), np.nan] - else: - df["value"] = np.nan - - write_and_upload(df) - return df - - if not df.empty and pd.to_datetime(df["date"].iloc[-1]).tz_localize("UTC") >= yesterday: - logger.info(f"{question['id']} is skipped because it's already up-to-date!") - # Check last date to see if we've already gotten the resolution value for today - return df - - df = get_historical_forecasts(df, question["id"]) - df.date = df["date"].dt.date - - if resolved: - df = df[df.date < pd.to_datetime(question["market_info_resolution_datetime"][:10]).date()] - resolution_row = pd.DataFrame( - { - "id": [question["id"]], - "date": [question["market_info_resolution_datetime"][:10]], - "value": [question["probability"]], - } - ) - df = pd.concat([df, resolution_row], ignore_index=True) - - write_and_upload(df) - - -def update_questions(dfq, dff): - """ - Update the dataframes with new or modified question data and new community predictions. - - Parameters: - - dfq (pd.DataFrame): DataFrame containing existing questions. - - dff (pd.DataFrame): DataFrame containing newly fetched questions. - - Returns: - - dfq (pd.DataFrame): DataFrame containing updated questions. - - The function updates dfq by either replacing existing questions with new data or adding new questions. - It also appends new community predictions to dfr for each question in all_questions_to_add. - """ - for question in dff.to_dict("records"): - create_resolution_file(question, question["resolved"]) - - # Marke nullified questions as resolved so they're not selected for the question set. - if question["nullify_question"]: - question["resolved"] = True - - del question["fetch_datetime"] - del question["probability"] - del question["nullify_question"] - - # Check if the question exists in dfq - if question["id"] in dfq["id"].values: - # Case 1: Update existing question - dfq_index = dfq.index[dfq["id"] == question["id"]].tolist()[0] - for key, value in question.items(): - dfq.at[dfq_index, key] = value - else: - # Case 2: Append new question - new_q_row = pd.DataFrame([question]) - dfq = pd.concat([dfq, new_q_row], ignore_index=True) - - return dfq - - -@decorator.log_runtime -def driver(_): - """Execute the main workflow of fetching, processing, and uploading questions.""" - # Download existing questions from cloud storage - dfq, dff = data_utils.get_data_from_cloud_storage( - SOURCE, return_question_data=True, return_fetch_data=True - ) - - # Update the existing questions - dfq = update_questions(dfq, dff) - - logger.info("Uploading to GCP...") - - # Save and upload - data_utils.upload_questions(dfq, SOURCE) - logger.info("Done.") - - -if __name__ == "__main__": - driver(None) diff --git a/src/questions/yfinance/update_questions/main.py b/src/questions/yfinance/update_questions/main.py index 91edcd37..9d953eb6 100644 --- a/src/questions/yfinance/update_questions/main.py +++ b/src/questions/yfinance/update_questions/main.py @@ -12,13 +12,14 @@ from helpers import constants, data_utils, dates, decorator, env # noqa: E402 sys.path.append(os.path.join(os.path.dirname(__file__), "../../../..")) -from sources.yfinance import TICKER_RENAMES # noqa: E402 +from sources._metadata import SOURCE_METADATA # noqa: E402 from utils import gcp # noqa: E402 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) SOURCE = "yfinance" +TICKER_RENAMES = SOURCE_METADATA["yfinance"]["ticker_renames"] def select_time_range(days_difference): diff --git a/src/sources/__init__.py b/src/sources/__init__.py index c7c71004..3571cab3 100644 --- a/src/sources/__init__.py +++ b/src/sources/__init__.py @@ -1,50 +1,14 @@ -"""Source registry — one instance per question source.""" - -from _fb_types import SourceType - -from .acled import AcledSource -from .dbnomics import DbnomicsSource -from .fred import FredSource -from .infer import InferSource -from .manifold import ManifoldSource -from .metaculus import MetaculusSource -from .polymarket import PolymarketSource -from .wikipedia import WikipediaSource -from .yfinance import YfinanceSource - -# Singletons -_acled = AcledSource() -_dbnomics = DbnomicsSource() -_fred = FredSource() -_infer = InferSource() -_manifold = ManifoldSource() -_metaculus = MetaculusSource() -_polymarket = PolymarketSource() -_wikipedia = WikipediaSource() -_yfinance = YfinanceSource() - -SOURCES = { - s.name: s - for s in [ - _acled, - _dbnomics, - _fred, - _infer, - _manifold, - _metaculus, - _polymarket, - _wikipedia, - _yfinance, - ] -} - -DATASET_SOURCES = { - name: src for name, src in sorted(SOURCES.items()) if src.source_type == SourceType.DATASET -} -MARKET_SOURCES = { - name: src for name, src in sorted(SOURCES.items()) if src.source_type == SourceType.MARKET -} - -ALL_SOURCE_NAMES = sorted(SOURCES.keys()) -DATASET_SOURCE_NAMES = sorted(DATASET_SOURCES.keys()) -MARKET_SOURCE_NAMES = sorted(MARKET_SOURCES.keys()) +"""Source metadata re-exports. Lightweight — no concrete source imports. + +For source instances: +- import from ``sources.registry`` for all sources (heavyweight, use cautiously) +- import from the specific source module for a single source (lighter weight) + e.g. ``from sources.infer import InferSource`` +""" + +from ._metadata import ( # noqa: F401 + ALL_SOURCE_NAMES, + DATASET_SOURCE_NAMES, + MARKET_SOURCE_NAMES, + SOURCE_METADATA, +) diff --git a/src/sources/_base.py b/src/sources/_base.py index e69523c4..7f1a8ee7 100644 --- a/src/sources/_base.py +++ b/src/sources/_base.py @@ -5,14 +5,16 @@ import logging from abc import ABC, abstractmethod from datetime import date -from typing import TYPE_CHECKING, ClassVar, Union +from typing import TYPE_CHECKING, Any, ClassVar, Union import numpy as np import pandas as pd -from _fb_types import NullifiedQuestion, SourceType +from _fb_types import NullifiedQuestion, SourceType, UpdateResult from _schemas import ResolutionFrame +from ._metadata import SOURCE_METADATA + if TYPE_CHECKING: from pandera.typing import DataFrame @@ -28,25 +30,55 @@ class BaseSource(ABC): """ name: ClassVar[str] - display_name: ClassVar[str] source_type: ClassVar[SourceType] + source_intro: ClassVar[str] + resolution_criteria: ClassVar[str] nullified_questions: ClassVar[list[NullifiedQuestion]] = [] resolution_schema: ClassVar[type] = ResolutionFrame def __init__(self) -> None: """Initialize with empty hash mapping.""" self.hash_mapping: dict[str, dict] = {} + self.api_key: str | None = None + + def _require_api_key(self) -> str: + """Return api_key or raise if not set.""" + if not self.api_key: + raise RuntimeError( + f"{self.__class__.__name__}.api_key must be set before calling " + "fetch() or update(). Set it in the orchestration layer." + ) + return self.api_key def __init_subclass__(cls, **kwargs): - """Enforce required ClassVars on concrete (non-intermediate) subclasses.""" + """Enforce required ClassVars and auto-populate from _metadata. + + Concrete subclasses must define ``name`` and ``source_type``. After + enforcement, any keys present in ``SOURCE_METADATA[cls.name]`` are set + as class attributes automatically (source_intro, resolution_criteria, + nullified_questions, etc.) so source files only need ``name``. + """ super().__init_subclass__(**kwargs) # Skip enforcement for DatasetSource / MarketSource (they're still abstract) if cls.__name__ in ("DatasetSource", "MarketSource"): return - for attr in ("name", "display_name", "source_type"): + for attr in ("name", "source_type"): if not hasattr(cls, attr) or getattr(cls, attr) is getattr(BaseSource, attr, None): raise TypeError(f"Concrete source {cls.__name__} must define ClassVar '{attr}'") + # Auto-populate from metadata + _REQUIRED_METADATA_KEYS = {"source_type", "source_intro", "resolution_criteria"} + name = getattr(cls, "name", None) + if name and name in SOURCE_METADATA: + meta = SOURCE_METADATA[name] + missing = _REQUIRED_METADATA_KEYS - meta.keys() + if missing: + raise TypeError( + f"SOURCE_METADATA['{name}'] missing required keys: {sorted(missing)}" + ) + for key, value in meta.items(): + setattr(cls, key, value) + # ------------------------------------------------------------------ # Public resolve interface # ------------------------------------------------------------------ @@ -155,6 +187,26 @@ def check_id(mid): df["id"].apply(check_id) + @abstractmethod + def fetch(self, **kwargs: Any) -> pd.DataFrame: + """Fetch raw data from external API. Return shape is source-specific.""" + ... + + @abstractmethod + def update( + self, + dfq: DataFrame[QuestionFrame], + dff: pd.DataFrame, + **kwargs: Any, + ) -> UpdateResult: + """Process fetched data into questions and resolution files. + + Args: + dfq (DataFrame[QuestionFrame]): Existing questions. + dff (pd.DataFrame): Freshly fetched data (source-specific schema). + """ + ... + # ------------------------------------------------------------------ # Static utility methods # ------------------------------------------------------------------ diff --git a/src/sources/_metadata.py b/src/sources/_metadata.py new file mode 100644 index 00000000..33fb582e --- /dev/null +++ b/src/sources/_metadata.py @@ -0,0 +1,366 @@ +"""Lightweight source metadata — no heavy deps. + +Single source of truth for source identity strings and name lists. +Importable by any Cloud Run Job without triggering source-specific dependencies. +""" + +from datetime import date + +from _fb_types import NullifiedQuestion, SourceType +from helpers.constants import BENCHMARK_START_DATE_DATETIME_DATE + +SOURCE_METADATA = { + "acled": { + "source_type": SourceType.DATASET, + "source_intro": ( + "The Armed Conflict Location & Event Data Project (ACLED) collects real-time data on " + "the locations, dates, actors, fatalities, and types of all reported political violence " + "and protest events around the world. You're going to predict how questions based on " + "this data will resolve." + ), + "resolution_criteria": ( + "Resolves to the value calculated from the ACLED dataset once the data is published." + ), + }, + "dbnomics": { + "source_type": SourceType.DATASET, + "source_intro": ( + "DBnomics collects data on topics such as population and living conditions, " + "environment and energy, agriculture, finance, trade and others from publicly " + "available resources, for example national and international statistical institutions, " + "researchers and private companies. You're going to predict how questions based on " + "this data will resolve." + ), + "resolution_criteria": "Resolves to the value found at {url} once the data is published.", + }, + "fred": { + "source_type": SourceType.DATASET, + "source_intro": ( + "The Federal Reserve Economic Data database (FRED) provides economic data from " + "national, international, public, and private sources.You're going to predict how " + "questions based on this data will resolve." + ), + "resolution_criteria": "Resolves to the value found at {url} once the data is published.", + "nullified_questions": [ + NullifiedQuestion( + id="AMERIBOR", + nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE, + ), + ], + }, + "infer": { + "source_type": SourceType.MARKET, + "source_intro": ( + "We would like you to predict the outcome of a prediction market. A prediction " + "market, in this context, is the aggregate of predictions submitted by users on the " + "website INFER Public. You're going to predict the probability that the market will " + "resolve as 'Yes'." + ), + "resolution_criteria": "Resolves to the outcome of the question found at {url}.", + }, + "manifold": { + "source_type": SourceType.MARKET, + "source_intro": ( + "We would like you to predict the outcome of a prediction market. A prediction " + "market, in this context, is the aggregate of predictions submitted by users on the " + "website Manifold. You're going to predict the probability that the market will " + "resolve as 'Yes'." + ), + "resolution_criteria": "Resolves to the outcome of the question found at {url}.", + }, + "metaculus": { + "source_type": SourceType.MARKET, + "source_intro": ( + "We would like you to predict the outcome of a prediction market. A prediction " + "market, in this context, is the aggregate of predictions submitted by users on the " + "website Metaculus. You're going to predict the probability that the market will " + "resolve as 'Yes'." + ), + "resolution_criteria": "Resolves to the outcome of the question found at {url}.", + }, + "polymarket": { + "source_type": SourceType.MARKET, + "source_intro": ( + "We would like you to predict the outcome of a prediction market. A prediction " + "market, in this context, is the aggregate of predictions submitted by users on the " + "website Polymarket. You're going to predict the probability that the market will " + "resolve as 'Yes'." + ), + "resolution_criteria": "Resolves to the outcome of the question found at {url}.", + # IDs for which it is no longer possible to fetch data on Polymarket + # (though it was once possible) + "nullified_questions": [ + NullifiedQuestion(id=nid, nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE) + for nid in sorted( + { + "0x525820c5314f4143091d05079a8d810ecc07c8d5c8954ec2e6b6e163e40de9cb", + "0x9b46e4d85db0b2cd29acc36b836e1dad6cd2ac4fe495643cca64f7b962b6ab24", + "0x1e4d38c9b9e4aa154e350099216f4d86d94f1277eaa0d22fd33f48c0402155d5", + "0x738a551b7e2680669ea268911b2dc2079d156c350e40dc847d2a00eb0c57cfc2", + "0x0edd688013e4d08dd5367b9171bf85c6df73f2a4f561ed3c8ce004271c8278b7", + "0x42b4e02c1e95ca7b5e8610c3c1fad1dff6c0a46d01de6ae12565df026e3fc5a6", + "0x4afb076c5d9dfe1c33bf300cfd9fb93a5a8d9bfce8fe2beaeccbde5f8c269fc1", + "0x5642824719fa2e4d164de9a9ddaa1b5ca4f6fc57483eb222bec54082ad0bb57c", + "0xd8bf9a22e052cc97b14047a48552f3bd0e2605654e4fe580f48fa65e98d8487f", + } + ) + ], + }, + "wikipedia": { + "source_type": SourceType.DATASET, + "source_intro": ( + "Wikipedia is an online encyclopedia created and edited by volunteers. You're going " + "to predict how questions based on data sourced from Wikipedia will resolve." + ), + "resolution_criteria": "Resolves to the value calculated from {url} on the resolution date.", + "nullified_questions": [ + # Name changed: "R. Vaishali" --> "Vaishali Rameshbabu" + NullifiedQuestion( + id="149b5a465d9640ee10afcd1c6dde90627a4b58918111c14455d369f304aae454", + nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE, + ), # noqa: B950 + NullifiedQuestion( + id="98e72a2d4c6daa0b0d8aee1d02a8628bbacf713f0e44b02f80a12b1dae1c618f", + nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE, + ), # noqa: B950 + # Name changed: "Erigaisi Arjun" --> "Arjun Erigaisi" + NullifiedQuestion( + id="b70970a0440d1b7dedde9220fb60ffe3f2ed8b00ef12b45341772046caa12092", + nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE, + ), # noqa: B950 + # Rameshbabu Praggnanandhaa — too many repeated name changes + # NB: _not_ nullifying ff153a13... (first asked 2025-05-25) or a987eef3... (2025-03-30) + NullifiedQuestion( + id="7687186d5e0807f8925a694beafb3d6e057978a9a01f0d1a3e0eaf1a49959e78", + nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE, + ), # noqa: B950 + NullifiedQuestion( + id="479a40c45087510f72ee43a77aaccf78d563361728151ed3aab9b2b186db0b72", + nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE, + ), # noqa: B950 + NullifiedQuestion( + id="4b9175c88f855ee0d0fc54640158fc7da10b7b2dcc4fe1053bd180ac1a72bf39", + nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE, + ), # noqa: B950 + # Virus common name changed from "Monkeypox" to "Mpox" + NullifiedQuestion( + id="f9323386a651ce67fc0da31285bee22a4ec53b8a2ea5220431ecb4560fb44c77", + nullification_start_date=date(2022, 8, 21), + ), # noqa: B950 + NullifiedQuestion( + id="3f04d0cfccd38b26e86c0939516c483eb31edf6aaa3a1eaaabe38a48f7a0996a", + nullification_start_date=date(2022, 8, 21), + ), # noqa: B950 + # Leinier Domínguez Pérez — too many repeated name changes + NullifiedQuestion( + id="c8cc0816ce50a7fc018eccb7e6ed19628dc1f56e1cda26aca4b8f09c4edc7beb", + nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE, + ), # noqa: B950 + NullifiedQuestion( + id="21f7534aaa7292ba1e71ed0d1ce0fc350febe64414083b4b60d35765781eab35", + nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE, + ), # noqa: B950 + NullifiedQuestion( + id="9ab6734c6bf88f28a8c71b9d73995541b351f2663a7d8331a2c56dd5116d78a3", + nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE, + ), # noqa: B950 + NullifiedQuestion( + id="a9783d8184c3f43668cc21417788be00fd4ff70eec91064c5539ed5ebb0019e8", + nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE, + ), # noqa: B950 + NullifiedQuestion( + id="fa118e263e1218af8bb24cf7f6dd1c68e179d430584adf5b9b37d1b8488932d8", + nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE, + ), # noqa: B950 + NullifiedQuestion( + id="60d86f26a5b1e6576d218076ae7a66bf0fadc0bfe042ff1adf875918cc8d2781", + nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE, + ), # noqa: B950 + NullifiedQuestion( + id="6f8a3d10d39d69ecbdb10db2fabb66d852af39b95ce1af9f48ce5d9fd0175d87", + nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE, + ), # noqa: B950 + NullifiedQuestion( + id="dfa2dc6d7511437365132459a03e4d7bc10632ffd78c145fb98496699647f968", + nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE, + ), # noqa: B950 + # Resolved keys from _TRANSFORM_ID_MAPPING — old erroneous IDs superseded + # Tatjana Schoenmaker, lost swimming WR + NullifiedQuestion( + id="25891a351e97154028edc8075558470a6ec21d6d37dbd75f74268ee1b48253bf", + nullification_start_date=date(2023, 7, 5), + ), # noqa: B950 + NullifiedQuestion( + id="94297b75a6d18445c35a179a860b810bf0be7b6f296c502cec7caab24c8c1775", + nullification_start_date=date(2023, 7, 5), + ), # noqa: B950 + # Anthony Ervin, lost swimming WR + NullifiedQuestion( + id="cf02d516cc8b14b7b2880baae0ca4d520b167fe271123e6adfeedaefb83a3ec5", + nullification_start_date=date(2023, 8, 8), + ), # noqa: B950 + NullifiedQuestion( + id="6358ab9dab0aa4b6fc2abe8aacf1b31c8cbed08d54557eb4982c230fe19fe774", + nullification_start_date=date(2023, 8, 8), + ), # noqa: B950 + # Michael Phelps, lost swimming WR + NullifiedQuestion( + id="eea4cb0741c001c18ec28a58f64fb02bfba72e776f2d9ef2257309269b119526", + nullification_start_date=date(2023, 8, 29), + ), # noqa: B950 + NullifiedQuestion( + id="234175128275d109b5ffe5f8a30f863f150051e892e56566f88936b961be1f2f", + nullification_start_date=date(2023, 8, 29), + ), # noqa: B950 + # Benedetta Pilato, lost swimming WR + NullifiedQuestion( + id="e4afa18eb3d8d08fbc37c114f876a93ddceac453da415512ef5d73c7d26f391d", + nullification_start_date=date(2023, 8, 29), + ), # noqa: B950 + NullifiedQuestion( + id="747aa3406023deab8175b051bac64b55c061d38c2aebc73c1ded759de7b0477a", + nullification_start_date=date(2023, 8, 29), + ), # noqa: B950 + # Zac Stubblety-Cook, lost swimming WR + NullifiedQuestion( + id="5b078ec5a0d0a51c3668c62fe93441bd177ad4c58a1ff1d50b62a8bf6bc609fe", + nullification_start_date=date(2023, 8, 29), + ), # noqa: B950 + NullifiedQuestion( + id="afd040f28eb27f973ba1dc2cfeb3f613a7c29a543b14cbab4ba8d44ca8eb0d36", + nullification_start_date=date(2023, 8, 29), + ), # noqa: B950 + # Federica Pellegrini, lost swimming WR + NullifiedQuestion( + id="6e295dc29db5dce0672097160d432e7a3af469317298cb3153d745b2270041f1", + nullification_start_date=date(2023, 8, 29), + ), # noqa: B950 + NullifiedQuestion( + id="f0054684e6c6c24c5595e5cdf8498ffc5479e82d26a8b0318af35a26cd9b9ce7", + nullification_start_date=date(2023, 8, 29), + ), # noqa: B950 + # Liu Xiang, lost swimming WR + NullifiedQuestion( + id="245eb0146484bad467bbdb3d0c871f30390fb1a902105f86c85ec4637c52a9f4", + nullification_start_date=date(2023, 10, 20), + ), # noqa: B950 + NullifiedQuestion( + id="e222aa0998ad2e53a4cbfbdb11f3d80dfd13a263b4748e4a6cd8f4b965f0506f", + nullification_start_date=date(2023, 10, 20), + ), # noqa: B950 + # Hunter Armstrong, lost swimming WR + NullifiedQuestion( + id="851337578d0bf07dc60b233f5ef2a49d0309c1728621dd7b4ac0724414887fde", + nullification_start_date=date(2023, 11, 13), + ), # noqa: B950 + NullifiedQuestion( + id="56e00c66d9d2bfa3dd3ad0656c81701e04033438f90320ba96a63b62e61a4ea5", + nullification_start_date=date(2023, 11, 13), + ), # noqa: B950 + # David Popovici, lost swimming WR + NullifiedQuestion( + id="646cd3619a16c273007816e559834682e19754dcaf7d0ecb6ffebe64d351f177", + nullification_start_date=date(2024, 3, 21), + ), # noqa: B950 + NullifiedQuestion( + id="0e0f5a6cf1ac926657d43b909af4d2fb27ba975dfe3a274fbe0930dcf667d499", + nullification_start_date=date(2024, 3, 21), + ), # noqa: B950 + # Mollie O'Callaghan, lost swimming WR + NullifiedQuestion( + id="ebb4e1e85bed81266e94dda8e84eafe1479d5697f850792d84b5fab7251f483f", + nullification_start_date=date(2024, 7, 18), + ), # noqa: B950 + NullifiedQuestion( + id="b4c4989ac25edfbb8510e8ffa9aeee70c0de0d82e22a360faac590304f67c575", + nullification_start_date=date(2024, 7, 18), + ), # noqa: B950 + # Sun Yang, lost swimming WR + NullifiedQuestion( + id="7558c5b4f539cc922552c4f18a9a5cdaccbc100d6108acf117e886bd9dc67857", + nullification_start_date=date(2024, 8, 4), + ), # noqa: B950 + NullifiedQuestion( + id="04bfcc27745a1813367fcb5aad43423db616dccff54c1cc929bd32de3f43a38a", + nullification_start_date=date(2024, 8, 4), + ), # noqa: B950 + # Kate Douglass, lost swimming WR + NullifiedQuestion( + id="eaf10e98fdc5ddd2227b212f1e446a1937a2e0529b8f89c9a2528cb469e7cc27", + nullification_start_date=date(2024, 11, 2), + ), # noqa: B950 + NullifiedQuestion( + id="c539c3ef6d2534204b4fc67a94b14eebc7c51f141fea3c30f337cb3ede390b11", + nullification_start_date=date(2024, 11, 2), + ), # noqa: B950 + # Katinka Hosszú, lost swimming WR + NullifiedQuestion( + id="2e88b046538e239140043da9471c2b4894615a12173c3a52ee707321acf2ed8d", + nullification_start_date=date(2025, 6, 10), + ), # noqa: B950 + NullifiedQuestion( + id="c4db6cf85ef3ef4165705b863f1491f2903df3a2534e2d4e25f57edcbdfaac4b", + nullification_start_date=date(2025, 6, 10), + ), # noqa: B950 + # Vaccine was created in 2023 but Wikipedia table had not been updated + NullifiedQuestion( + id="242926fea271734ef8d4920e532414b38dbfdf301516fd9f0c988abd0ce777dd", + nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE, + ), # noqa: B950 + # Ryan Lochte, lost swimming WR + NullifiedQuestion( + id="12486c21df689124f8fdad70760247dffe2b7696599748bcb5c7a738735285d5", + nullification_start_date=date(2025, 7, 30), + ), # noqa: B950 + NullifiedQuestion( + id="c6ee39b4504603aa5ddbe73f378d48d94ab128406e5dd1bbb70ead0207a43840", + nullification_start_date=date(2025, 7, 30), + ), # noqa: B950 + ], + }, + "yfinance": { + "source_type": SourceType.DATASET, + "source_intro": ( + "Yahoo Finance provides financial data on stocks, bonds, and currencies and also " + "offers news, commentary and tools for personal financial management. You're going " + "to predict how questions based on this data will resolve." + ), + "resolution_criteria": ( + "Resolves to the market close price at {url} for the resolution date. If the " + "resolution date coincides with a day the market is closed (weekend, holiday, etc.) " + "the previous market close price is used." + ), + # Stocks that were delisted (via acquisition, merger, or going private) while still in the + # question pool. nullification_start_date is the first calendar day after the last trading + # session so that question sets whose forecast_due_date falls on or after this date are + # nullified, while earlier sets continue to resolve to the final close price. + "nullified_questions": [ + NullifiedQuestion(id="MRO", nullification_start_date=date(2024, 11, 22)), + NullifiedQuestion(id="CTLT", nullification_start_date=date(2024, 12, 18)), + NullifiedQuestion(id="DFS", nullification_start_date=date(2025, 5, 19)), + NullifiedQuestion(id="JNPR", nullification_start_date=date(2025, 7, 2)), + NullifiedQuestion(id="ANSS", nullification_start_date=date(2025, 7, 17)), + NullifiedQuestion(id="HES", nullification_start_date=date(2025, 7, 18)), + NullifiedQuestion(id="PARA", nullification_start_date=date(2025, 8, 7)), + NullifiedQuestion(id="WBA", nullification_start_date=date(2025, 8, 28)), + NullifiedQuestion(id="K", nullification_start_date=date(2025, 12, 11)), + NullifiedQuestion(id="DAY", nullification_start_date=date(2026, 2, 4)), + ], + # Tickers that were renamed on yfinance while still in the question pool. yfinance serves + # all price history under the replacement ticker; the original ticker returns no data. The + # update_questions code fetches data under the replacement ticker and writes it to the + # original ticker's resolution file so that existing questions resolve correctly. + "ticker_renames": [ + {"original_ticker": "FI", "replacement_ticker": "FISV"}, + {"original_ticker": "MMC", "replacement_ticker": "MRSH"}, + ], + }, +} + +ALL_SOURCE_NAMES = sorted(SOURCE_METADATA.keys()) +DATASET_SOURCE_NAMES = sorted( + name for name, m in SOURCE_METADATA.items() if m["source_type"] == SourceType.DATASET +) +MARKET_SOURCE_NAMES = sorted( + name for name, m in SOURCE_METADATA.items() if m["source_type"] == SourceType.MARKET +) diff --git a/src/sources/acled.py b/src/sources/acled.py index a2a7726c..217c5433 100644 --- a/src/sources/acled.py +++ b/src/sources/acled.py @@ -10,7 +10,6 @@ import numpy as np import pandas as pd -from _fb_types import SourceType from _schemas import AcledResolutionFrame from ._dataset import DatasetSource @@ -22,8 +21,6 @@ class AcledSource(DatasetSource): """Armed Conflict Location & Event Data source with custom resolution logic.""" name: ClassVar[str] = "acled" - display_name: ClassVar[str] = "ACLED" - source_type: ClassVar[SourceType] = SourceType.DATASET resolution_schema: ClassVar[type] = AcledResolutionFrame def _resolve(self, df: pd.DataFrame, dfq: pd.DataFrame, dfr: pd.DataFrame) -> pd.DataFrame: @@ -148,3 +145,15 @@ def _id_hash(self, d: dict) -> str: def _id_unhash(self, hash_key: str): """Look up the original question dict from a hash key.""" return self.hash_mapping.get(hash_key) + + # ------------------------------------------------------------------ + # Fetch / update (not yet implemented) + # ------------------------------------------------------------------ + + def fetch(self, **kwargs): + """Fetch ACLED data from external API.""" + raise NotImplementedError + + def update(self, dfq, dff, **kwargs): + """Process fetched ACLED data into questions and resolution files.""" + raise NotImplementedError diff --git a/src/sources/dbnomics.py b/src/sources/dbnomics.py index 32d2f103..3294be82 100644 --- a/src/sources/dbnomics.py +++ b/src/sources/dbnomics.py @@ -4,8 +4,6 @@ from typing import ClassVar -from _fb_types import SourceType - from ._dataset import DatasetSource @@ -13,5 +11,11 @@ class DbnomicsSource(DatasetSource): """DBnomics economic data source.""" name: ClassVar[str] = "dbnomics" - display_name: ClassVar[str] = "DBnomics" - source_type: ClassVar[SourceType] = SourceType.DATASET + + def fetch(self, **kwargs): + """Fetch DBnomics data from external API.""" + raise NotImplementedError + + def update(self, dfq, dff, **kwargs): + """Process fetched DBnomics data into questions and resolution files.""" + raise NotImplementedError diff --git a/src/sources/fred.py b/src/sources/fred.py index 7bd616d5..776d92c2 100644 --- a/src/sources/fred.py +++ b/src/sources/fred.py @@ -4,23 +4,18 @@ from typing import ClassVar -from _fb_types import NullifiedQuestion, SourceType -from helpers.constants import BENCHMARK_START_DATE_DATETIME_DATE - from ._dataset import DatasetSource -NULLIFIED_IDS = [ - "AMERIBOR", -] - class FredSource(DatasetSource): """Federal Reserve Economic Data source.""" name: ClassVar[str] = "fred" - display_name: ClassVar[str] = "FRED" - source_type: ClassVar[SourceType] = SourceType.DATASET - nullified_questions: ClassVar[list[NullifiedQuestion]] = [ - NullifiedQuestion(id=nid, nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE) - for nid in NULLIFIED_IDS - ] + + def fetch(self, **kwargs): + """Fetch FRED data from external API.""" + raise NotImplementedError + + def update(self, dfq, dff, **kwargs): + """Process fetched FRED data into questions and resolution files.""" + raise NotImplementedError diff --git a/src/sources/infer.py b/src/sources/infer.py index 2d0f1ddc..7cf88c10 100644 --- a/src/sources/infer.py +++ b/src/sources/infer.py @@ -2,16 +2,483 @@ from __future__ import annotations -from typing import ClassVar +import logging +import time +from datetime import timedelta, timezone +from typing import Any, ClassVar -from _fb_types import SourceType +import backoff +import certifi +import numpy as np +import pandas as pd +import pandera.pandas as pa +import requests +from pandera.typing import DataFrame + +from _fb_types import UpdateResult +from _schemas import InferFetchFrame, QuestionFrame, ResolutionFrame +from helpers import constants, data_utils, dates from ._market import MarketSource +logger = logging.getLogger(__name__) + +_INFER_URL = "https://www.randforecastinginitiative.org" + class InferSource(MarketSource): """INFER Public prediction market source.""" name: ClassVar[str] = "infer" - display_name: ClassVar[str] = "INFER" - source_type: ClassVar[SourceType] = SourceType.MARKET + + # ------------------------------------------------------------------ + # Public: fetch + # ------------------------------------------------------------------ + + @pa.check_types + def fetch( + self, + *, + dfq: DataFrame[QuestionFrame] | None = None, + files_in_storage: list[str] | None = None, + ) -> DataFrame[InferFetchFrame]: + """Fetch questions from the INFER API. + + Args: + dfq (DataFrame[QuestionFrame] | None): Existing question bank. + files_in_storage (list[str] | None): Existing resolution file paths. + """ + self._require_api_key() + files_in_storage = files_in_storage or [] + + # Determine which existing questions need re-fetching + resolved_ids: list[str] = [] + unresolved_ids: list[str] = [] + if dfq is not None and not dfq.empty: + resolved_ids = dfq[dfq["resolved"]]["id"].tolist() + unresolved_ids = dfq[~dfq["resolved"]]["id"].tolist() + + logger.info(f"Number resolved_ids: {len(resolved_ids)}") + logger.info(f"Number unresolved_ids: {len(unresolved_ids)}") + + resolved_ids_without_files = [ + id for id in resolved_ids if f"{self.name}/{id}.jsonl" not in files_in_storage + ] + logger.info(f"resolved_ids_without_resolution_files: {resolved_ids_without_files}") + + all_existing_ids_to_fetch = unresolved_ids + resolved_ids_without_files + + # Fetch existing (potentially closed) questions + all_existing_questions = ( + self._fetch_questions_from_api(status="all", question_ids=all_existing_ids_to_fetch) + if all_existing_ids_to_fetch + else [] + ) + + # Fetch all active questions + all_active_questions = self._fetch_questions_from_api() + + # Filter active to binary questions with predictions + all_active_binary_questions = [ + q + for q in all_active_questions + if q["state"] == "active" + and q["type"] == "Forecast::YesNoQuestion" + and q["answers"][0]["predictions_count"] > 0 + ] + + # Deduplicate: active takes precedence + active_ids = {q["id"] for q in all_active_binary_questions} + all_existing_questions = [q for q in all_existing_questions if q["id"] not in active_ids] + + all_questions = all_active_binary_questions + all_existing_questions + logger.info(f"Number of questions fetched: {len(all_questions)}") + + # Transform to InferFetchFrame schema + current_time = dates.get_datetime_now() + rows = [self._transform_question(q, current_time) for q in all_questions] + + return pd.DataFrame(rows) + + # ------------------------------------------------------------------ + # Public: update + # ------------------------------------------------------------------ + + @pa.check_types + def update( + self, + dfq: DataFrame[QuestionFrame], + dff: DataFrame[InferFetchFrame], + *, + existing_resolution_files: dict[str, DataFrame[ResolutionFrame]] | None = None, + ) -> UpdateResult: + """Process fetched data into updated questions and resolution files. + + Args: + dfq (DataFrame[QuestionFrame]): Existing questions. + dff (DataFrame[InferFetchFrame]): Freshly fetched data. + existing_resolution_files (dict | None): Per-question existing resolution data. + """ + self._require_api_key() + existing_resolution_files = existing_resolution_files or {} + resolution_files: dict[str, pd.DataFrame] = {} + + for question in dff.to_dict("records"): + question_id = str(question["id"]) + + # Build/update resolution file + existing_df = existing_resolution_files.get(question_id) + df_res = self._build_resolution_file( + question=question, + resolved=question["resolved"], + existing_df=existing_df, + ) + resolution_files[question_id] = df_res + + # Mark nullified questions as resolved + if question["nullify_question"]: + question["resolved"] = True + + # Strip transient fields (not part of QuestionFrame) + del question["fetch_datetime"] + del question["probability"] + del question["nullify_question"] + + # Upsert into dfq + if question["id"] in dfq["id"].values: + dfq_index = dfq.index[dfq["id"] == question["id"]].tolist()[0] + for key, value in question.items(): + dfq.at[dfq_index, key] = value + else: + dfq = pd.concat([dfq, pd.DataFrame([question])], ignore_index=True) + + return UpdateResult( + dfq=dfq, + resolution_files=resolution_files, + ) + + # ------------------------------------------------------------------ + # Private: API calls + # ------------------------------------------------------------------ + + @backoff.on_exception( + backoff.expo, + requests.exceptions.RequestException, + max_time=300, + on_backoff=data_utils.print_error_info_handler, + ) + def _fetch_questions_from_api( + self, + *, + status: str = "active", + question_ids: list[str] | None = None, + ) -> list[dict]: + """Fetch paginated questions from the INFER API. + + Args: + status (str): "active" or "all". + question_ids (list[str] | None): If provided, fetch these specific IDs. + """ + api_key = self._require_api_key() + endpoint = _INFER_URL + "/api/v1/questions" + headers = {"Authorization": f"Bearer {api_key}"} + params: dict[str, Any] = {"page": 0, "status": status} + if question_ids is not None: + params.update({"status": "all", "ids": ",".join(sorted(question_ids))}) + + questions: list[dict] = [] + seen_ids: set = set() + while True: + response = requests.get( + endpoint, params=params, headers=headers, verify=certifi.where() + ) + if not response.ok: + logger.error(f"Request to Infer questions endpoint failed with params: {params}") + response.raise_for_status() + + new_questions = response.json().get("questions", []) + if not new_questions: + break + + for q in new_questions: + if q["id"] not in seen_ids: + questions.append(q) + seen_ids.add(q["id"]) + + params["page"] += 1 + + return questions + + def _get_historical_forecasts( + self, + current_df: DataFrame[ResolutionFrame] | None, + question_id: str, + ) -> DataFrame[ResolutionFrame]: + """Fetch historical prediction time series for a question. + + Args: + current_df (DataFrame[ResolutionFrame] | None): Existing resolution data. + question_id (str): INFER question ID. + """ + api_key = self._require_api_key() + endpoint = _INFER_URL + "/api/v1/prediction_sets" + params = {"question_id": question_id, "page": 0} + headers = {"Authorization": f"Bearer {api_key}"} + all_responses: list[dict] = [] + current_time = dates.get_datetime_today_midnight() + + # Determine cutoff: only fetch predictions newer than what we have + has_existing = current_df is not None and not current_df.empty + last_date = ( + pd.to_datetime(current_df["date"].iloc[-1]).tz_localize("UTC") + if has_existing + else constants.BENCHMARK_START_DATE_DATETIME.replace(tzinfo=timezone.utc) + ) + + while True: + try: + logger.info(f"Fetched page: {params['page']}, for question ID: {question_id}") + response = requests.get( + endpoint, params=params, headers=headers, verify=certifi.where() + ) + response.raise_for_status() + new_responses = response.json().get("prediction_sets", []) + all_responses.extend(new_responses) + if ( + not new_responses + or pd.to_datetime(new_responses[-1]["created_at"], utc=True) <= last_date + ): + break + params["page"] += 1 + except requests.exceptions.HTTPError as e: + if e.response.status_code != 429: + raise + logger.error("Rate limit reached, waiting 10s before retrying...") + time.sleep(10) + + # Extract (date, probability) from each prediction set + all_forecasts: list[tuple] = [] + for forecast in all_responses: + if not has_existing or pd.to_datetime(forecast["created_at"], utc=True) > last_date: + if len(forecast["predictions"]) == 2: + forecast_yes = forecast["predictions"][0] + if forecast_yes["answer_name"] == "No": + forecast_yes = forecast["predictions"][1] + elif len(forecast["predictions"]) == 1: + forecast_yes = forecast["predictions"][0] + + all_forecasts.append( + ( + dates.convert_zulu_to_iso(forecast["created_at"]), + forecast_yes["final_probability"], + ) + ) + + df = pd.DataFrame(all_forecasts, columns=["date", "value"]) + df["date"] = pd.to_datetime(df["date"]) + df = df[df["date"].dt.date < current_time.date()] + df["value"] = df["value"].astype(float) + df["id"] = question_id + + # Sort and convert to date-only + df_sorted = df.sort_values("date").reset_index(drop=True) + df_sorted["date"] = df_sorted["date"].dt.date + df_final = df_sorted[["id", "date", "value"]] + + # Merge with existing data + if not has_existing: + result_df = df_final.drop_duplicates(subset=["id", "date"], keep="last") + else: + current_df = current_df.copy() + current_df["date"] = pd.to_datetime(current_df["date"]).dt.date + current_df_final = current_df[["id", "date", "value"]] + result_df = ( + pd.concat([current_df_final, df_final], axis=0) + .sort_values(by=["date"], ascending=True) + .drop_duplicates(subset=["id", "date"], keep="last") + .reset_index(drop=True) + ) + + # Forward-fill missing dates + result_df.loc[:, "date"] = pd.to_datetime(result_df["date"]).dt.tz_localize("UTC") + result_df = result_df.infer_objects() + result_df = result_df.sort_values(by="date") + all_dates = pd.date_range( + start=result_df["date"].min(), + end=current_time - timedelta(days=1), + freq="D", + ) + result_df = result_df.set_index("date").reindex(all_dates, method="ffill").reset_index() + result_df["id"] = question_id + result_df.reset_index(inplace=True) + result_df.rename(columns={"index": "date"}, inplace=True) + + return result_df[["id", "date", "value"]] + + # ------------------------------------------------------------------ + # Private: resolution file building + # ------------------------------------------------------------------ + + def _build_resolution_file( + self, + question: dict, + resolved: bool, + existing_df: DataFrame[ResolutionFrame] | None = None, + ) -> DataFrame[ResolutionFrame]: + """Build or update a resolution file for a single question. + + Args: + question (dict): Must have 'id', 'nullify_question'. If resolved, must also + have 'market_info_resolution_datetime' and 'probability'. + resolved (bool): Whether the question has resolved. + existing_df (DataFrame[ResolutionFrame] | None): Existing resolution data. + """ + yesterday = dates.get_datetime_today_midnight() - timedelta(days=1) + + # --- Nullification --- + if question["nullify_question"]: + logger.warning( + f"Nullifying question {question['id']}. " + "Pushing np.nan values to resolution file." + ) + if existing_df is None or existing_df.empty: + return pd.DataFrame( + { + "id": [question["id"]], + "date": [str(yesterday.date())], + "value": [np.nan], + } + ) + else: + df = existing_df.copy() + df["value"] = np.nan + return self._finalize_resolution_df(df) + + # --- Already up-to-date check --- + if ( + existing_df is not None + and not existing_df.empty + and pd.to_datetime(existing_df["date"].iloc[-1]).tz_localize("UTC") >= yesterday + ): + logger.info(f"{question['id']} is skipped because it's already up-to-date!") + return existing_df + + # --- Fetch historical forecasts --- + df = self._get_historical_forecasts(existing_df, question["id"]) + df["date"] = df["date"].dt.date if hasattr(df["date"].dtype, "tz") else df["date"] + + # --- Handle resolved questions --- + if resolved: + resolution_date_str = question["market_info_resolution_datetime"][:10] + resolution_date = pd.to_datetime(resolution_date_str) + df["date"] = pd.to_datetime(df["date"]) + df = df[df["date"] < resolution_date] + resolution_row = pd.DataFrame( + { + "id": [question["id"]], + "date": [resolution_date_str], + "value": [question["probability"]], + } + ) + df = pd.concat([df, resolution_row], ignore_index=True) + + return self._finalize_resolution_df(df) + + @staticmethod + def _finalize_resolution_df(df: pd.DataFrame) -> DataFrame[ResolutionFrame]: + """Apply date filtering and return as validated ResolutionFrame. + + Args: + df (pd.DataFrame): Raw resolution data with id, date, value columns. + """ + df["date"] = pd.to_datetime(df["date"]) + df = df[df["date"].dt.date >= constants.BENCHMARK_START_DATE_DATETIME_DATE] + return ResolutionFrame.validate(df[["id", "date", "value"]]) + + # ------------------------------------------------------------------ + # Private: question transformation + # ------------------------------------------------------------------ + + @staticmethod + def _transform_question(q: dict, current_time: str) -> dict: + """Transform a single INFER API response to InferFetchFrame row. + + Args: + q (dict): Raw question dict from the INFER API. + current_time (str): ISO timestamp for fetch_datetime. + """ + nullify_question = q["type"] != "Forecast::YesNoQuestion" + + # --- Close datetime: min(scoring_end_time, ends_at) --- + scoring_end_time_str = ( + dates.convert_datetime_str_to_iso_utc(q["scoring_end_time"]) + if q["scoring_end_time"] + else "N/A" + ) + ended_at_str = dates.convert_zulu_to_iso(q["ends_at"]) if q["ends_at"] else "N/A" + final_closed_at_str = ( + "N/A" + if scoring_end_time_str == "N/A" and ended_at_str == "N/A" + else ( + ended_at_str + if scoring_end_time_str == "N/A" + else ( + scoring_end_time_str + if ended_at_str == "N/A" + else min(scoring_end_time_str, ended_at_str) + ) + ) + ) + + # --- Open datetime --- + scoring_start_time_str = ( + dates.convert_datetime_str_to_iso_utc(q["scoring_start_time"]) + if q["scoring_start_time"] + else "N/A" + ) + + # --- Resolution datetime: min(resolved_at, close_datetime) --- + resolved_at_str = dates.convert_zulu_to_iso(q["resolved_at"]) if q["resolved_at"] else "N/A" + final_resolved_str = ( + "N/A" + if resolved_at_str == "N/A" and final_closed_at_str == "N/A" + else ( + final_closed_at_str + if resolved_at_str == "N/A" + else ( + resolved_at_str + if final_closed_at_str == "N/A" + else min(resolved_at_str, final_closed_at_str) + ) + ) + ) + + # --- Probability --- + forecast_yes: Any = "N/A" + if len(q["answers"]) == 2 and not nullify_question: + yes_index = 0 if q["answers"][0]["name"].lower() == "yes" else 1 + forecast_yes = q["answers"][yes_index]["probability"] + + return { + "id": str(q["id"]), + "question": q["name"], + "background": q["description"], + "market_info_resolution_criteria": ( + " ".join([content["content"] for content in q["clarifications"]]) + if q["clarifications"] + else "N/A" + ), + "market_info_open_datetime": scoring_start_time_str, + "market_info_close_datetime": final_closed_at_str, + "url": f"{_INFER_URL}/questions/{q['id']}", + "resolved": q.get("resolved?", False), + "market_info_resolution_datetime": ( + "N/A" if not q.get("resolved?", False) else final_resolved_str + ), + "fetch_datetime": current_time, + "probability": forecast_yes, + "forecast_horizons": "N/A", + "freeze_datetime_value": forecast_yes, + "freeze_datetime_value_explanation": "The crowd forecast.", + "nullify_question": nullify_question, + } diff --git a/src/sources/manifold.py b/src/sources/manifold.py index 67ee60bb..693dc8fe 100644 --- a/src/sources/manifold.py +++ b/src/sources/manifold.py @@ -4,8 +4,6 @@ from typing import ClassVar -from _fb_types import SourceType - from ._market import MarketSource @@ -13,5 +11,11 @@ class ManifoldSource(MarketSource): """Manifold prediction market source.""" name: ClassVar[str] = "manifold" - display_name: ClassVar[str] = "Manifold" - source_type: ClassVar[SourceType] = SourceType.MARKET + + def fetch(self, **kwargs): + """Fetch Manifold data from external API.""" + raise NotImplementedError + + def update(self, dfq, dff, **kwargs): + """Process fetched Manifold data into questions and resolution files.""" + raise NotImplementedError diff --git a/src/sources/metaculus.py b/src/sources/metaculus.py index bee4bbd5..c980e047 100644 --- a/src/sources/metaculus.py +++ b/src/sources/metaculus.py @@ -4,8 +4,6 @@ from typing import ClassVar -from _fb_types import SourceType - from ._market import MarketSource @@ -13,5 +11,11 @@ class MetaculusSource(MarketSource): """Metaculus prediction market source.""" name: ClassVar[str] = "metaculus" - display_name: ClassVar[str] = "Metaculus" - source_type: ClassVar[SourceType] = SourceType.MARKET + + def fetch(self, **kwargs): + """Fetch Metaculus data from external API.""" + raise NotImplementedError + + def update(self, dfq, dff, **kwargs): + """Process fetched Metaculus data into questions and resolution files.""" + raise NotImplementedError diff --git a/src/sources/polymarket.py b/src/sources/polymarket.py index 0665957f..1a9e847f 100644 --- a/src/sources/polymarket.py +++ b/src/sources/polymarket.py @@ -4,33 +4,18 @@ from typing import ClassVar -from _fb_types import NullifiedQuestion, SourceType -from helpers.constants import BENCHMARK_START_DATE_DATETIME_DATE - from ._market import MarketSource -# These are question IDs for which it is no longer possible to fetch data on Polymarket -# (though it was once possible) -NULLIFIED_QUESTION_IDS = { - "0x525820c5314f4143091d05079a8d810ecc07c8d5c8954ec2e6b6e163e40de9cb", - "0x9b46e4d85db0b2cd29acc36b836e1dad6cd2ac4fe495643cca64f7b962b6ab24", - "0x1e4d38c9b9e4aa154e350099216f4d86d94f1277eaa0d22fd33f48c0402155d5", - "0x738a551b7e2680669ea268911b2dc2079d156c350e40dc847d2a00eb0c57cfc2", - "0x0edd688013e4d08dd5367b9171bf85c6df73f2a4f561ed3c8ce004271c8278b7", - "0x42b4e02c1e95ca7b5e8610c3c1fad1dff6c0a46d01de6ae12565df026e3fc5a6", - "0x4afb076c5d9dfe1c33bf300cfd9fb93a5a8d9bfce8fe2beaeccbde5f8c269fc1", - "0x5642824719fa2e4d164de9a9ddaa1b5ca4f6fc57483eb222bec54082ad0bb57c", - "0xd8bf9a22e052cc97b14047a48552f3bd0e2605654e4fe580f48fa65e98d8487f", -} - class PolymarketSource(MarketSource): """Polymarket prediction market source.""" name: ClassVar[str] = "polymarket" - display_name: ClassVar[str] = "Polymarket" - source_type: ClassVar[SourceType] = SourceType.MARKET - nullified_questions: ClassVar[list[NullifiedQuestion]] = [ - NullifiedQuestion(id=nid, nullification_start_date=BENCHMARK_START_DATE_DATETIME_DATE) - for nid in sorted(NULLIFIED_QUESTION_IDS) - ] + + def fetch(self, **kwargs): + """Fetch Polymarket data from external API.""" + raise NotImplementedError + + def update(self, dfq, dff, **kwargs): + """Process fetched Polymarket data into questions and resolution files.""" + raise NotImplementedError diff --git a/src/sources/registry.py b/src/sources/registry.py new file mode 100644 index 00000000..8cc70e9a --- /dev/null +++ b/src/sources/registry.py @@ -0,0 +1,51 @@ +"""Source registry — imports all source modules. Heavy deps. Use sparringly. + +Only import from here if you need source instances (e.g. func_resolve). +For metadata, use `from sources import xxx` or for individual sources +use `from sources.infer import InferSource` +""" + +from _fb_types import SourceType + +from .acled import AcledSource +from .dbnomics import DbnomicsSource +from .fred import FredSource +from .infer import InferSource +from .manifold import ManifoldSource +from .metaculus import MetaculusSource +from .polymarket import PolymarketSource +from .wikipedia import WikipediaSource +from .yfinance import YfinanceSource + +# Singletons +_acled = AcledSource() +_dbnomics = DbnomicsSource() +_fred = FredSource() +_infer = InferSource() +_manifold = ManifoldSource() +_metaculus = MetaculusSource() +_polymarket = PolymarketSource() +_wikipedia = WikipediaSource() +_yfinance = YfinanceSource() + +SOURCES = { + s.name: s + for s in [ + _acled, + _dbnomics, + _fred, + _infer, + _manifold, + _metaculus, + _polymarket, + _wikipedia, + _yfinance, + ] +} + +DATASET_SOURCES = { + name: src for name, src in sorted(SOURCES.items()) if src.source_type == SourceType.DATASET +} +MARKET_SOURCES = { + name: src for name, src in sorted(SOURCES.items()) if src.source_type == SourceType.MARKET +} diff --git a/src/sources/wikipedia.py b/src/sources/wikipedia.py index 73eae440..55027f41 100644 --- a/src/sources/wikipedia.py +++ b/src/sources/wikipedia.py @@ -12,10 +12,10 @@ import numpy as np import pandas as pd -from _fb_types import NullifiedQuestion, SourceType from helpers import constants, dates from ._dataset import DatasetSource +from ._metadata import SOURCE_METADATA logger = logging.getLogger(__name__) @@ -34,9 +34,6 @@ class WikipediaSource(DatasetSource): """Wikipedia dataset source with custom row-by-row resolution logic.""" name: ClassVar[str] = "wikipedia" - display_name: ClassVar[str] = "Wikipedia" - source_type: ClassVar[SourceType] = SourceType.DATASET - nullified_questions: ClassVar[list[NullifiedQuestion]] = [] # populated after _IDS_TO_NULLIFY def _resolve(self, df: pd.DataFrame, dfq: pd.DataFrame, dfr: pd.DataFrame) -> pd.DataFrame: """Resolve Wikipedia questions row by row.""" @@ -197,6 +194,18 @@ def _id_unhash(self, hash_key: str): hash_key = self._transform_id(hash_key) return self.hash_mapping.get(hash_key) + # ------------------------------------------------------------------ + # Fetch / update (not yet implemented) + # ------------------------------------------------------------------ + + def fetch(self, **kwargs): + """Fetch Wikipedia data.""" + raise NotImplementedError + + def update(self, dfq, dff, **kwargs): + """Process fetched Wikipedia data into questions and resolution files.""" + raise NotImplementedError + # flake8: noqa: B950 _TRANSFORM_ID_MAPPING = { @@ -301,228 +310,6 @@ def _id_unhash(self, hash_key: str): _IDS_TO_NULLIFY = [ - # Name changed after it was asked on a question set: "R. Vaishali" --> "Vaishali Rameshbabu" - { - "id": "149b5a465d9640ee10afcd1c6dde90627a4b58918111c14455d369f304aae454", - "nullify_start_date": constants.BENCHMARK_START_DATE_DATETIME_DATE, - }, - { - "id": "98e72a2d4c6daa0b0d8aee1d02a8628bbacf713f0e44b02f80a12b1dae1c618f", - "nullify_start_date": constants.BENCHMARK_START_DATE_DATETIME_DATE, - }, - # Name changed after it was asked on a question set: "Erigaisi Arjun" --> "Arjun Erigaisi", - { - "id": "b70970a0440d1b7dedde9220fb60ffe3f2ed8b00ef12b45341772046caa12092", - "nullify_start_date": constants.BENCHMARK_START_DATE_DATETIME_DATE, - }, - # Rameshbabu Praggnanandhaa, too many repeated name changes: - # Praggnanandhaa R, R Praggnanandhaa, R. Praggnanandhaa, Rameshbabu Praggnanandhaa - # At some point down the line we can combine these histories and resolve the questions that have - # been asked - # - # NB: _not_ nullifying "ff153a13090a11be47ca39fcf8f8e54ad7c8fae80d681d26b58cad0e02b2d9ed" or - # was first asked on 2025-05-25, which is after R Praggnanandhaa became the current name - # for Praggnanandhaa on the Wiki page. Same for - # "a987eef385663d96115ba6c113ffb3dc7e83affdcaa8c53421220e4e9e1f95f8" which was first asked - # on 2025-03-30 - { - "id": "7687186d5e0807f8925a694beafb3d6e057978a9a01f0d1a3e0eaf1a49959e78", - "nullify_start_date": constants.BENCHMARK_START_DATE_DATETIME_DATE, - }, - { - "id": "479a40c45087510f72ee43a77aaccf78d563361728151ed3aab9b2b186db0b72", # never asked - "nullify_start_date": constants.BENCHMARK_START_DATE_DATETIME_DATE, - }, - { - "id": "4b9175c88f855ee0d0fc54640158fc7da10b7b2dcc4fe1053bd180ac1a72bf39", # never asked - "nullify_start_date": constants.BENCHMARK_START_DATE_DATETIME_DATE, - }, - # Virus common name changed from "Monkeypox" to "Mpox" - { - "id": "f9323386a651ce67fc0da31285bee22a4ec53b8a2ea5220431ecb4560fb44c77", - "nullify_start_date": datetime(2022, 8, 21).date(), - }, - { - "id": "3f04d0cfccd38b26e86c0939516c483eb31edf6aaa3a1eaaabe38a48f7a0996a", - "nullify_start_date": datetime(2022, 8, 21).date(), - }, - # Leinier Domínguez Pérez, too many repeated name changes: - # Leinier Dominguez, Leinier Dominguez Pérez, Leinier Domínguez Pérez, Leinier Domínguez - # At some point down the line we can combine these histories and resolve the questions that have - # been asked - { - "id": "c8cc0816ce50a7fc018eccb7e6ed19628dc1f56e1cda26aca4b8f09c4edc7beb", - "nullify_start_date": constants.BENCHMARK_START_DATE_DATETIME_DATE, - }, - { - "id": "21f7534aaa7292ba1e71ed0d1ce0fc350febe64414083b4b60d35765781eab35", - "nullify_start_date": constants.BENCHMARK_START_DATE_DATETIME_DATE, - }, - { - "id": "9ab6734c6bf88f28a8c71b9d73995541b351f2663a7d8331a2c56dd5116d78a3", - "nullify_start_date": constants.BENCHMARK_START_DATE_DATETIME_DATE, - }, - { - "id": "a9783d8184c3f43668cc21417788be00fd4ff70eec91064c5539ed5ebb0019e8", - "nullify_start_date": constants.BENCHMARK_START_DATE_DATETIME_DATE, - }, - { - "id": "fa118e263e1218af8bb24cf7f6dd1c68e179d430584adf5b9b37d1b8488932d8", - "nullify_start_date": constants.BENCHMARK_START_DATE_DATETIME_DATE, - }, - { - "id": "60d86f26a5b1e6576d218076ae7a66bf0fadc0bfe042ff1adf875918cc8d2781", - "nullify_start_date": constants.BENCHMARK_START_DATE_DATETIME_DATE, - }, - { - "id": "6f8a3d10d39d69ecbdb10db2fabb66d852af39b95ce1af9f48ce5d9fd0175d87", - "nullify_start_date": constants.BENCHMARK_START_DATE_DATETIME_DATE, - }, - { - "id": "dfa2dc6d7511437365132459a03e4d7bc10632ffd78c145fb98496699647f968", - "nullify_start_date": constants.BENCHMARK_START_DATE_DATETIME_DATE, - }, - # End Leinier Domínguez Pérez nullifications - # - # Resolved keys from `_TRANSFORM_ID_MAPPING`. These are old, erroneous IDs that have been - # superseded that have erroneously been included in a question set. `nullify_start_date` is the - # date the correct (the value the keys map to above) ID's record went null. - # - # Tatjana Schoenmaker, lost swimming WR - { - "id": "25891a351e97154028edc8075558470a6ec21d6d37dbd75f74268ee1b48253bf", - "nullify_start_date": datetime(2023, 7, 5).date(), - }, - { - "id": "94297b75a6d18445c35a179a860b810bf0be7b6f296c502cec7caab24c8c1775", - "nullify_start_date": datetime(2023, 7, 5).date(), - }, - # Anthony Ervin, lost swimming WR - { - "id": "cf02d516cc8b14b7b2880baae0ca4d520b167fe271123e6adfeedaefb83a3ec5", - "nullify_start_date": datetime(2023, 8, 8).date(), - }, - { - "id": "6358ab9dab0aa4b6fc2abe8aacf1b31c8cbed08d54557eb4982c230fe19fe774", - "nullify_start_date": datetime(2023, 8, 8).date(), - }, - # Michael Phelps, lost swimming WR - { - "id": "eea4cb0741c001c18ec28a58f64fb02bfba72e776f2d9ef2257309269b119526", - "nullify_start_date": datetime(2023, 8, 29).date(), - }, - { - "id": "234175128275d109b5ffe5f8a30f863f150051e892e56566f88936b961be1f2f", - "nullify_start_date": datetime(2023, 8, 29).date(), - }, - # Benedetta Pilato, lost swimming WR - { - "id": "e4afa18eb3d8d08fbc37c114f876a93ddceac453da415512ef5d73c7d26f391d", - "nullify_start_date": datetime(2023, 8, 29).date(), - }, - { - "id": "747aa3406023deab8175b051bac64b55c061d38c2aebc73c1ded759de7b0477a", - "nullify_start_date": datetime(2023, 8, 29).date(), - }, - # Zac Stubblety-Cook, lost swimming WR - { - "id": "5b078ec5a0d0a51c3668c62fe93441bd177ad4c58a1ff1d50b62a8bf6bc609fe", - "nullify_start_date": datetime(2023, 8, 29).date(), - }, - { - "id": "afd040f28eb27f973ba1dc2cfeb3f613a7c29a543b14cbab4ba8d44ca8eb0d36", - "nullify_start_date": datetime(2023, 8, 29).date(), - }, - # Federica Pellegrini, lost swimming WR - { - "id": "6e295dc29db5dce0672097160d432e7a3af469317298cb3153d745b2270041f1", - "nullify_start_date": datetime(2023, 8, 29).date(), - }, - { - "id": "f0054684e6c6c24c5595e5cdf8498ffc5479e82d26a8b0318af35a26cd9b9ce7", - "nullify_start_date": datetime(2023, 8, 29).date(), - }, - # Liu Xiang, lost swimming WR - { - "id": "245eb0146484bad467bbdb3d0c871f30390fb1a902105f86c85ec4637c52a9f4", - "nullify_start_date": datetime(2023, 10, 20).date(), - }, - { - "id": "e222aa0998ad2e53a4cbfbdb11f3d80dfd13a263b4748e4a6cd8f4b965f0506f", - "nullify_start_date": datetime(2023, 10, 20).date(), - }, - # Hunter Armstrong, lost swimming WR - { - "id": "851337578d0bf07dc60b233f5ef2a49d0309c1728621dd7b4ac0724414887fde", - "nullify_start_date": datetime(2023, 11, 13).date(), - }, - { - "id": "56e00c66d9d2bfa3dd3ad0656c81701e04033438f90320ba96a63b62e61a4ea5", - "nullify_start_date": datetime(2023, 11, 13).date(), - }, - # David Popovici, lost swimming WR - { - "id": "646cd3619a16c273007816e559834682e19754dcaf7d0ecb6ffebe64d351f177", - "nullify_start_date": datetime(2024, 3, 21).date(), - }, - { - "id": "0e0f5a6cf1ac926657d43b909af4d2fb27ba975dfe3a274fbe0930dcf667d499", - "nullify_start_date": datetime(2024, 3, 21).date(), - }, - # Mollie O'Callaghan, lost swimming WR - { - "id": "ebb4e1e85bed81266e94dda8e84eafe1479d5697f850792d84b5fab7251f483f", - "nullify_start_date": datetime(2024, 7, 18).date(), - }, - { - "id": "b4c4989ac25edfbb8510e8ffa9aeee70c0de0d82e22a360faac590304f67c575", - "nullify_start_date": datetime(2024, 7, 18).date(), - }, - # Sun Yang, lost swimming WR - { - "id": "7558c5b4f539cc922552c4f18a9a5cdaccbc100d6108acf117e886bd9dc67857", - "nullify_start_date": datetime(2024, 8, 4).date(), - }, - { - "id": "04bfcc27745a1813367fcb5aad43423db616dccff54c1cc929bd32de3f43a38a", - "nullify_start_date": datetime(2024, 8, 4).date(), - }, - # Kate Douglass, lost swimming WR - { - "id": "eaf10e98fdc5ddd2227b212f1e446a1937a2e0529b8f89c9a2528cb469e7cc27", - "nullify_start_date": datetime(2024, 11, 2).date(), - }, - { - "id": "c539c3ef6d2534204b4fc67a94b14eebc7c51f141fea3c30f337cb3ede390b11", - "nullify_start_date": datetime(2024, 11, 2).date(), - }, - # Katinka Hosszú, lost swimming WR - { - "id": "2e88b046538e239140043da9471c2b4894615a12173c3a52ee707321acf2ed8d", - "nullify_start_date": datetime(2025, 6, 10).date(), - }, - { - "id": "c4db6cf85ef3ef4165705b863f1491f2903df3a2534e2d4e25f57edcbdfaac4b", - "nullify_start_date": datetime(2025, 6, 10).date(), - }, - # Vaccine was created in 2023 but Wikipedia table had not been updated - { - "id": "242926fea271734ef8d4920e532414b38dbfdf301516fd9f0c988abd0ce777dd", - "nullify_start_date": constants.BENCHMARK_START_DATE_DATETIME_DATE, - }, - # Ryan Lochte, lost swimming WR - { - "id": "12486c21df689124f8fdad70760247dffe2b7696599748bcb5c7a738735285d5", - "nullify_start_date": datetime(2025, 7, 30).date(), - }, - { - "id": "c6ee39b4504603aa5ddbe73f378d48d94ab128406e5dd1bbb70ead0207a43840", - "nullify_start_date": datetime(2025, 7, 30).date(), - }, -] - -# Populate nullified_questions now that _IDS_TO_NULLIFY is defined -WikipediaSource.nullified_questions = [ - NullifiedQuestion(id=entry["id"], nullification_start_date=entry["nullify_start_date"]) - for entry in _IDS_TO_NULLIFY + {"id": nq.id, "nullify_start_date": nq.nullification_start_date} + for nq in SOURCE_METADATA["wikipedia"]["nullified_questions"] ] diff --git a/src/sources/yfinance.py b/src/sources/yfinance.py index 8cb91cba..5d0eee13 100644 --- a/src/sources/yfinance.py +++ b/src/sources/yfinance.py @@ -2,44 +2,20 @@ from __future__ import annotations -from datetime import date from typing import ClassVar -from _fb_types import NullifiedQuestion, SourceType - from ._dataset import DatasetSource -# Stocks that were delisted (via acquisition, merger, or going private) while still in the question -# pool. nullification_start_date is the first calendar day after the last trading session so that -# question sets whose forecast_due_date falls on or after this date are nullified, while earlier -# sets continue to resolve to the final close price. -DELISTED_STOCKS = [ - NullifiedQuestion(id="MRO", nullification_start_date=date(2024, 11, 22)), - NullifiedQuestion(id="CTLT", nullification_start_date=date(2024, 12, 18)), - NullifiedQuestion(id="DFS", nullification_start_date=date(2025, 5, 19)), - NullifiedQuestion(id="JNPR", nullification_start_date=date(2025, 7, 2)), - NullifiedQuestion(id="ANSS", nullification_start_date=date(2025, 7, 17)), - NullifiedQuestion(id="HES", nullification_start_date=date(2025, 7, 18)), - NullifiedQuestion(id="PARA", nullification_start_date=date(2025, 8, 7)), - NullifiedQuestion(id="WBA", nullification_start_date=date(2025, 8, 28)), - NullifiedQuestion(id="K", nullification_start_date=date(2025, 12, 11)), - NullifiedQuestion(id="DAY", nullification_start_date=date(2026, 2, 4)), -] - -# Tickers that were renamed on yfinance while still in the question pool. yfinance serves all price -# history under the replacement ticker; the original ticker returns no data. The update_questions -# code fetches data under the replacement ticker and writes it to the original ticker's resolution -# file so that existing questions resolve correctly. -TICKER_RENAMES = [ - {"original_ticker": "FI", "replacement_ticker": "FISV"}, - {"original_ticker": "MMC", "replacement_ticker": "MRSH"}, -] - class YfinanceSource(DatasetSource): """Yahoo Finance financial data source.""" name: ClassVar[str] = "yfinance" - display_name: ClassVar[str] = "Yahoo Finance" - source_type: ClassVar[SourceType] = SourceType.DATASET - nullified_questions: ClassVar[list[NullifiedQuestion]] = DELISTED_STOCKS + + def fetch(self, **kwargs): + """Fetch Yahoo Finance data from external API.""" + raise NotImplementedError + + def update(self, dfq, dff, **kwargs): + """Process fetched Yahoo Finance data into questions and resolution files.""" + raise NotImplementedError diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 89a07fed..779e5b1e 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -9,6 +9,7 @@ from sources.acled import AcledSource from sources.fred import FredSource +from sources.infer import InferSource from sources.metaculus import MetaculusSource # --------------------------------------------------------------------------- @@ -68,6 +69,14 @@ def acled_source(): return AcledSource() +@pytest.fixture() +def infer_source(): + """Return an InferSource instance with a fake API key.""" + src = InferSource() + src.api_key = "test-key" + return src + + # --------------------------------------------------------------------------- # DataFrame factories # --------------------------------------------------------------------------- @@ -146,3 +155,100 @@ def make_acled_resolution_df(rows, event_columns=None): def make_question_set_df(rows): """Build a DataFrame with [id, source, resolution_dates] for explode_question_set.""" return pd.DataFrame(rows) + + +# --------------------------------------------------------------------------- +# INFER-specific factories +# --------------------------------------------------------------------------- + + +def make_infer_api_question(**overrides): + """Build a realistic INFER API question dict. Override specific fields as needed.""" + base = { + "id": 9999, + "name": "Will X happen by end of 2026?", + "description": "

Background text.

", + "clarifications": [], + "state": "active", + "type": "Forecast::YesNoQuestion", + "active?": True, + "binary?": False, + "resolved?": False, + "resolved_at": None, + "ends_at": "2026-06-01T04:00:00.000Z", + "starts_at": "2026-01-01T20:00:00.000Z", + "scoring_start_time": "2026-01-01T15:00:00.000-05:00", + "scoring_end_time": "2026-06-01T00:00:00.000-05:00", + "created_at": "2026-01-01T18:00:00.000Z", + "closed_at": None, + "voided_at": None, + "answers": [ + { + "id": 9001, + "name": "Yes", + "probability": 0.65, + "display_probability": "65%", + "predictions_count": 50, + "answer_name": "Yes", + }, + { + "id": 9002, + "name": "No", + "probability": 0.35, + "display_probability": "35%", + "predictions_count": 50, + "answer_name": "No", + }, + ], + } + base.update(overrides) + return base + + +def make_infer_prediction_set(created_at, yes_prob): + """Build a realistic INFER prediction set dict.""" + return { + "id": 999999, + "type": "Forecast::OpinionPoolPredictionSet", + "question_id": 9999, + "created_at": created_at, + "predictions": [ + { + "answer_name": "Yes", + "final_probability": yes_prob, + "forecasted_probability": yes_prob, + "starting_probability": yes_prob, + }, + { + "answer_name": "No", + "final_probability": round(1 - yes_prob, 4), + "forecasted_probability": round(1 - yes_prob, 4), + "starting_probability": round(1 - yes_prob, 4), + }, + ], + } + + +def make_infer_fetch_df(rows): + """Build a DataFrame matching InferFetchFrame schema.""" + defaults = { + "question": "N/A", + "background": "N/A", + "url": "N/A", + "resolved": False, + "forecast_horizons": "N/A", + "freeze_datetime_value": "N/A", + "freeze_datetime_value_explanation": "N/A", + "market_info_resolution_criteria": "N/A", + "market_info_open_datetime": "N/A", + "market_info_close_datetime": "N/A", + "market_info_resolution_datetime": "N/A", + "fetch_datetime": "2026-01-15T00:00:00+00:00", + "probability": 0.5, + "nullify_question": False, + } + df = pd.DataFrame(rows) + for col, default in defaults.items(): + if col not in df.columns: + df[col] = default + return df diff --git a/src/tests/test_base_source.py b/src/tests/test_base_source.py index dbbe46e4..edafc05c 100644 --- a/src/tests/test_base_source.py +++ b/src/tests/test_base_source.py @@ -19,7 +19,6 @@ class _StubSource(BaseSource): """Minimal concrete subclass for testing BaseSource.""" name = "stub" - display_name = "Stub" source_type = SourceType.DATASET def _resolve(self, df, dfq, dfr): @@ -27,12 +26,17 @@ def _resolve(self, df, dfq, dfr): df["resolved"] = True return df, [] + def fetch(self, **kwargs): + raise NotImplementedError + + def update(self, dfq, dff, **kwargs): + raise NotImplementedError + class _StubSourceWithNullified(BaseSource): """Concrete subclass with nullified questions.""" name = "stub_null" - display_name = "StubNull" source_type = SourceType.DATASET nullified_questions = [ NullifiedQuestion(id="null_q1", nullification_start_date=date(2024, 6, 1)), @@ -44,6 +48,12 @@ def _resolve(self, df, dfq, dfr): df["resolved"] = True return df, [] + def fetch(self, **kwargs): + raise NotImplementedError + + def update(self, dfq, dff, **kwargs): + raise NotImplementedError + # --------------------------------------------------------------------------- # __init_subclass__ @@ -57,17 +67,6 @@ def test_missing_name_raises(self): with pytest.raises(TypeError, match="must define ClassVar 'name'"): class _BadSource(BaseSource): - display_name = "Bad" - source_type = SourceType.DATASET - - def _resolve(self, df, dfq, dfr): - return df - - def test_missing_display_name_raises(self): - with pytest.raises(TypeError, match="must define ClassVar 'display_name'"): - - class _BadSource(BaseSource): - name = "bad" source_type = SourceType.DATASET def _resolve(self, df, dfq, dfr): @@ -77,7 +76,6 @@ def test_valid_concrete_source_ok(self): # Should not raise class _GoodSource(BaseSource): name = "good" - display_name = "Good" source_type = SourceType.MARKET def _resolve(self, df, dfq, dfr): diff --git a/src/tests/test_infer.py b/src/tests/test_infer.py new file mode 100644 index 00000000..e382c146 --- /dev/null +++ b/src/tests/test_infer.py @@ -0,0 +1,556 @@ +"""Tests for InferSource fetch/update logic.""" + +from datetime import date +from unittest.mock import Mock, patch + +import numpy as np +import pandas as pd +import pytest + +from _schemas import InferFetchFrame, QuestionFrame, ResolutionFrame +from sources.infer import InferSource + +from .conftest import ( + make_infer_api_question, + make_infer_fetch_df, + make_infer_prediction_set, + make_question_df, + make_resolution_df, +) + +# --------------------------------------------------------------------------- +# _transform_question (pure, no mocking) +# --------------------------------------------------------------------------- + + +class TestTransformQuestion: + """Tests for InferSource._transform_question static method.""" + + CURRENT_TIME = "2026-01-15T00:00:00+00:00" + + def test_standard_active_question(self): + """All fields populated, output matches InferFetchFrame schema.""" + q = make_infer_api_question() + row = InferSource._transform_question(q, self.CURRENT_TIME) + + assert row["id"] == "9999" + assert row["question"] == q["name"] + assert row["probability"] == 0.65 + assert row["nullify_question"] is False + assert row["resolved"] is False + assert row["market_info_resolution_datetime"] == "N/A" + assert row["fetch_datetime"] == self.CURRENT_TIME + # Verify schema compliance + df = pd.DataFrame([row]) + InferFetchFrame.validate(df) + + def test_resolved_question(self): + """Resolved question has resolution datetime set.""" + q = make_infer_api_question( + **{ + "resolved?": True, + "resolved_at": "2026-01-10T12:00:00.000Z", + "scoring_end_time": "2026-02-01T00:00:00.000-05:00", + } + ) + row = InferSource._transform_question(q, self.CURRENT_TIME) + + assert bool(row["resolved"]) is True + assert row["market_info_resolution_datetime"] != "N/A" + assert "2026-01-10" in row["market_info_resolution_datetime"] + + def test_non_binary_question_nullified(self): + """Non-YesNo question types get nullified.""" + q = make_infer_api_question(type="Forecast::MultipleChoiceQuestion") + row = InferSource._transform_question(q, self.CURRENT_TIME) + + assert row["nullify_question"] is True + assert row["probability"] == "N/A" + + def test_missing_datetime_fields(self): + """None datetimes produce N/A strings.""" + q = make_infer_api_question( + scoring_start_time=None, + scoring_end_time=None, + ends_at=None, + resolved_at=None, + ) + row = InferSource._transform_question(q, self.CURRENT_TIME) + + assert row["market_info_open_datetime"] == "N/A" + assert row["market_info_close_datetime"] == "N/A" + + def test_close_datetime_picks_earlier(self): + """Close datetime is min(scoring_end_time, ends_at).""" + q = make_infer_api_question( + scoring_end_time="2026-03-01T00:00:00.000-05:00", + ends_at="2026-06-01T04:00:00.000Z", + ) + row = InferSource._transform_question(q, self.CURRENT_TIME) + assert "2026-03" in row["market_info_close_datetime"] + + # Reverse: ends_at is earlier + q2 = make_infer_api_question( + scoring_end_time="2026-09-01T00:00:00.000-05:00", + ends_at="2026-06-01T04:00:00.000Z", + ) + row2 = InferSource._transform_question(q2, self.CURRENT_TIME) + assert "2026-06" in row2["market_info_close_datetime"] + + def test_resolution_datetime_picks_earlier(self): + """Resolution datetime is min(resolved_at, close_datetime).""" + q = make_infer_api_question( + **{ + "resolved?": True, + "resolved_at": "2026-02-01T00:00:00.000Z", + "scoring_end_time": "2026-06-01T00:00:00.000-05:00", + } + ) + row = InferSource._transform_question(q, self.CURRENT_TIME) + assert "2026-02-01" in row["market_info_resolution_datetime"] + + def test_answers_swapped_order(self): + """Extracts Yes probability even when No is first.""" + q = make_infer_api_question( + answers=[ + {"name": "No", "probability": 0.3, "predictions_count": 10}, + {"name": "Yes", "probability": 0.7, "predictions_count": 10}, + ] + ) + row = InferSource._transform_question(q, self.CURRENT_TIME) + assert row["probability"] == 0.7 + + def test_single_answer(self): + """Single-answer question still extracts probability.""" + q = make_infer_api_question( + answers=[{"name": "Yes", "probability": 0.8, "predictions_count": 5}] + ) + # Single answer → len != 2, so probability is N/A (binary check fails) + row = InferSource._transform_question(q, self.CURRENT_TIME) + assert row["probability"] == "N/A" + + def test_clarifications_joined(self): + """Multiple clarifications are joined into one string.""" + q = make_infer_api_question( + clarifications=[ + {"content": "Clarification 1."}, + {"content": "Clarification 2."}, + ] + ) + row = InferSource._transform_question(q, self.CURRENT_TIME) + assert "Clarification 1." in row["market_info_resolution_criteria"] + assert "Clarification 2." in row["market_info_resolution_criteria"] + + +# --------------------------------------------------------------------------- +# _finalize_resolution_df (pure, no mocking) +# --------------------------------------------------------------------------- + + +class TestFinalizeResolutionDf: + """Tests for InferSource._finalize_resolution_df static method.""" + + def test_filters_before_benchmark_start(self): + """Rows before BENCHMARK_START_DATE are dropped.""" + df = pd.DataFrame( + { + "id": ["A", "A", "A"], + "date": pd.to_datetime(["2020-01-01", "2024-06-01", "2024-07-01"]), + "value": [0.1, 0.2, 0.3], + } + ) + result = InferSource._finalize_resolution_df(df) + assert len(result) == 2 + assert result["value"].tolist() == [0.2, 0.3] + + def test_validates_schema(self): + """Output is a valid ResolutionFrame.""" + df = pd.DataFrame( + { + "id": ["A"], + "date": pd.to_datetime(["2024-06-01"]), + "value": [0.5], + } + ) + result = InferSource._finalize_resolution_df(df) + ResolutionFrame.validate(result) + + def test_only_keeps_id_date_value(self): + """Extra columns are stripped.""" + df = pd.DataFrame( + { + "id": ["A"], + "date": pd.to_datetime(["2024-06-01"]), + "value": [0.5], + "extra": ["junk"], + } + ) + result = InferSource._finalize_resolution_df(df) + assert list(result.columns) == ["id", "date", "value"] + + +# --------------------------------------------------------------------------- +# _build_resolution_file (mock _get_historical_forecasts) +# --------------------------------------------------------------------------- + + +class TestBuildResolutionFile: + """Tests for InferSource._build_resolution_file.""" + + def _question(self, **overrides): + base = { + "id": "200", + "nullify_question": False, + "market_info_resolution_datetime": "N/A", + "probability": 0.6, + } + base.update(overrides) + return base + + @patch.object(InferSource, "_get_historical_forecasts") + def test_nullified_no_existing(self, mock_hist, infer_source, freeze_today): + """Nullified question with no existing data returns single NaN row.""" + freeze_today(date(2026, 1, 15)) + q = self._question(nullify_question=True) + df = infer_source._build_resolution_file(q, resolved=False, existing_df=None) + + assert len(df) == 1 + assert np.isnan(df["value"].iloc[0]) + mock_hist.assert_not_called() + + @patch.object(InferSource, "_get_historical_forecasts") + def test_nullified_with_existing(self, mock_hist, infer_source, freeze_today): + """Nullified question with existing data sets all values to NaN.""" + freeze_today(date(2026, 1, 15)) + existing = make_resolution_df( + [ + {"id": "200", "date": "2024-06-01", "value": 0.5}, + {"id": "200", "date": "2024-06-02", "value": 0.6}, + ] + ) + q = self._question(nullify_question=True) + df = infer_source._build_resolution_file(q, resolved=False, existing_df=existing) + + assert df["value"].isna().all() + mock_hist.assert_not_called() + + @patch.object(InferSource, "_get_historical_forecasts") + def test_already_up_to_date(self, mock_hist, infer_source, freeze_today): + """Skips API call if existing data covers through yesterday.""" + freeze_today(date(2026, 1, 15)) + existing = make_resolution_df( + [ + {"id": "200", "date": "2024-06-01", "value": 0.5}, + {"id": "200", "date": "2026-01-14", "value": 0.6}, + ] + ) + q = self._question() + df = infer_source._build_resolution_file(q, resolved=False, existing_df=existing) + + assert df.equals(existing) + mock_hist.assert_not_called() + + @patch.object(InferSource, "_get_historical_forecasts") + def test_fetches_when_stale(self, mock_hist, infer_source, freeze_today): + """Calls _get_historical_forecasts when existing data is stale.""" + freeze_today(date(2026, 1, 15)) + mock_hist.return_value = make_resolution_df( + [ + {"id": "200", "date": "2024-06-01", "value": 0.5}, + {"id": "200", "date": "2026-01-14", "value": 0.65}, + ] + ) + existing = make_resolution_df([{"id": "200", "date": "2024-06-01", "value": 0.5}]) + q = self._question() + df = infer_source._build_resolution_file(q, resolved=False, existing_df=existing) + + assert not df.empty + mock_hist.assert_called_once() + + @patch.object(InferSource, "_get_historical_forecasts") + def test_resolved_truncates_and_appends(self, mock_hist, infer_source, freeze_today): + """Resolved question truncates at resolution date and appends final row.""" + freeze_today(date(2026, 1, 15)) + mock_hist.return_value = make_resolution_df( + [ + {"id": "200", "date": "2024-06-01", "value": 0.4}, + {"id": "200", "date": "2026-01-10", "value": 0.6}, + {"id": "200", "date": "2026-01-12", "value": 0.7}, + ] + ) + q = self._question( + market_info_resolution_datetime="2026-01-11T00:00:00+00:00", + probability=1.0, + ) + df = infer_source._build_resolution_file(q, resolved=True, existing_df=None) + + # Should have rows up to resolution date + assert not df.empty + # Last row should be the resolution value + assert float(df.iloc[-1]["value"]) == 1.0 + + +# --------------------------------------------------------------------------- +# fetch() (mock _fetch_questions_from_api) +# --------------------------------------------------------------------------- + + +class TestFetch: + """Tests for InferSource.fetch.""" + + @patch.object(InferSource, "_fetch_questions_from_api") + def test_basic_fetch(self, mock_api, infer_source): + """Returns InferFetchFrame with correct rows.""" + mock_api.return_value = [ + make_infer_api_question(id=200), + make_infer_api_question(id=201), + ] + dff = infer_source.fetch() + + assert len(dff) == 2 + InferFetchFrame.validate(dff) + + @patch.object(InferSource, "_fetch_questions_from_api") + def test_active_filter(self, mock_api, infer_source): + """Only active binary questions with predictions pass the filter.""" + mock_api.return_value = [ + make_infer_api_question(id=1, state="active"), + make_infer_api_question(id=2, state="closed"), # filtered out + make_infer_api_question(id=3, type="Forecast::MultipleChoiceQuestion"), # filtered out + make_infer_api_question( + id=4, + answers=[ + {"name": "Yes", "probability": 0.5, "predictions_count": 0}, + {"name": "No", "probability": 0.5, "predictions_count": 0}, + ], + ), # filtered out (no predictions) + ] + dff = infer_source.fetch() + assert len(dff) == 1 + assert dff.iloc[0]["id"] == "1" + + @patch.object(InferSource, "_fetch_questions_from_api") + def test_deduplication_active_wins(self, mock_api, infer_source): + """When same ID appears in both active and existing, active version wins.""" + mock_api.side_effect = [ + [make_infer_api_question(id=100, state="closed")], # existing re-fetch + [make_infer_api_question(id=100, state="active")], # active fetch + ] + dfq = make_question_df([{"id": "100", "resolved": False}]) + dff = infer_source.fetch(dfq=dfq, files_in_storage=[]) + + assert len(dff) == 1 + + @patch.object(InferSource, "_fetch_questions_from_api") + def test_resolved_without_files_refetched(self, mock_api, infer_source): + """Resolved questions missing resolution files are re-fetched.""" + mock_api.side_effect = [ + [make_infer_api_question(id=100, state="resolved", **{"resolved?": True})], + [], # no active + ] + dfq = make_question_df([{"id": "100", "resolved": True}]) + # No resolution file in storage → should re-fetch + dff = infer_source.fetch(dfq=dfq, files_in_storage=[]) + + assert len(dff) == 1 + mock_api.assert_any_call(status="all", question_ids=["100"]) + + @patch.object(InferSource, "_fetch_questions_from_api") + def test_empty_dfq(self, mock_api, infer_source): + """Works with no existing questions.""" + mock_api.side_effect = [ + [make_infer_api_question(id=300)], + ] + dff = infer_source.fetch(dfq=None, files_in_storage=[]) + assert len(dff) == 1 + + def test_api_key_required(self): + """Raises RuntimeError if api_key not set.""" + src = InferSource() # no api_key + with pytest.raises(RuntimeError, match="api_key must be set"): + src.fetch() + + +# --------------------------------------------------------------------------- +# update() (mock _build_resolution_file) +# --------------------------------------------------------------------------- + + +class TestUpdate: + """Tests for InferSource.update.""" + + @patch.object(InferSource, "_build_resolution_file") + def test_basic_update(self, mock_build, infer_source): + """Returns UpdateResult with valid dfq and resolution files.""" + mock_build.return_value = make_resolution_df( + [{"id": "200", "date": "2024-06-01", "value": 0.65}] + ) + dfq = make_question_df([{"id": "100"}]) + dff = make_infer_fetch_df([{"id": "200"}]) + + result = infer_source.update(dfq, dff) + + assert "200" in result.dfq["id"].values + assert "200" in result.resolution_files + QuestionFrame.validate(result.dfq) + + @patch.object(InferSource, "_build_resolution_file") + def test_new_question_inserted(self, mock_build, infer_source): + """Question not in dfq gets appended.""" + mock_build.return_value = make_resolution_df( + [{"id": "300", "date": "2024-06-01", "value": 0.5}] + ) + dfq = make_question_df([{"id": "100"}]) + dff = make_infer_fetch_df([{"id": "300"}]) + + result = infer_source.update(dfq, dff) + assert len(result.dfq) == 2 + assert set(result.dfq["id"].tolist()) == {"100", "300"} + + @patch.object(InferSource, "_build_resolution_file") + def test_existing_question_updated(self, mock_build, infer_source): + """Existing question fields are updated in place.""" + mock_build.return_value = make_resolution_df( + [{"id": "100", "date": "2024-06-01", "value": 0.5}] + ) + dfq = make_question_df([{"id": "100", "question": "Old text"}]) + dff = make_infer_fetch_df([{"id": "100", "question": "New text"}]) + + result = infer_source.update(dfq, dff) + assert len(result.dfq) == 1 + assert result.dfq.iloc[0]["question"] == "New text" + + @patch.object(InferSource, "_build_resolution_file") + def test_nullified_marked_resolved(self, mock_build, infer_source): + """Nullified questions are marked as resolved in dfq.""" + mock_build.return_value = make_resolution_df( + [{"id": "200", "date": "2024-06-01", "value": np.nan}] + ) + dfq = make_question_df([{"id": "100"}]) + dff = make_infer_fetch_df([{"id": "200", "nullify_question": True}]) + + result = infer_source.update(dfq, dff) + row = result.dfq[result.dfq["id"] == "200"].iloc[0] + assert bool(row["resolved"]) is True + + @patch.object(InferSource, "_build_resolution_file") + def test_transient_fields_stripped(self, mock_build, infer_source): + """fetch_datetime, probability, nullify_question not in output dfq.""" + mock_build.return_value = make_resolution_df( + [{"id": "200", "date": "2024-06-01", "value": 0.5}] + ) + dfq = make_question_df([{"id": "placeholder"}]).iloc[:0] + dff = make_infer_fetch_df([{"id": "200"}]) + + result = infer_source.update(dfq, dff) + for col in ["fetch_datetime", "probability", "nullify_question"]: + assert col not in result.dfq.columns + + def test_api_key_required(self): + """Raises RuntimeError if api_key not set.""" + src = InferSource() + dfq = make_question_df([{"id": "100"}]) + dff = make_infer_fetch_df([{"id": "200"}]) + with pytest.raises(RuntimeError, match="api_key must be set"): + src.update(dfq, dff) + + +# --------------------------------------------------------------------------- +# _get_historical_forecasts (mock requests.get) +# --------------------------------------------------------------------------- + + +class TestGetHistoricalForecasts: + """Tests for InferSource._get_historical_forecasts.""" + + def _mock_response(self, prediction_sets): + resp = Mock() + resp.ok = True + resp.json.return_value = {"prediction_sets": prediction_sets} + resp.raise_for_status = Mock() + return resp + + @patch("sources.infer.requests.get") + def test_basic_fetch_no_existing(self, mock_get, infer_source, freeze_today): + """Builds time series from scratch.""" + freeze_today(date(2026, 1, 15)) + mock_get.side_effect = [ + self._mock_response( + [ + make_infer_prediction_set("2026-01-10T12:00:00.000Z", 0.4), + make_infer_prediction_set("2026-01-12T14:00:00.000Z", 0.6), + ] + ), + self._mock_response([]), # empty page stops pagination + ] + + df = infer_source._get_historical_forecasts(None, "200") + + assert not df.empty + assert list(df.columns) == ["id", "date", "value"] + assert (df["id"] == "200").all() + # Should have forward-filled dates between 10th and 14th + assert len(df) >= 4 + + @patch("sources.infer.requests.get") + def test_incremental_with_existing(self, mock_get, infer_source, freeze_today): + """Only fetches newer predictions when existing data provided.""" + freeze_today(date(2026, 1, 15)) + existing = make_resolution_df( + [ + {"id": "200", "date": "2026-01-10", "value": 0.4}, + {"id": "200", "date": "2026-01-11", "value": 0.4}, + ] + ) + mock_get.side_effect = [ + self._mock_response([make_infer_prediction_set("2026-01-13T12:00:00.000Z", 0.7)]), + self._mock_response([]), + ] + + df = infer_source._get_historical_forecasts(existing, "200") + + assert not df.empty + # Should contain both old and new data, forward-filled + assert len(df) >= 4 + + @patch("sources.infer.requests.get") + def test_forward_fill_gaps(self, mock_get, infer_source, freeze_today): + """Missing dates between predictions are forward-filled.""" + freeze_today(date(2026, 1, 15)) + mock_get.side_effect = [ + self._mock_response( + [ + make_infer_prediction_set("2026-01-10T12:00:00.000Z", 0.3), + make_infer_prediction_set("2026-01-13T12:00:00.000Z", 0.8), + ] + ), + self._mock_response([]), + ] + + df = infer_source._get_historical_forecasts(None, "200") + + # Dates 10, 11, 12, 13, 14 should exist (14 = today-1) + dates_in_df = pd.to_datetime(df["date"]).dt.date.tolist() + assert date(2026, 1, 11) in dates_in_df # forward-filled + assert date(2026, 1, 12) in dates_in_df # forward-filled + + @patch("sources.infer.requests.get") + @patch("sources.infer.time.sleep") + def test_rate_limit_retry(self, mock_sleep, mock_get, infer_source, freeze_today): + """429 response triggers retry after sleep.""" + freeze_today(date(2026, 1, 15)) + + rate_limit_resp = Mock() + rate_limit_resp.raise_for_status.side_effect = __import__("requests").exceptions.HTTPError( + response=Mock(status_code=429) + ) + + ok_resp = self._mock_response([make_infer_prediction_set("2026-01-10T12:00:00.000Z", 0.5)]) + empty_resp = self._mock_response([]) + + mock_get.side_effect = [rate_limit_resp, ok_resp, empty_resp] + + df = infer_source._get_historical_forecasts(None, "200") + + assert not df.empty + mock_sleep.assert_called_once_with(10) diff --git a/src/tests/test_integration.py b/src/tests/test_integration.py index e6dd2278..fd48a858 100644 --- a/src/tests/test_integration.py +++ b/src/tests/test_integration.py @@ -10,7 +10,7 @@ from resolve._prepare import check_and_prepare_forecast_file, set_resolution_dates from resolve.explode_question_set import explode_question_set from resolve.resolve_all import resolve_all -from sources import SOURCES +from sources.registry import SOURCES from tests.conftest import make_question_df, make_question_set_df, make_resolution_df # --------------------------------------------------------------------------- diff --git a/src/tests/test_types_and_schemas.py b/src/tests/test_types_and_schemas.py index 56007125..44ac6178 100644 --- a/src/tests/test_types_and_schemas.py +++ b/src/tests/test_types_and_schemas.py @@ -6,7 +6,7 @@ import pytest from _fb_types import NullifiedQuestion, SourceQuestionBank, SourceType -from sources import SOURCES +from sources.registry import SOURCES # --------------------------------------------------------------------------- # _fb_types.py @@ -61,15 +61,15 @@ def test_with_hash_mapping(self): _EXPECTED_SOURCES = { - "acled": ("ACLED", SourceType.DATASET), - "dbnomics": ("DBnomics", SourceType.DATASET), - "fred": ("FRED", SourceType.DATASET), - "infer": ("INFER", SourceType.MARKET), - "manifold": ("Manifold", SourceType.MARKET), - "metaculus": ("Metaculus", SourceType.MARKET), - "polymarket": ("Polymarket", SourceType.MARKET), - "wikipedia": ("Wikipedia", SourceType.DATASET), - "yfinance": ("Yahoo Finance", SourceType.DATASET), + "acled": SourceType.DATASET, + "dbnomics": SourceType.DATASET, + "fred": SourceType.DATASET, + "infer": SourceType.MARKET, + "manifold": SourceType.MARKET, + "metaculus": SourceType.MARKET, + "polymarket": SourceType.MARKET, + "wikipedia": SourceType.DATASET, + "yfinance": SourceType.DATASET, } @@ -79,10 +79,8 @@ class TestConcreteSourceClassVars: @pytest.mark.parametrize("name", sorted(_EXPECTED_SOURCES.keys())) def test_source_name_and_type(self, name): source = SOURCES[name] - expected_display, expected_type = _EXPECTED_SOURCES[name] assert source.name == name - assert source.display_name == expected_display - assert source.source_type == expected_type + assert source.source_type == _EXPECTED_SOURCES[name] def test_all_sources_registered(self): assert set(SOURCES.keys()) == set(_EXPECTED_SOURCES.keys()) diff --git a/src/tests/test_yfinance.py b/src/tests/test_yfinance.py index ca665196..77789471 100644 --- a/src/tests/test_yfinance.py +++ b/src/tests/test_yfinance.py @@ -12,9 +12,13 @@ finalize_resolution_file, update_questions, ) -from sources.yfinance import DELISTED_STOCKS, TICKER_RENAMES, YfinanceSource +from sources._metadata import SOURCE_METADATA +from sources.yfinance import YfinanceSource from tests.conftest import make_forecast_df, make_question_df +DELISTED_STOCKS = SOURCE_METADATA["yfinance"]["nullified_questions"] +TICKER_RENAMES = SOURCE_METADATA["yfinance"]["ticker_renames"] + class TestTickerRenamesDefinition: """Test the TICKER_RENAMES list is correctly defined."""