diff --git a/requirements.txt b/requirements.txt index d7aa018..10f8a0a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,4 +32,5 @@ langgraph>=0.1.0 redis>=5.0.0 # General -numpy>=2.4.0 \ No newline at end of file +numpy>=2.4.0 +torch>=2.11.0 \ No newline at end of file diff --git a/src/agent/graph.py b/src/agent/graph.py index 54febf8..b42e758 100644 --- a/src/agent/graph.py +++ b/src/agent/graph.py @@ -1,12 +1,13 @@ from langgraph.graph import StateGraph from src.agent.nodes import check_cache, pubmed_retrieval, llm_generation, parse_claims, nli_scoring, \ - confidence_scoring, assembly, route_after_cache + confidence_scoring, assembly, route_after_cache, preprocess_query from src.agent.state import AgentState graph = StateGraph(AgentState) # Add nodes graph.add_node("check_cache", check_cache) +graph.add_node("preprocess_query", preprocess_query) graph.add_node("pubmed_retrieval", pubmed_retrieval) graph.add_node("llm_generation", llm_generation) graph.add_node("parse_claims", parse_claims) @@ -22,6 +23,7 @@ ) # Add edges +graph.add_edge("preprocess_query", "check_cache") graph.add_edge("pubmed_retrieval", "llm_generation") graph.add_edge("llm_generation", "parse_claims") graph.add_edge("parse_claims", "nli_scoring") @@ -29,7 +31,7 @@ graph.add_edge("confidence_scoring", "assembly") # Set entry and finish -graph.set_entry_point("check_cache") +graph.set_entry_point("preprocess_query") graph.set_finish_point("assembly") # Compile diff --git a/src/agent/nodes.py b/src/agent/nodes.py index 4a8ddb6..08f590d 100644 --- a/src/agent/nodes.py +++ b/src/agent/nodes.py @@ -1,4 +1,5 @@ import numpy as np +import torch from src.agent.state import AgentState from src.retrieval.cache import get_cache, set_cache from src.retrieval.vector_store import add_abstracts, query_abstracts @@ -8,7 +9,7 @@ from sentence_transformers import CrossEncoder from src.core.config import settings -_llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-lite", google_api_key=settings.GEMINI_API_KEY) +_llm = ChatGoogleGenerativeAI(model="gemma-3-27b-it", google_api_key=settings.GEMINI_API_KEY) _nli_model = CrossEncoder("cross-encoder/nli-MiniLM2-L6-H768") def check_cache(state: AgentState): @@ -23,9 +24,21 @@ def route_after_cache(state: AgentState) -> str: return "llm_generation" return "pubmed_retrieval" +def preprocess_query(state: AgentState): + prompt = f"""Extract a concise PubMed search query (3-6 words) from this clinical question. + Return ONLY the search terms, nothing else. + + Question: {state["query"]} + + Search terms:""" + + response = _llm.invoke(prompt) + search_query = response.content.strip() + return {"search_query": search_query} + def pubmed_retrieval(state: AgentState): if not state["cache_hit"]: - results = search_pubmed(state["query"]) + results = search_pubmed(state["search_query"]) add_abstracts(results) abstracts = query_abstracts(state["query"]) set_cache(state["query"], abstracts) @@ -72,15 +85,27 @@ def nli_scoring(state: AgentState): for abstract in state["abstracts"]: scores = _nli_model.predict([(abstract["abstract"], claim)])[0] + scores = torch.softmax(torch.tensor(scores), dim=0).numpy() label_idx = int(np.argmax(scores)) - if scores[label_idx] > best_score: - best_score = scores[label_idx] - best_result = { - "claim": claim, - "label": labels[label_idx], - "score": float(best_score), - "evidence": abstract["abstract"] - } + + if label_idx != 2: + non_neutral_score = max(scores[0], scores[1]) + if non_neutral_score > 0.7 and non_neutral_score > best_score: + best_score = non_neutral_score + best_result = { + "claim": claim, + "label": labels[label_idx], + "score": float(non_neutral_score), + "evidence": abstract["abstract"] + } + + if best_result is None: + best_result = { + "claim": claim, + "label": "Unverifiable", + "score": 0.0, + "evidence": None + } scored_claims.append(best_result) diff --git a/src/agent/state.py b/src/agent/state.py index 993084d..fbaa25b 100644 --- a/src/agent/state.py +++ b/src/agent/state.py @@ -3,6 +3,7 @@ class AgentState(TypedDict): query: str cache_hit: bool + search_query: Optional[str] abstracts: list[dict] llm_response: Optional[str] claims: Optional[list[str]] diff --git a/src/retrieval/vector_store.py b/src/retrieval/vector_store.py index 6d7b570..67ca07c 100644 --- a/src/retrieval/vector_store.py +++ b/src/retrieval/vector_store.py @@ -23,19 +23,19 @@ def get_collection(): def add_abstracts(abstracts: list[dict]): data_list = [] - for item in abstracts: - data = {'id': item["pmid"], - 'values': _embed_text(item["abstract"]), - 'metadata': { + if item["abstract"] and item["pmid"]: + data_list.append({ + "id": item["pmid"], + "values": _embed_text(item["abstract"]), + "metadata": { "title": item["title"], "abstract": item["abstract"], "pmid": item["pmid"] - } } - data_list.append(data) - - get_collection().upsert(vectors=data_list) + }) + if data_list: + get_collection().upsert(vectors=data_list) def query_abstracts(query: str, n_results: int = 5) -> list[dict]: embedding = _embed_text(query)