From 2ae92e96f278260106abfee42caefdc19d21da12 Mon Sep 17 00:00:00 2001 From: AndrewVFranco <129307231+AndrewVFranco@users.noreply.github.com> Date: Wed, 8 Apr 2026 03:40:01 -0700 Subject: [PATCH 1/6] Add redis caching --- .env.example | 2 ++ src/core/config.py | 4 ++++ src/retrieval/cache.py | 20 ++++++++++++++++++++ 3 files changed, 26 insertions(+) create mode 100644 src/retrieval/cache.py diff --git a/.env.example b/.env.example index 86dd02e..b994bf1 100644 --- a/.env.example +++ b/.env.example @@ -7,6 +7,8 @@ OLLAMA_MAX_TOKENS=1024 OLLAMA_TEMPERATURE=0.1 PINECONE_API_KEY=your_key_here PINECONE_INDEX_NAME=your_index_name +REDIS_HOST=localhost +REDIS_PORT=6379 MLFLOW_TRACKING_URI=http://localhost:5000 MLFLOW_ARTIFACT_LOCATION=logs/mlflow FASTAPI_HOST=0.0.0.0 diff --git a/src/core/config.py b/src/core/config.py index 5c210a8..3fff649 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -19,6 +19,10 @@ class Settings(BaseSettings): PINECONE_API_KEY: str PINECONE_INDEX_NAME: str + # Redis + REDIS_HOST: str + REDIS_PORT: str + # MLflow MLFLOW_TRACKING_URI: str = "http://localhost:5000" MLFLOW_ARTIFACT_LOCATION: str = "logs/mlflow" diff --git a/src/retrieval/cache.py b/src/retrieval/cache.py new file mode 100644 index 0000000..e86e85a --- /dev/null +++ b/src/retrieval/cache.py @@ -0,0 +1,20 @@ +import redis +import json +from src.core.config import settings + +_client = None + +def get_client(): + global _client + if _client is None: + _client = redis.Redis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, decode_responses=True) + return _client + +def get_cache(key: str): + value = get_client().get(key) + if value is None: + return None + return json.loads(value) + +def set_cache(key: str, value: list[dict]): + get_client().setex(key, 86400, json.dumps(value)) # 86400 = 24 hours From b35f0cf0f94f4e1e6b7a93dd36f44eff6f99cd90 Mon Sep 17 00:00:00 2001 From: AndrewVFranco <129307231+AndrewVFranco@users.noreply.github.com> Date: Wed, 8 Apr 2026 05:56:18 -0700 Subject: [PATCH 2/6] Update project to use Gemini API --- .env.example | 6 +----- requirements.txt | 7 ++++--- src/core/config.py | 8 ++------ 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/.env.example b/.env.example index b994bf1..4099c91 100644 --- a/.env.example +++ b/.env.example @@ -1,10 +1,6 @@ NCBI_API_KEY=your_key_here HF_TOKEN=your_key_here -OLLAMA_SERVER_IP=192.168.x.x -OLLAMA_SERVER_PORT=11434 -OLLAMA_MODEL=llama3.2 -OLLAMA_MAX_TOKENS=1024 -OLLAMA_TEMPERATURE=0.1 +GEMINI_API_KEY=your_key_here PINECONE_API_KEY=your_key_here PINECONE_INDEX_NAME=your_index_name REDIS_HOST=localhost diff --git a/requirements.txt b/requirements.txt index 85dd180..d7aa018 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,12 +23,13 @@ sentence-transformers>=2.2.0 # Agent framework langchain>=0.2.0 -langchain-openai>=0.1.0 +langchain-core>=0.3.83 +langchain-google-genai>=1.0.0 langchain-community>=0.2.0 langgraph>=0.1.0 # Redis caching redis>=5.0.0 -# OpenAI LLM -openai>=1.0.0 \ No newline at end of file +# General +numpy>=2.4.0 \ No newline at end of file diff --git a/src/core/config.py b/src/core/config.py index 3fff649..d6fc066 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -8,12 +8,8 @@ class Settings(BaseSettings): # Hugging Face HF_TOKEN: str - # Ollama - OLLAMA_SERVER_IP: str - OLLAMA_SERVER_PORT: int = 11434 - OLLAMA_MODEL: str = "llama3.2" - OLLAMA_MAX_TOKENS: int = 1024 - OLLAMA_TEMPERATURE: float = 0.1 + # Gemini + GEMINI_API_KEY: str # Pinecone PINECONE_API_KEY: str From 03d9f7c42dbc510096c86a7ea23cce67c1deb865 Mon Sep 17 00:00:00 2001 From: AndrewVFranco <129307231+AndrewVFranco@users.noreply.github.com> Date: Wed, 8 Apr 2026 05:58:26 -0700 Subject: [PATCH 3/6] Update query_abstracts to return only the metadata dictionary --- src/retrieval/__init__.py | 0 src/retrieval/vector_store.py | 13 +------------ 2 files changed, 1 insertion(+), 12 deletions(-) create mode 100644 src/retrieval/__init__.py diff --git a/src/retrieval/__init__.py b/src/retrieval/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/retrieval/vector_store.py b/src/retrieval/vector_store.py index a2ab6f8..6d7b570 100644 --- a/src/retrieval/vector_store.py +++ b/src/retrieval/vector_store.py @@ -2,8 +2,6 @@ from sentence_transformers import SentenceTransformer from src.core.config import settings import os -from src.retrieval.pubmed import search_pubmed -import sys os.environ["HUGGING_FACE_HUB_TOKEN"] = settings.HF_TOKEN model = SentenceTransformer("pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb") @@ -46,13 +44,4 @@ def query_abstracts(query: str, n_results: int = 5) -> list[dict]: top_k=n_results, include_metadata=True ) - return results - -def main(): - search_results = search_pubmed("myocardial infarction", max_results=5) - add_abstracts(search_results) - results = query_abstracts("chest pain treatment") - print(results) - -if __name__ == "__main__": - sys.exit(main()) + return [match["metadata"] for match in results["matches"]] From 039466f7a099abf7e065da4bdbb2d03a6d5e24d0 Mon Sep 17 00:00:00 2001 From: AndrewVFranco <129307231+AndrewVFranco@users.noreply.github.com> Date: Wed, 8 Apr 2026 06:01:21 -0700 Subject: [PATCH 4/6] Add agent framework including nodes, AgentState class, and graph node interaction --- src/agent/__init__.py | 0 src/agent/graph.py | 36 +++++++++++++++ src/agent/nodes.py | 102 ++++++++++++++++++++++++++++++++++++++++++ src/agent/state.py | 11 +++++ 4 files changed, 149 insertions(+) create mode 100644 src/agent/__init__.py create mode 100644 src/agent/graph.py create mode 100644 src/agent/nodes.py create mode 100644 src/agent/state.py diff --git a/src/agent/__init__.py b/src/agent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/agent/graph.py b/src/agent/graph.py new file mode 100644 index 0000000..54febf8 --- /dev/null +++ b/src/agent/graph.py @@ -0,0 +1,36 @@ +from langgraph.graph import StateGraph +from src.agent.nodes import check_cache, pubmed_retrieval, llm_generation, parse_claims, nli_scoring, \ + confidence_scoring, assembly, route_after_cache +from src.agent.state import AgentState + +graph = StateGraph(AgentState) + +# Add nodes +graph.add_node("check_cache", check_cache) +graph.add_node("pubmed_retrieval", pubmed_retrieval) +graph.add_node("llm_generation", llm_generation) +graph.add_node("parse_claims", parse_claims) +graph.add_node("nli_scoring", nli_scoring) +graph.add_node("confidence_scoring", confidence_scoring) +graph.add_node("assembly", assembly) + +# Conditional edge +graph.add_conditional_edges( + "check_cache", + route_after_cache, + {"pubmed_retrieval": "pubmed_retrieval", "llm_generation": "llm_generation"} +) + +# Add edges +graph.add_edge("pubmed_retrieval", "llm_generation") +graph.add_edge("llm_generation", "parse_claims") +graph.add_edge("parse_claims", "nli_scoring") +graph.add_edge("nli_scoring", "confidence_scoring") +graph.add_edge("confidence_scoring", "assembly") + +# Set entry and finish +graph.set_entry_point("check_cache") +graph.set_finish_point("assembly") + +# Compile +app = graph.compile() \ No newline at end of file diff --git a/src/agent/nodes.py b/src/agent/nodes.py new file mode 100644 index 0000000..4a8ddb6 --- /dev/null +++ b/src/agent/nodes.py @@ -0,0 +1,102 @@ +import numpy as np +from src.agent.state import AgentState +from src.retrieval.cache import get_cache, set_cache +from src.retrieval.vector_store import add_abstracts, query_abstracts +from src.retrieval.pubmed import search_pubmed +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_core.output_parsers import JsonOutputParser +from sentence_transformers import CrossEncoder +from src.core.config import settings + +_llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-lite", google_api_key=settings.GEMINI_API_KEY) +_nli_model = CrossEncoder("cross-encoder/nli-MiniLM2-L6-H768") + +def check_cache(state: AgentState): + cached_result = get_cache(state["query"]) + if cached_result: + return {"cache_hit": True, "abstracts": cached_result} + else: + return {"cache_hit": False} + +def route_after_cache(state: AgentState) -> str: + if state["cache_hit"]: + return "llm_generation" + return "pubmed_retrieval" + +def pubmed_retrieval(state: AgentState): + if not state["cache_hit"]: + results = search_pubmed(state["query"]) + add_abstracts(results) + abstracts = query_abstracts(state["query"]) + set_cache(state["query"], abstracts) + return {"abstracts": abstracts} + +def llm_generation(state: AgentState): + context = "\n\n".join([f"Title: {a['title']}\nAbstract: {a['abstract']}" + for a in state["abstracts"]]) + + prompt = f"""You are a clinical assistant in charge of extracting insights from medical literature. Use the following documentation to answer the query. + + Literature: + {context} + + Query: {state["query"]} + + Provide a detailed clinical response based solely on the provided literature.""" + + response = _llm.invoke(prompt) + return {"llm_response": response.content} + +def parse_claims(state: AgentState): + parser = JsonOutputParser() + + prompt = f"""Extract all discrete factual claims from the following clinical response. + Return ONLY a JSON array of strings, no other text. + Each claim should be a single verifiable factual statement. + + Response: + {state["llm_response"]} + + Return format: ["claim 1", "claim 2", "claim 3"]""" + + response = _llm.invoke(prompt) + claims = parser.parse(response.content) + return {"claims": claims} + +def nli_scoring(state: AgentState): + scored_claims = [] + labels = ["Contradicted", "Supported", "Unverifiable"] + for claim in state["claims"]: + best_score = -1 + best_result = None + + for abstract in state["abstracts"]: + scores = _nli_model.predict([(abstract["abstract"], claim)])[0] + label_idx = int(np.argmax(scores)) + if scores[label_idx] > best_score: + best_score = scores[label_idx] + best_result = { + "claim": claim, + "label": labels[label_idx], + "score": float(best_score), + "evidence": abstract["abstract"] + } + + scored_claims.append(best_result) + + return {"scored_claims": scored_claims} + +def confidence_scoring(state: AgentState): + weights = {"Supported": 1.0, "Unverifiable": 0.5, "Contradicted": 0.0} + score = np.mean([weights[claim["label"]] for claim in state["scored_claims"]]) + return {"confidence_score": score} + +def assembly(state: AgentState): + return {"final_response": { + "query": state["query"], + "response": state["llm_response"], + "confidence_score": state["confidence_score"], + "scored_claims": state["scored_claims"], + "abstracts": state["abstracts"] + } + } diff --git a/src/agent/state.py b/src/agent/state.py new file mode 100644 index 0000000..993084d --- /dev/null +++ b/src/agent/state.py @@ -0,0 +1,11 @@ +from typing import TypedDict, Optional + +class AgentState(TypedDict): + query: str + cache_hit: bool + abstracts: list[dict] + llm_response: Optional[str] + claims: Optional[list[str]] + scored_claims: Optional[list[dict]] + confidence_score: Optional[float] + final_response: Optional[dict] \ No newline at end of file From 8eb89ccaf7b0d52937aea8118cfeb053c3838c92 Mon Sep 17 00:00:00 2001 From: AndrewVFranco <129307231+AndrewVFranco@users.noreply.github.com> Date: Wed, 8 Apr 2026 06:03:02 -0700 Subject: [PATCH 5/6] Add test for pinecone database functionality --- tests/pinecone_check.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 tests/pinecone_check.py diff --git a/tests/pinecone_check.py b/tests/pinecone_check.py new file mode 100644 index 0000000..38b981c --- /dev/null +++ b/tests/pinecone_check.py @@ -0,0 +1,12 @@ +from src.retrieval.pubmed import search_pubmed +from src.retrieval.vector_store import add_abstracts, query_abstracts +import sys + +def main(): + search_results = search_pubmed("pulmonary embolism", max_results=5) + add_abstracts(search_results) + results = query_abstracts("chest pain treatment") + print(results) + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file From 286809f4bdb0947281ec41f1958add6f8d95a036 Mon Sep 17 00:00:00 2001 From: AndrewVFranco <129307231+AndrewVFranco@users.noreply.github.com> Date: Wed, 8 Apr 2026 06:03:45 -0700 Subject: [PATCH 6/6] Add test for agent functionality --- tests/pipeline_runner.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 tests/pipeline_runner.py diff --git a/tests/pipeline_runner.py b/tests/pipeline_runner.py new file mode 100644 index 0000000..8642a98 --- /dev/null +++ b/tests/pipeline_runner.py @@ -0,0 +1,19 @@ +from src.agent.graph import app +import sys + +def main(): + result = app.invoke({ + "query": "What are the first line treatments for atrial fibrillation?", + "cache_hit": False, + "abstracts": [], + "llm_response": None, + "claims": None, + "scored_claims": None, + "confidence_score": None, + "final_response": None + }) + + print(result["final_response"]) + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file