Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
NCBI_API_KEY=your_key_here
HF_TOKEN=your_key_here
OLLAMA_SERVER_IP=192.168.x.x
OLLAMA_SERVER_PORT=11434
OLLAMA_MODEL=llama3.2
OLLAMA_MAX_TOKENS=1024
OLLAMA_TEMPERATURE=0.1
GEMINI_API_KEY=your_key_here
PINECONE_API_KEY=your_key_here
PINECONE_INDEX_NAME=your_index_name
REDIS_HOST=localhost
REDIS_PORT=6379
MLFLOW_TRACKING_URI=http://localhost:5000
MLFLOW_ARTIFACT_LOCATION=logs/mlflow
FASTAPI_HOST=0.0.0.0
Expand Down
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ sentence-transformers>=2.2.0

# Agent framework
langchain>=0.2.0
langchain-openai>=0.1.0
langchain-core>=0.3.83
langchain-google-genai>=1.0.0
langchain-community>=0.2.0
langgraph>=0.1.0

# Redis caching
redis>=5.0.0

# OpenAI LLM
openai>=1.0.0
# General
numpy>=2.4.0
Empty file added src/agent/__init__.py
Empty file.
36 changes: 36 additions & 0 deletions src/agent/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from langgraph.graph import StateGraph
from src.agent.nodes import check_cache, pubmed_retrieval, llm_generation, parse_claims, nli_scoring, \
confidence_scoring, assembly, route_after_cache
from src.agent.state import AgentState

graph = StateGraph(AgentState)

# Add nodes
graph.add_node("check_cache", check_cache)
graph.add_node("pubmed_retrieval", pubmed_retrieval)
graph.add_node("llm_generation", llm_generation)
graph.add_node("parse_claims", parse_claims)
graph.add_node("nli_scoring", nli_scoring)
graph.add_node("confidence_scoring", confidence_scoring)
graph.add_node("assembly", assembly)

# Conditional edge
graph.add_conditional_edges(
"check_cache",
route_after_cache,
{"pubmed_retrieval": "pubmed_retrieval", "llm_generation": "llm_generation"}
)

# Add edges
graph.add_edge("pubmed_retrieval", "llm_generation")
graph.add_edge("llm_generation", "parse_claims")
graph.add_edge("parse_claims", "nli_scoring")
graph.add_edge("nli_scoring", "confidence_scoring")
graph.add_edge("confidence_scoring", "assembly")

# Set entry and finish
graph.set_entry_point("check_cache")
graph.set_finish_point("assembly")

# Compile
app = graph.compile()
102 changes: 102 additions & 0 deletions src/agent/nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import numpy as np
from src.agent.state import AgentState
from src.retrieval.cache import get_cache, set_cache
from src.retrieval.vector_store import add_abstracts, query_abstracts
from src.retrieval.pubmed import search_pubmed
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.output_parsers import JsonOutputParser
from sentence_transformers import CrossEncoder
from src.core.config import settings

_llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-lite", google_api_key=settings.GEMINI_API_KEY)
_nli_model = CrossEncoder("cross-encoder/nli-MiniLM2-L6-H768")

def check_cache(state: AgentState):
cached_result = get_cache(state["query"])
if cached_result:
return {"cache_hit": True, "abstracts": cached_result}
else:
return {"cache_hit": False}

def route_after_cache(state: AgentState) -> str:
if state["cache_hit"]:
return "llm_generation"
return "pubmed_retrieval"

def pubmed_retrieval(state: AgentState):
if not state["cache_hit"]:
results = search_pubmed(state["query"])
add_abstracts(results)
abstracts = query_abstracts(state["query"])
set_cache(state["query"], abstracts)
return {"abstracts": abstracts}

def llm_generation(state: AgentState):
context = "\n\n".join([f"Title: {a['title']}\nAbstract: {a['abstract']}"
for a in state["abstracts"]])

prompt = f"""You are a clinical assistant in charge of extracting insights from medical literature. Use the following documentation to answer the query.

Literature:
{context}

Query: {state["query"]}

Provide a detailed clinical response based solely on the provided literature."""

response = _llm.invoke(prompt)
return {"llm_response": response.content}

def parse_claims(state: AgentState):
parser = JsonOutputParser()

prompt = f"""Extract all discrete factual claims from the following clinical response.
Return ONLY a JSON array of strings, no other text.
Each claim should be a single verifiable factual statement.

Response:
{state["llm_response"]}

Return format: ["claim 1", "claim 2", "claim 3"]"""

response = _llm.invoke(prompt)
claims = parser.parse(response.content)
return {"claims": claims}

def nli_scoring(state: AgentState):
scored_claims = []
labels = ["Contradicted", "Supported", "Unverifiable"]
for claim in state["claims"]:
best_score = -1
best_result = None

for abstract in state["abstracts"]:
scores = _nli_model.predict([(abstract["abstract"], claim)])[0]
label_idx = int(np.argmax(scores))
if scores[label_idx] > best_score:
best_score = scores[label_idx]
best_result = {
"claim": claim,
"label": labels[label_idx],
"score": float(best_score),
"evidence": abstract["abstract"]
}

scored_claims.append(best_result)

return {"scored_claims": scored_claims}

def confidence_scoring(state: AgentState):
weights = {"Supported": 1.0, "Unverifiable": 0.5, "Contradicted": 0.0}
score = np.mean([weights[claim["label"]] for claim in state["scored_claims"]])
return {"confidence_score": score}

def assembly(state: AgentState):
return {"final_response": {
"query": state["query"],
"response": state["llm_response"],
"confidence_score": state["confidence_score"],
"scored_claims": state["scored_claims"],
"abstracts": state["abstracts"]
}
}
11 changes: 11 additions & 0 deletions src/agent/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import TypedDict, Optional

class AgentState(TypedDict):
query: str
cache_hit: bool
abstracts: list[dict]
llm_response: Optional[str]
claims: Optional[list[str]]
scored_claims: Optional[list[dict]]
confidence_score: Optional[float]
final_response: Optional[dict]
12 changes: 6 additions & 6 deletions src/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ class Settings(BaseSettings):
# Hugging Face
HF_TOKEN: str

# Ollama
OLLAMA_SERVER_IP: str
OLLAMA_SERVER_PORT: int = 11434
OLLAMA_MODEL: str = "llama3.2"
OLLAMA_MAX_TOKENS: int = 1024
OLLAMA_TEMPERATURE: float = 0.1
# Gemini
GEMINI_API_KEY: str

# Pinecone
PINECONE_API_KEY: str
PINECONE_INDEX_NAME: str

# Redis
REDIS_HOST: str
REDIS_PORT: str

# MLflow
MLFLOW_TRACKING_URI: str = "http://localhost:5000"
MLFLOW_ARTIFACT_LOCATION: str = "logs/mlflow"
Expand Down
Empty file added src/retrieval/__init__.py
Empty file.
20 changes: 20 additions & 0 deletions src/retrieval/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import redis
import json
from src.core.config import settings

_client = None

def get_client():
global _client
if _client is None:
_client = redis.Redis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, decode_responses=True)
return _client

def get_cache(key: str):
value = get_client().get(key)
if value is None:
return None
return json.loads(value)

def set_cache(key: str, value: list[dict]):
get_client().setex(key, 86400, json.dumps(value)) # 86400 = 24 hours
13 changes: 1 addition & 12 deletions src/retrieval/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from sentence_transformers import SentenceTransformer
from src.core.config import settings
import os
from src.retrieval.pubmed import search_pubmed
import sys

os.environ["HUGGING_FACE_HUB_TOKEN"] = settings.HF_TOKEN
model = SentenceTransformer("pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb")
Expand Down Expand Up @@ -46,13 +44,4 @@ def query_abstracts(query: str, n_results: int = 5) -> list[dict]:
top_k=n_results,
include_metadata=True
)
return results

def main():
search_results = search_pubmed("myocardial infarction", max_results=5)
add_abstracts(search_results)
results = query_abstracts("chest pain treatment")
print(results)

if __name__ == "__main__":
sys.exit(main())
return [match["metadata"] for match in results["matches"]]
12 changes: 12 additions & 0 deletions tests/pinecone_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from src.retrieval.pubmed import search_pubmed
from src.retrieval.vector_store import add_abstracts, query_abstracts
import sys

def main():
search_results = search_pubmed("pulmonary embolism", max_results=5)
add_abstracts(search_results)
results = query_abstracts("chest pain treatment")
print(results)

if __name__ == "__main__":
sys.exit(main())
19 changes: 19 additions & 0 deletions tests/pipeline_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from src.agent.graph import app
import sys

def main():
result = app.invoke({
"query": "What are the first line treatments for atrial fibrillation?",
"cache_hit": False,
"abstracts": [],
"llm_response": None,
"claims": None,
"scored_claims": None,
"confidence_score": None,
"final_response": None
})

print(result["final_response"])

if __name__ == "__main__":
sys.exit(main())
Loading