diff --git a/.env.example b/.env.example index 86dd02e..4099c91 100644 --- a/.env.example +++ b/.env.example @@ -1,12 +1,10 @@ NCBI_API_KEY=your_key_here HF_TOKEN=your_key_here -OLLAMA_SERVER_IP=192.168.x.x -OLLAMA_SERVER_PORT=11434 -OLLAMA_MODEL=llama3.2 -OLLAMA_MAX_TOKENS=1024 -OLLAMA_TEMPERATURE=0.1 +GEMINI_API_KEY=your_key_here PINECONE_API_KEY=your_key_here PINECONE_INDEX_NAME=your_index_name +REDIS_HOST=localhost +REDIS_PORT=6379 MLFLOW_TRACKING_URI=http://localhost:5000 MLFLOW_ARTIFACT_LOCATION=logs/mlflow FASTAPI_HOST=0.0.0.0 diff --git a/requirements.txt b/requirements.txt index 85dd180..d7aa018 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,12 +23,13 @@ sentence-transformers>=2.2.0 # Agent framework langchain>=0.2.0 -langchain-openai>=0.1.0 +langchain-core>=0.3.83 +langchain-google-genai>=1.0.0 langchain-community>=0.2.0 langgraph>=0.1.0 # Redis caching redis>=5.0.0 -# OpenAI LLM -openai>=1.0.0 \ No newline at end of file +# General +numpy>=2.4.0 \ No newline at end of file diff --git a/src/agent/__init__.py b/src/agent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/agent/graph.py b/src/agent/graph.py new file mode 100644 index 0000000..54febf8 --- /dev/null +++ b/src/agent/graph.py @@ -0,0 +1,36 @@ +from langgraph.graph import StateGraph +from src.agent.nodes import check_cache, pubmed_retrieval, llm_generation, parse_claims, nli_scoring, \ + confidence_scoring, assembly, route_after_cache +from src.agent.state import AgentState + +graph = StateGraph(AgentState) + +# Add nodes +graph.add_node("check_cache", check_cache) +graph.add_node("pubmed_retrieval", pubmed_retrieval) +graph.add_node("llm_generation", llm_generation) +graph.add_node("parse_claims", parse_claims) +graph.add_node("nli_scoring", nli_scoring) +graph.add_node("confidence_scoring", confidence_scoring) +graph.add_node("assembly", assembly) + +# Conditional edge +graph.add_conditional_edges( + "check_cache", + route_after_cache, + {"pubmed_retrieval": "pubmed_retrieval", "llm_generation": "llm_generation"} +) + +# Add edges +graph.add_edge("pubmed_retrieval", "llm_generation") +graph.add_edge("llm_generation", "parse_claims") +graph.add_edge("parse_claims", "nli_scoring") +graph.add_edge("nli_scoring", "confidence_scoring") +graph.add_edge("confidence_scoring", "assembly") + +# Set entry and finish +graph.set_entry_point("check_cache") +graph.set_finish_point("assembly") + +# Compile +app = graph.compile() \ No newline at end of file diff --git a/src/agent/nodes.py b/src/agent/nodes.py new file mode 100644 index 0000000..4a8ddb6 --- /dev/null +++ b/src/agent/nodes.py @@ -0,0 +1,102 @@ +import numpy as np +from src.agent.state import AgentState +from src.retrieval.cache import get_cache, set_cache +from src.retrieval.vector_store import add_abstracts, query_abstracts +from src.retrieval.pubmed import search_pubmed +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_core.output_parsers import JsonOutputParser +from sentence_transformers import CrossEncoder +from src.core.config import settings + +_llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-lite", google_api_key=settings.GEMINI_API_KEY) +_nli_model = CrossEncoder("cross-encoder/nli-MiniLM2-L6-H768") + +def check_cache(state: AgentState): + cached_result = get_cache(state["query"]) + if cached_result: + return {"cache_hit": True, "abstracts": cached_result} + else: + return {"cache_hit": False} + +def route_after_cache(state: AgentState) -> str: + if state["cache_hit"]: + return "llm_generation" + return "pubmed_retrieval" + +def pubmed_retrieval(state: AgentState): + if not state["cache_hit"]: + results = search_pubmed(state["query"]) + add_abstracts(results) + abstracts = query_abstracts(state["query"]) + set_cache(state["query"], abstracts) + return {"abstracts": abstracts} + +def llm_generation(state: AgentState): + context = "\n\n".join([f"Title: {a['title']}\nAbstract: {a['abstract']}" + for a in state["abstracts"]]) + + prompt = f"""You are a clinical assistant in charge of extracting insights from medical literature. Use the following documentation to answer the query. + + Literature: + {context} + + Query: {state["query"]} + + Provide a detailed clinical response based solely on the provided literature.""" + + response = _llm.invoke(prompt) + return {"llm_response": response.content} + +def parse_claims(state: AgentState): + parser = JsonOutputParser() + + prompt = f"""Extract all discrete factual claims from the following clinical response. + Return ONLY a JSON array of strings, no other text. + Each claim should be a single verifiable factual statement. + + Response: + {state["llm_response"]} + + Return format: ["claim 1", "claim 2", "claim 3"]""" + + response = _llm.invoke(prompt) + claims = parser.parse(response.content) + return {"claims": claims} + +def nli_scoring(state: AgentState): + scored_claims = [] + labels = ["Contradicted", "Supported", "Unverifiable"] + for claim in state["claims"]: + best_score = -1 + best_result = None + + for abstract in state["abstracts"]: + scores = _nli_model.predict([(abstract["abstract"], claim)])[0] + label_idx = int(np.argmax(scores)) + if scores[label_idx] > best_score: + best_score = scores[label_idx] + best_result = { + "claim": claim, + "label": labels[label_idx], + "score": float(best_score), + "evidence": abstract["abstract"] + } + + scored_claims.append(best_result) + + return {"scored_claims": scored_claims} + +def confidence_scoring(state: AgentState): + weights = {"Supported": 1.0, "Unverifiable": 0.5, "Contradicted": 0.0} + score = np.mean([weights[claim["label"]] for claim in state["scored_claims"]]) + return {"confidence_score": score} + +def assembly(state: AgentState): + return {"final_response": { + "query": state["query"], + "response": state["llm_response"], + "confidence_score": state["confidence_score"], + "scored_claims": state["scored_claims"], + "abstracts": state["abstracts"] + } + } diff --git a/src/agent/state.py b/src/agent/state.py new file mode 100644 index 0000000..993084d --- /dev/null +++ b/src/agent/state.py @@ -0,0 +1,11 @@ +from typing import TypedDict, Optional + +class AgentState(TypedDict): + query: str + cache_hit: bool + abstracts: list[dict] + llm_response: Optional[str] + claims: Optional[list[str]] + scored_claims: Optional[list[dict]] + confidence_score: Optional[float] + final_response: Optional[dict] \ No newline at end of file diff --git a/src/core/config.py b/src/core/config.py index 5c210a8..d6fc066 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -8,17 +8,17 @@ class Settings(BaseSettings): # Hugging Face HF_TOKEN: str - # Ollama - OLLAMA_SERVER_IP: str - OLLAMA_SERVER_PORT: int = 11434 - OLLAMA_MODEL: str = "llama3.2" - OLLAMA_MAX_TOKENS: int = 1024 - OLLAMA_TEMPERATURE: float = 0.1 + # Gemini + GEMINI_API_KEY: str # Pinecone PINECONE_API_KEY: str PINECONE_INDEX_NAME: str + # Redis + REDIS_HOST: str + REDIS_PORT: str + # MLflow MLFLOW_TRACKING_URI: str = "http://localhost:5000" MLFLOW_ARTIFACT_LOCATION: str = "logs/mlflow" diff --git a/src/retrieval/__init__.py b/src/retrieval/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/retrieval/cache.py b/src/retrieval/cache.py new file mode 100644 index 0000000..e86e85a --- /dev/null +++ b/src/retrieval/cache.py @@ -0,0 +1,20 @@ +import redis +import json +from src.core.config import settings + +_client = None + +def get_client(): + global _client + if _client is None: + _client = redis.Redis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, decode_responses=True) + return _client + +def get_cache(key: str): + value = get_client().get(key) + if value is None: + return None + return json.loads(value) + +def set_cache(key: str, value: list[dict]): + get_client().setex(key, 86400, json.dumps(value)) # 86400 = 24 hours diff --git a/src/retrieval/vector_store.py b/src/retrieval/vector_store.py index a2ab6f8..6d7b570 100644 --- a/src/retrieval/vector_store.py +++ b/src/retrieval/vector_store.py @@ -2,8 +2,6 @@ from sentence_transformers import SentenceTransformer from src.core.config import settings import os -from src.retrieval.pubmed import search_pubmed -import sys os.environ["HUGGING_FACE_HUB_TOKEN"] = settings.HF_TOKEN model = SentenceTransformer("pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb") @@ -46,13 +44,4 @@ def query_abstracts(query: str, n_results: int = 5) -> list[dict]: top_k=n_results, include_metadata=True ) - return results - -def main(): - search_results = search_pubmed("myocardial infarction", max_results=5) - add_abstracts(search_results) - results = query_abstracts("chest pain treatment") - print(results) - -if __name__ == "__main__": - sys.exit(main()) + return [match["metadata"] for match in results["matches"]] diff --git a/tests/pinecone_check.py b/tests/pinecone_check.py new file mode 100644 index 0000000..38b981c --- /dev/null +++ b/tests/pinecone_check.py @@ -0,0 +1,12 @@ +from src.retrieval.pubmed import search_pubmed +from src.retrieval.vector_store import add_abstracts, query_abstracts +import sys + +def main(): + search_results = search_pubmed("pulmonary embolism", max_results=5) + add_abstracts(search_results) + results = query_abstracts("chest pain treatment") + print(results) + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/tests/pipeline_runner.py b/tests/pipeline_runner.py new file mode 100644 index 0000000..8642a98 --- /dev/null +++ b/tests/pipeline_runner.py @@ -0,0 +1,19 @@ +from src.agent.graph import app +import sys + +def main(): + result = app.invoke({ + "query": "What are the first line treatments for atrial fibrillation?", + "cache_hit": False, + "abstracts": [], + "llm_response": None, + "claims": None, + "scored_claims": None, + "confidence_score": None, + "final_response": None + }) + + print(result["final_response"]) + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file