Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ num_candidates: 50
ensemble_method: "rrf"
ranker_weights: {"faiss":1,"bm25":0,"index_keywords":0}
rrf_k: 60
max_gen_tokens: 400
max_gen_tokens: 1024
chunk_mode : "recursive_sections"
gen_model: "models/qwen2.5-3b-instruct-q8_0.gguf"
chunk_size: 2000
Expand All @@ -17,3 +17,4 @@ rerank_top_k: 5
use_double_prompt: false
enable_history: true
max_history_turns: 3
use_query_expansion: false
2 changes: 2 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class RAGConfig:
use_hyde: bool = False
hyde_max_tokens: int = 300
use_double_prompt: bool = False
use_query_expansion: bool = False
query_expansion_max_tokens: int = 64

# conversational memory
enable_history: bool = True
Expand Down
2 changes: 1 addition & 1 deletion src/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def format_prompt(chunks, query, max_chunk_chars=400, system_prompt_mode="tutor"

_LLM_CACHE = {}

def get_llama_model(model_path: str, n_ctx: int = 4096):
def get_llama_model(model_path: str, n_ctx: int = 8192):
if model_path not in _LLM_CACHE:
try:
_LLM_CACHE[model_path] = Llama(model_path=model_path,
Expand Down
80 changes: 58 additions & 22 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from src.instrumentation.logging import get_logger
from src.ranking.ranker import EnsembleRanker
from src.preprocessing.chunking import DocumentChunker
from src.query_enhancement import generate_hypothetical_document, contextualize_query
from src.query_enhancement import generate_hypothetical_document, contextualize_query, expand_query_with_keywords, expand_query_with_perspectives
from src.retriever import (
filter_retrieved_chunks,
BM25Retriever,
Expand Down Expand Up @@ -47,6 +47,11 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="enable double prompting for higher quality answers"
)
parser.add_argument(
"--use_query_expansion",
action="store_true",
help="enable multi-query retrieval (query expansion) for higher accuracy"
)

return parser.parse_args()

Expand Down Expand Up @@ -131,38 +136,69 @@ def get_answer(
elif cfg.use_indexed_chunks:
ranked_chunks, topk_idxs = use_indexed_chunks(question, chunks)
else:
retrieval_query = question
# print(f"Retrieval query: {retrieval_query}")
if cfg.use_hyde:
retrieval_query = generate_hypothetical_document(question, cfg.gen_model, max_tokens=cfg.hyde_max_tokens)
use_expansion = getattr(args, "use_query_expansion", False) or cfg.use_query_expansion

if use_expansion:
queries_to_run = expand_query_with_perspectives(question, cfg.gen_model, max_tokens=cfg.query_expansion_max_tokens)
elif cfg.use_hyde:
queries_to_run = [generate_hypothetical_document(question, cfg.gen_model, max_tokens=cfg.hyde_max_tokens)]
else:
queries_to_run = [question]

pool_n = max(cfg.num_candidates, cfg.top_k + 10)
raw_scores: Dict[str, Dict[int, float]] = {}
for retriever in retrievers:
# print(f"Getting scores from retriever: {retriever.name}...")
raw_scores[retriever.name] = retriever.get_scores(retrieval_query, pool_n, chunks)
# TODO: Fix retrieval logging.

# print("Raw scores from retrievers:")
# for retriever_name, score_dict in raw_scores.items():
# print(f" {retriever_name}: {list(score_dict.values())}")

# Store original weights so we can adjust dynamically for RRF if multiple queries are used
original_weights = ranker.weights.copy()

for q_idx, q_variant in enumerate(queries_to_run):
for retriever in retrievers:
scores = retriever.get_scores(q_variant, pool_n, chunks)

if len(queries_to_run) == 1:
raw_scores[retriever.name] = scores
else:
r_key = f"{retriever.name}_{q_idx}"
raw_scores[r_key] = scores
if r_key not in ranker.weights:
base_weight = original_weights.get(retriever.name, 0.0)
if q_idx == 0:
ranker.weights[r_key] = base_weight * 0.5
else:
ranker.weights[r_key] = (base_weight * 0.5) / (len(queries_to_run) - 1)

if len(queries_to_run) > 1:
for retriever in retrievers:
ranker.weights[retriever.name] = 0.0

# Step 2: Ranking
ordered, scores = ranker.rank(raw_scores=raw_scores)
# print(f"Ordered candidate indices after ranking: {ordered[:cfg.top_k]}")
# print(f"Corresponding scores: {scores[:cfg.top_k]}")
topk_idxs = filter_retrieved_chunks(cfg, chunks, ordered)
ranked_chunks = [chunks[i] for i in topk_idxs]
# print(f"Top-{cfg.top_k} chunk indices after filtering: {topk_idxs}")
# print("Len Ranked chunks:", len(ranked_chunks))
# print("Example ranked chunk content:", ranked_chunks[0] if ranked_chunks else "No chunks retrieved")

# Restore old ranker weights to keep things clean
ranker.weights = original_weights

# Capture chunk info if in test mode
if is_test_mode:
# Compute individual ranker ranks
faiss_scores = raw_scores.get("faiss", {})
bm25_scores = raw_scores.get("bm25", {})
index_scores = raw_scores.get("index_keywords", {})
from collections import defaultdict
if len(queries_to_run) > 1:
agg_faiss = defaultdict(float)
agg_bm25 = defaultdict(float)
agg_index = defaultdict(float)
for rv, s_dict in raw_scores.items():
base_rv = rv.rsplit('_', 1)[0]
for cand, sc in s_dict.items():
if base_rv == "faiss": agg_faiss[cand] += sc
elif base_rv == "bm25": agg_bm25[cand] += sc
elif base_rv == "index_keywords": agg_index[cand] += sc
faiss_scores = dict(agg_faiss)
bm25_scores = dict(agg_bm25)
index_scores = dict(agg_index)
else:
faiss_scores = raw_scores.get("faiss", {})
bm25_scores = raw_scores.get("bm25", {})
index_scores = raw_scores.get("index_keywords", {})

faiss_ranked = sorted(faiss_scores.keys(), key=lambda i: faiss_scores[i], reverse=True)
bm25_ranked = sorted(bm25_scores.keys(), key=lambda i: bm25_scores[i], reverse=True)
Expand Down
44 changes: 43 additions & 1 deletion src/query_enhancement.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ def generate_hypothetical_document(
prompt,
model_path,
max_tokens=max_tokens,
temperature=0.7,
**llm_kwargs
)

return hypothetical.strip()
return hypothetical["choices"][0]["text"].strip()

def correct_query_grammar(
query: str,
Expand Down Expand Up @@ -126,6 +127,47 @@ def expand_query_with_keywords(
return query_lines


def expand_query_with_perspectives(
query: str,
model_path: str,
max_tokens: int = 64,
**llm_kwargs
) -> str:
"""
Query Expansion: Generates queries from different technical perspectives.
Provides broader context for retrieval to achieve better performance.
"""
prompt = textwrap.dedent(f"""\
<|im_start|>system
You are a search optimization expert.
Generate 3 alternative versions of the user's query that approach the topic from different technical perspectives or use different specialized terminology.
This provides broader context for retrieval. Output the alternative queries separated by newlines. Do not provide explanations or numberings.
<|im_end|>
<|im_start|>user
Query: {query}
<|im_end|>
<|im_start|>assistant
""")

prompt = text_cleaning(prompt)
expansion = run_llama_cpp(
prompt,
model_path,
max_tokens=max_tokens,
temperature=0.5,
**llm_kwargs
)

# Combine original query with expansion
query_lines = [query]
query_lines.extend([line.strip() for line in expansion["choices"][0]["text"].split('\n') if line.strip()])

# Remove numbering if present
query_lines = [line.split('.', 1)[-1].strip() if '.' in line[:3] else line for line in query_lines]

return query_lines


def decompose_complex_query(
query: str,
model_path: str,
Expand Down
2 changes: 1 addition & 1 deletion src/ranking/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def rerank_with_cross_encoder(query: str, chunks: List[str], top_n: int) -> List
chunk_with_scores = list(zip(chunks, scores))
chunk_with_scores.sort(key=lambda x: x[1], reverse=True)

return chunk_with_scores[:top_n]
return [chunk for chunk, score in chunk_with_scores[:top_n]]


# -------------------------- Reranking Router -----------------------------
Expand Down
23 changes: 17 additions & 6 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,16 +534,12 @@ def test_generate_hypothetical_document(self, mock_llm):
"""generate_hypothetical_document returns string."""
from src.query_enhancement import generate_hypothetical_document

# Note: The function calls .strip() on run_llama_cpp result.
# run_llama_cpp returns a dict, so there's an inconsistency in the source.
# For API testing, mock to return a string (what the function expects).
mock_llm.return_value = "A hypothetical answer about databases."
mock_llm.return_value = {"choices": [{"text": "A hypothetical answer about databases."}]}

result = generate_hypothetical_document(
"What is a transaction?",
model_path="mock_model",
max_tokens=100,
temperature=0.5
max_tokens=100
)

assert isinstance(result, str)
Expand Down Expand Up @@ -592,6 +588,21 @@ def test_decompose_complex_query(self, mock_llm):

assert isinstance(result, list)

@patch('src.query_enhancement.run_llama_cpp')
def test_expand_query_with_perspectives(self, mock_llm):
"""expand_query_with_perspectives returns list of specific perspectives."""
from src.query_enhancement import expand_query_with_perspectives

mock_llm.return_value = {"choices": [{"text": "1. Database systems from an architecture perspective\n2. Data storage from a hardware perspective"}]}

result = expand_query_with_perspectives(
"database architecture",
model_path="mock_model"
)

assert isinstance(result, list)
assert len(result) > 0


# ====================== Load Artifacts Tests ======================

Expand Down