diff --git a/README.md b/README.md index 9a3ff6bc..4d45ea9b 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ ## ๐Ÿ“‹ Requirements -- **Python**: 3.9+ +- **Python**: 3.9+ - **Conda/Miniconda**: For environment management - **System Requirements**: - macOS: Xcode Command Line Tools @@ -25,45 +25,58 @@ ## ๐Ÿš€ Quick Start ### 1. Clone the Repository + ```shell git clone https://github.com/georgia-tech-db/TokenSmith.git cd tokensmith ``` ### One-command setup: creates conda env, builds llama.cpp, installs dependencies + ```shell make build ``` + This will: + - Create a conda environment named `tokensmith` - Install all Python dependencies - Detect or build llama.cpp with platform-specific optimizations - Install TokenSmith in development mode ### 3. Activate the Environment + ```shell conda activate tokensmith ``` ### 4. Prepare Your Documents + Place your PDF files in the data directory + ```shell mkdir -p data/chapters cp your-documents.pdf data/chapters/ ``` ### 5. Index Your Documents + Index with default settings + ```shell make run-index ``` + Or with custom parameters, eg. + ```shell make run-index ARGS="--pdf_range 1-10 --chunk_mode chars --visualize" ``` ### 6. Start Chatting + Activate environment first (required for interactive mode) + ```shell conda activate tokensmith python -m src.main chat @@ -72,6 +85,7 @@ python -m src.main chat > You might have to download `qwen2.5-0.5b-instruct-q5_k_m.gguf` into your `llama.cpp/models` if you get an error about a missing model. ### 7. Deactivate the Environment + ```shell conda deactivate ``` @@ -85,6 +99,7 @@ TokenSmith uses YAML configuration files with the following priority order: 3. Default config (`config/config.yaml`) ### Sample Configuration + ```yaml # config/config.yaml @@ -104,55 +119,64 @@ chunk_size_char: 20000 ## ๐ŸŽฎ Usage ### Basic indexing + ```shell make run-index ``` ### Index specific PDF range + ```shell make run-index ARGS="--pdf_range - --chunk_mode " ``` ### Index with visualization and table preservation + ```shell make run-index ARGS="--keep_tables --visualize --chunk_tokens " ``` ### Custom paths and settings + ```shell make run-index ARGS="--pdf_dir --index_prefix book_index --config " ``` ### Chat with custom settings + ```shell python -m src.main chat --config --model_path ``` ### Build with existing llama.cpp installation + ```shell export LLAMA_CPP_BINARY=/usr/local/bin/llama-cli make build ``` ### Update environment with new dependencies + ```shell make update-env ``` ### Export environment for sharing + ```shell make export-env ``` ### Show installed packages + ```shell make show-deps ``` - ## ๐Ÿ“Š Command Line Arguments ### Core Arguments + - `mode`: Operation mode (`index` or `chat`) - `--config`: Configuration file path - `--pdf_dir`: Directory containing PDF files @@ -160,6 +184,7 @@ make show-deps - `--model_path`: Path to GGUF model file ### Indexing Arguments + - `--pdf_range`: Process specific page range (e.g., "1-10") - `--chunk_mode`: Chunking strategy (`tokens` or `chars`) - `--chunk_tokens`: Tokens per chunk (default: 500) @@ -170,6 +195,7 @@ make show-deps ## ๐Ÿ”จ Development ### Available Make Targets + ```shell make help # Show all available commands make env # Create conda environment @@ -184,13 +210,15 @@ make export-env # Export environment with exact versions ``` ### Adding Dependencies + ```shell # Add new conda package conda activate tokensmith conda install new-package ``` + Add to environment.yml for persistence. Edit environment.yml, then: + ```shell make update-env ``` - diff --git a/src/config.py b/src/config.py index 5ab38f9e..79ca8c15 100644 --- a/src/config.py +++ b/src/config.py @@ -2,11 +2,17 @@ import os from dataclasses import dataclass -from typing import Dict, Callable, Any +from typing import Dict, Callable, Any, Optional import yaml -from src.chunking import ChunkStrategy, make_chunk_strategy, CharChunkConfig, TokenChunkConfig, SlidingTokenConfig, \ +import sys +from pathlib import Path + +# Add src to path for imports +sys.path.insert(0, str(Path(__file__).parent)) + +from chunking import ChunkStrategy, make_chunk_strategy, CharChunkConfig, TokenChunkConfig, SlidingTokenConfig, \ SectionChunkConfig, ChunkConfig @@ -32,6 +38,12 @@ class QueryPlanConfig: model_path: os.PathLike + # citation settings + enable_citations: bool + + # planner hints (optional, populated by planners) + location_hint: Optional[Dict[str, Any]] = None + # ---------- chunking strategy + artifact name helpers ---------- def make_strategy(self) -> ChunkStrategy: return make_chunk_strategy(config=self.chunk_config) @@ -71,7 +83,9 @@ def pick(key, default=None): max_gen_tokens = pick("max_gen_tokens", 400), halo_mode = pick("halo_mode", "none"), seg_filter = pick("seg_filter", None), - model_path = pick("model_path", None) + model_path = pick("model_path", None), + enable_citations = pick("enable_citations", True), + location_hint = None, ) cfg._validate() return cfg @@ -117,6 +131,8 @@ def to_dict(self) -> Dict[str, Any]: "ranker_weights": self.ranker_weights, "halo_mode": self.halo_mode, "max_gen_tokens": self.max_gen_tokens, - "model_path": self.model_path + "model_path": self.model_path, + "enable_citations": self.enable_citations, + "location_hint": self.location_hint, } diff --git a/src/instrumentation/logging.py b/src/instrumentation/logging.py index edbd5aec..542a5b7f 100644 --- a/src/instrumentation/logging.py +++ b/src/instrumentation/logging.py @@ -142,13 +142,15 @@ def log_ensemble_result(self, final_ranking: List[int], ensemble_method: str, self.current_query_data["ensemble"] = ensemble_data def log_chunks_used(self, chunk_indices: List[int], chunks: List[str], - sources: List[str], chunk_tags: Optional[List[List[str]]] = None): + sources: List[str], chunk_tags: Optional[List[List[str]]] = None, + metadata: Optional[List[Dict[str, Any]]] = None): """Log details about chunks selected for generation.""" if not self.current_query_data: return chunks_data = [] for i, idx in enumerate(chunk_indices): + m = metadata[idx] if (metadata and idx < len(metadata)) else {} chunk_info = { "rank": i + 1, "global_index": idx, @@ -158,7 +160,9 @@ def log_chunks_used(self, chunk_indices: List[int], chunks: List[str], "has_table": "" in chunks[idx].lower() if idx < len(chunks) else False, "preview": (chunks[idx][:200] + "...") if idx < len(chunks) and len(chunks[idx]) > 200 else chunks[ idx] if idx < len(chunks) else "", - "tags": chunk_tags[idx][:10] if chunk_tags and idx < len(chunk_tags) else [] + "tags": chunk_tags[idx][:10] if chunk_tags and idx < len(chunk_tags) else [], + "section": m.get("section"), + "filename": m.get("filename"), } chunks_data.append(chunk_info) diff --git a/src/location_handler.py b/src/location_handler.py new file mode 100644 index 00000000..5ad03646 --- /dev/null +++ b/src/location_handler.py @@ -0,0 +1,95 @@ +""" +Location query detection and response handling. +""" +import re +from typing import List, Dict, Tuple + + +def is_location_query(text: str) -> bool: + """ + Detect if a query is asking for location information. + + Args: + text: The user's query text + + Returns: + True if this is a location query, False otherwise + """ + t = text.lower().strip() + # Multiple patterns to catch various "where" question formats + patterns = [ + r"^where\s+is\s+", # "where is X" + r"^where\s+can\s+i\s+find", # "where can I find X" + r"^where\s+do\s+i\s+find", # "where do I find X" + r"^where\s+is\s+.*\s+(located|found|discussed|covered|explained|described)", # "where is X located/found/etc" + r"^where\s+can\s+.*\s+(find|locate|get)", # "where can I find X" + r"^where\s+does\s+.*\s+(appear|occur|show)", # "where does X appear" + r"^in\s+which\s+(section|chapter|part)", # "in which section is X" + r"^what\s+(section|chapter|part).*", # "what section covers X" + ] + return any(re.search(pattern, t) for pattern in patterns) + + +def format_location_response(topk_idxs: List[int], metadata: List[Dict], max_locations: int = 5) -> str: + """ + Format a location response from the top retrieved chunks. + + Args: + topk_idxs: List of chunk indices that were selected + metadata: List of metadata dictionaries for each chunk + max_locations: Maximum number of locations to return + + Returns: + Formatted string with numbered location list + """ + seen = set() + locations = [] + + for i in topk_idxs: + sec = str(metadata[i].get("section", "")).strip() + if sec.startswith("## "): + sec = sec[3:].strip() + if sec and sec not in seen: + seen.add(sec) + locations.append(sec) + if len(locations) >= max_locations: + break + + if locations: + return "\n".join(f"{rank}. {s}" for rank, s in enumerate(locations, 1)) + else: + return "(no matching sections found)" + + +def format_citations(topk_idxs: List[int], metadata: List[Dict], max_citations: int = 3) -> str: + """ + Format inline citations from the top retrieved chunks. + + Args: + topk_idxs: List of chunk indices that were selected + metadata: List of metadata dictionaries for each chunk + max_citations: Maximum number of citations to return + + Returns: + Formatted citations string + """ + seen = set() + sections = [] + + for i in topk_idxs: + sec = str(metadata[i].get("section", "")).strip() + if not sec: + continue + # remove markdown heading markers if present + if sec.startswith("## "): + sec = sec[3:].strip() + if sec not in seen: + seen.add(sec) + sections.append(sec) + if len(sections) >= max_citations: + break + + if sections: + return "; ".join(f"[{s}]" for s in sections) + else: + return "" diff --git a/src/main.py b/src/main.py index 6ea5fa8b..ab584976 100644 --- a/src/main.py +++ b/src/main.py @@ -4,9 +4,10 @@ from src.config import QueryPlanConfig from src.instrumentation.logging import init_logger, get_logger from src.planning.heuristics import HeuristicQueryPlanner +from src.planning.difficulty_planner import QueryDifficultyPlanner from src.preprocess import build_index from src.ranking.ensemble import EnsembleRanker -from src.ranking.rankers import FaissSimilarityRanker, BM25Ranker, TfIDFRanker +from src.ranking.rankers import FaissSimilarityRanker, BM25Ranker, TfIDFRanker, LocationRanker from src.retriever import get_candidates, apply_seg_filter from src.ranker import rerank from src.generator import answer @@ -28,6 +29,8 @@ def parse_args(): p.add_argument("--pdf_range", type=str, default=None, help="e.g., 27-33") p.add_argument("--keep_tables", action="store_true") p.add_argument("--visualize", action="store_true") + p.add_argument("--planner", choices=["heuristic", "difficulty"], default="heuristic", + help="Choose query planner: heuristic (default) or difficulty-based") return p.parse_args() @@ -66,7 +69,14 @@ def main(): init_logger(cfg) logger = get_logger() - planner = HeuristicQueryPlanner(cfg) + + # Choose planner based on argument + if args.planner == "difficulty": + planner = QueryDifficultyPlanner(cfg) + print("Using difficulty-based planner (easy/hard pipelines)") + else: + planner = HeuristicQueryPlanner(cfg) + print("Using heuristic-based planner") if args.mode == "index": # Optional range filtering @@ -96,7 +106,7 @@ def main(): break logger.log_query_start(q) cfg = planner.plan(q) - index, chunks, sources, vectorizer, chunk_tags = load_artifacts( + index, chunks, sources, vectorizer, chunk_tags, metadata = load_artifacts( cfg.index_prefix, cfg ) @@ -115,6 +125,8 @@ def main(): "faiss_distances": faiss_dists, # for FaissSimilarityRanker "vectorizer": vectorizer, # for TfIDFRanker "chunk_tags": chunk_tags, # for TfIDFRanker + "metadata": metadata, # for LocationRanker + "location_hint": cfg.location_hint, # for LocationRanker } # 3) build rankers + ensemble (using weights from config) @@ -122,6 +134,7 @@ def main(): FaissSimilarityRanker(), BM25Ranker(), TfIDFRanker(), + LocationRanker(), ] weights = cfg.ranker_weights method = cfg.ensemble_method @@ -132,24 +145,40 @@ def main(): query=q, chunks=chunks, cand_idxs=cand_idxs, context=context ) - topk_idxs = apply_seg_filter(cfg, chunks, ordered) - logger.log_chunks_used(topk_idxs, chunks, sources, chunk_tags) - - # 4) materialize indices into text and continue - ranked_chunks = [chunks[i] for i in topk_idxs] - - # HALO Stub (NO OP for now) - ranked_chunks = rerank(q, ranked_chunks, mode=cfg.halo_mode) + # Location-aware boosting is now handled by LocationRanker in the ensemble - ans = answer( - q, - ranked_chunks, - args.model_path or cfg.model_path, - max_tokens=cfg.max_gen_tokens, - ) - print("\n=== ANSWER =========================================\n") - print(ans if ans.strip() else "(no output)") - print("\n====================================================\n") + topk_idxs = apply_seg_filter(cfg, chunks, ordered) + logger.log_chunks_used(topk_idxs, chunks, sources, chunk_tags, metadata) + + # 4) Handle location queries vs regular queries + from src.location_handler import is_location_query, format_location_response, format_citations + + ans = "" + if is_location_query(q): + # Location query: return section headings + ans = format_location_response(topk_idxs, metadata) + print(ans) + else: + # Regular query: generate answer with citations + ranked_chunks = [chunks[i] for i in topk_idxs] + ranked_chunks = rerank(q, ranked_chunks, mode=cfg.halo_mode) + + ans = answer( + q, + ranked_chunks, + args.model_path or cfg.model_path, + max_tokens=cfg.max_gen_tokens, + ) + + # Add inline citations if enabled + if cfg.enable_citations: + citations = format_citations(topk_idxs, metadata) + if citations: + ans = f"{ans}\n\nReferences: {citations}" + + print("\n=== ANSWER =========================================\n") + print(ans if ans.strip() else "(no output)") + print("\n====================================================\n") logger.log_generation( ans, {"max_tokens": cfg.max_gen_tokens, "model_path": args.model_path} ) diff --git a/src/planning/__init__.py b/src/planning/__init__.py index e69de29b..8d44ca0e 100644 --- a/src/planning/__init__.py +++ b/src/planning/__init__.py @@ -0,0 +1,14 @@ +# Planning module exports +from .planner import QueryPlanner +from .heuristics import HeuristicQueryPlanner +from .difficulty_planner import QueryDifficultyPlanner +from .comparison_planner import ComparisonPlanner, run_comparison_test + +__all__ = [ + 'QueryPlanner', + 'HeuristicQueryPlanner', + 'QueryDifficultyPlanner', + 'ComparisonPlanner', + 'run_comparison_test' +] + diff --git a/src/planning/comparison_planner.py b/src/planning/comparison_planner.py new file mode 100644 index 00000000..1dbb07bc --- /dev/null +++ b/src/planning/comparison_planner.py @@ -0,0 +1,344 @@ +import sys +import os +from pathlib import Path + +# Add src to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from config import QueryPlanConfig +from chunking import CharChunkConfig, TokenChunkConfig, SlidingTokenConfig, SectionChunkConfig +from copy import deepcopy +import re +import json +import time +from typing import Dict, Any, List, Tuple +from datetime import datetime + +from planning.planner import QueryPlanner +from planning.heuristics import HeuristicQueryPlanner +from planning.difficulty_planner import QueryDifficultyPlanner +from generator import run_llama_cpp + + +class ComparisonPlanner(QueryPlanner): + """ + Comparison Query Planner + ------------------------ + Compares regex-based vs Qwen model-based difficulty classification. + + This planner runs both classification methods and logs the results + for analysis and comparison. + """ + + @property + def name(self) -> str: + return "ComparisonPlanner" + + def __init__(self, base_cfg: QueryPlanConfig): + super().__init__(base_cfg) + self.base_cfg = deepcopy(base_cfg) + self.difficulty_planner = QueryDifficultyPlanner(base_cfg) + self.comparison_results = [] + self.model_path = base_cfg.model_path + + def classify_difficulty_regex(self, query: str) -> str: + """ + Use the existing regex-based classification from QueryDifficultyPlanner. + """ + return self.difficulty_planner.classify_difficulty(query) + + def classify_difficulty_qwen(self, query: str) -> str: + """ + Use Qwen model to classify question difficulty as easy or hard. + + Returns 'easy' or 'hard' based on the model's binary classification. + """ + prompt = self._create_difficulty_prompt(query) + + try: + response = run_llama_cpp( + prompt=prompt, + model_path=self.model_path, + max_tokens=50, # Short response expected + temperature=0.1, # Low temperature for consistent classification + threads=4 + ) + + # Parse the response to extract difficulty classification + difficulty_classification = self._parse_difficulty_response(response) + + # Return the classification directly (should be "easy" or "hard") + return difficulty_classification + + except Exception as e: + print(f"Error in Qwen classification: {e}") + # Fallback to regex classification + return self.classify_difficulty_regex(query) + + def _create_difficulty_prompt(self, query: str) -> str: + """ + Create a prompt for Qwen to classify question difficulty. + """ + return f"""<|im_start|>system +You are an expert at analyzing question difficulty. Classify the following question as either EASY or HARD. + +EASY questions are: +- Simple definitions or fact recall (e.g., "What is a database?") +- Basic concept explanations (e.g., "What is a primary key?") +- Straightforward comparisons (e.g., "What is the difference between X and Y?") +- Questions that can be answered with a single concept or definition +- Questions that require minimal reasoning or explanation + +HARD questions are: +- Complex system design questions (e.g., "Design a distributed database system...") +- Questions requiring multi-step reasoning and problem-solving +- Questions involving multiple constraints, trade-offs, or conflicting requirements +- Questions about implementing complex algorithms or architectures +- Questions that require deep technical knowledge and analysis +- Questions with phrases like "design", "implement", "architect", "optimize", "handle", "maintain" combined with complex requirements +- Questions mentioning distributed systems, scalability, fault tolerance, consensus algorithms, Byzantine failures, network partitions, or similar advanced concepts + +Look for keywords and complexity indicators: +- HARD indicators: "design", "implement", "architect", "distributed", "scalable", "fault-tolerant", "consensus", "Byzantine", "network partitions", "multi-step", "optimize", "handle", "maintain", "guarantee", "ensure" +- EASY indicators: "what is", "what does", "explain", "define", "difference between", simple comparisons + +Respond with ONLY "EASY" or "HARD" followed by a brief explanation. +<|im_end|> +<|im_start|>user +Question: {query} +<|im_end|> +<|im_start|>assistant +""" + + def _parse_difficulty_response(self, response: str) -> str: + """ + Parse the Qwen response to extract the difficulty classification. + """ + # Extract the assistant's response (after "assistant" marker) + # Look for the pattern "assistant\n" or "assistant " followed by the response + assistant_match = re.search(r'assistant\s*\n\s*(.*?)(?:\[end of text\]|$)', response, re.DOTALL | re.IGNORECASE) + if assistant_match: + assistant_response = assistant_match.group(1).strip() + else: + # If we can't find the assistant marker, use the whole response + assistant_response = response + + # Look for "HARD" first (more specific), then "EASY" + if re.search(r'\bHARD\b', assistant_response, re.IGNORECASE): + return "hard" + elif re.search(r'\bEASY\b', assistant_response, re.IGNORECASE): + return "easy" + + # Fallback: look for the pattern "X [end of text]" for old format + match = re.search(r'(\d)\s*\[end of text\]', response) + if match: + level = int(match.group(1)) + if 1 <= level <= 3: + return "easy" + elif 4 <= level <= 5: + return "hard" + + # Fallback: look for any number 1-5 in the entire response + match = re.search(r'\b([1-5])\b', response) + if match: + level = int(match.group(1)) + if 1 <= level <= 3: + return "easy" + elif 4 <= level <= 5: + return "hard" + + # If no valid classification found, return "easy" as default + print(f"Warning: Could not parse difficulty from response. Assistant response: {assistant_response[:200]}...") + return "easy" + + def compare_classifications(self, query: str) -> Dict[str, Any]: + """ + Compare regex vs Qwen classifications for a single query. + """ + start_time = time.time() + + # Get regex classification + regex_result = self.classify_difficulty_regex(query) + regex_time = time.time() - start_time + + # Get Qwen classification + qwen_start = time.time() + qwen_result = self.classify_difficulty_qwen(query) + qwen_time = time.time() - qwen_start + + # Determine agreement + agreement = regex_result == qwen_result + + result = { + "query": query, + "regex_classification": regex_result, + "qwen_classification": qwen_result, + "agreement": agreement, + "regex_time_ms": round(regex_time * 1000, 2), + "qwen_time_ms": round(qwen_time * 1000, 2), + "timestamp": datetime.now().isoformat(), + "query_length": len(query), + "word_count": len(query.split()) + } + + self.comparison_results.append(result) + return result + + def plan(self, query: str) -> QueryPlanConfig: + """ + Main planning method that compares both approaches and uses regex result for actual planning. + """ + # Compare classifications + comparison = self.compare_classifications(query) + + # Use regex result for actual planning (as it's faster and more reliable) + difficulty = comparison["regex_classification"] + + if difficulty == "easy": + cfg = self.difficulty_planner.get_easy_pipeline_config(query, self.base_cfg) + cfg._pipeline_type = "easy" + else: + cfg = self.difficulty_planner.get_hard_pipeline_config(query, self.base_cfg) + cfg._pipeline_type = "hard" + + # Log the decision with comparison data + self._log_decision(cfg, extra_info={ + "difficulty": difficulty, + "pipeline": cfg._pipeline_type, + "comparison": comparison + }) + + return cfg + + def get_comparison_summary(self) -> Dict[str, Any]: + """ + Generate a summary of all comparison results. + """ + if not self.comparison_results: + return {"message": "No comparisons performed yet"} + + total_queries = len(self.comparison_results) + agreements = sum(1 for r in self.comparison_results if r["agreement"]) + agreement_rate = agreements / total_queries if total_queries > 0 else 0 + + # Count classifications + regex_counts = {"easy": 0, "medium": 0, "hard": 0} + qwen_counts = {"easy": 0, "medium": 0, "hard": 0} + + for result in self.comparison_results: + regex_counts[result["regex_classification"]] += 1 + qwen_counts[result["qwen_classification"]] += 1 + + # Calculate average times + avg_regex_time = sum(r["regex_time_ms"] for r in self.comparison_results) / total_queries + avg_qwen_time = sum(r["qwen_time_ms"] for r in self.comparison_results) / total_queries + + # Find disagreements + disagreements = [r for r in self.comparison_results if not r["agreement"]] + + return { + "total_queries": total_queries, + "agreement_rate": round(agreement_rate, 3), + "agreements": agreements, + "disagreements": len(disagreements), + "regex_classifications": regex_counts, + "qwen_classifications": qwen_counts, + "avg_regex_time_ms": round(avg_regex_time, 2), + "avg_qwen_time_ms": round(avg_qwen_time, 2), + "speed_ratio": round(avg_qwen_time / avg_regex_time, 2) if avg_regex_time > 0 else 0, + "disagreement_examples": disagreements[:5] # Show first 5 disagreements + } + + def save_comparison_results(self, filename: str = None) -> str: + """ + Save comparison results to a JSON file. + """ + if filename is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"comparison_results_{timestamp}.json" + + data = { + "summary": self.get_comparison_summary(), + "detailed_results": self.comparison_results, + "timestamp": datetime.now().isoformat() + } + + with open(filename, 'w') as f: + json.dump(data, f, indent=2) + + return filename + + def _log_decision(self, cfg: QueryPlanConfig, extra_info: Dict[str, Any] = None): + """Log the planning decision with comparison information.""" + info = { + "planner": self.name, + "pipeline_type": getattr(cfg, '_pipeline_type', 'unknown'), + "chunk_mode": cfg.chunk_mode, + "ranker_weights": cfg.ranker_weights, + "pool_size": cfg.pool_size, + "top_k": cfg.top_k, + } + + if extra_info: + info.update(extra_info) + + if cfg.location_hint: + info["location_hint"] = cfg.location_hint + + print(f"[COMPARISON_PLANNER] {info}") + + +def run_comparison_test(queries: List[str], config_path: str = "config/config.yaml") -> ComparisonPlanner: + """ + Run a comparison test with a list of queries. + + Args: + queries: List of query strings to test + config_path: Path to configuration file + + Returns: + ComparisonPlanner instance with results + """ + from config import QueryPlanConfig + + # Load configuration + cfg = QueryPlanConfig.from_yaml(config_path) + + # Create comparison planner + planner = ComparisonPlanner(cfg) + + print(f"Running comparison test with {len(queries)} queries...") + print("=" * 60) + + for i, query in enumerate(queries, 1): + print(f"\nQuery {i}/{len(queries)}: {query}") + result = planner.compare_classifications(query) + print(f" Regex: {result['regex_classification']} ({result['regex_time_ms']}ms)") + print(f" Qwen: {result['qwen_classification']} ({result['qwen_time_ms']}ms)") + print(f" Agreement: {'โœ“' if result['agreement'] else 'โœ—'}") + + # Print summary + print("\n" + "=" * 60) + print("COMPARISON SUMMARY") + print("=" * 60) + summary = planner.get_comparison_summary() + print(f"Total queries: {summary['total_queries']}") + print(f"Agreement rate: {summary['agreement_rate']:.1%}") + print(f"Average regex time: {summary['avg_regex_time_ms']}ms") + print(f"Average Qwen time: {summary['avg_qwen_time_ms']}ms") + print(f"Speed ratio (Qwen/Regex): {summary['speed_ratio']:.1f}x") + + print(f"\nRegex classifications: {summary['regex_classifications']}") + print(f"Qwen classifications: {summary['qwen_classifications']}") + + if summary['disagreements'] > 0: + print(f"\nDisagreements ({summary['disagreements']}):") + for i, d in enumerate(summary['disagreement_examples'], 1): + print(f" {i}. \"{d['query'][:60]}...\"") + print(f" Regex: {d['regex_classification']}, Qwen: {d['qwen_classification']}") + + # Save results + filename = planner.save_comparison_results() + print(f"\nResults saved to: {filename}") + + return planner diff --git a/src/planning/difficulty_planner.py b/src/planning/difficulty_planner.py new file mode 100644 index 00000000..8b31eb2b --- /dev/null +++ b/src/planning/difficulty_planner.py @@ -0,0 +1,304 @@ +import sys +from pathlib import Path + +# Add src to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from config import QueryPlanConfig +from chunking import CharChunkConfig, TokenChunkConfig, SlidingTokenConfig, SectionChunkConfig +from copy import deepcopy +import re +from typing import Dict, Any + +from planning.planner import QueryPlanner +from planning.heuristics import HeuristicQueryPlanner + + +class QueryDifficultyPlanner(QueryPlanner): + """ + Two-Pipeline Query Planner + -------------------------- + Routes queries to either "easy" or "hard" pipelines based on complexity indicators. + + Easy Pipeline: Simple, direct queries that benefit from fast, keyword-based retrieval + Hard Pipeline: Complex queries that need comprehensive semantic understanding + """ + + @property + def name(self) -> str: + return "DifficultyBasedPlanner" + + def __init__(self, base_cfg: QueryPlanConfig): + super().__init__(base_cfg) + self.base_cfg = deepcopy(base_cfg) + self.heuristic_planner = HeuristicQueryPlanner(base_cfg) + + def classify_difficulty(self, query: str) -> str: + """ + Classify query as 'easy' or 'hard' based on comprehensive complexity indicators. + + Easy queries (Level 1-3): + - Simple definitions and fact recall + - Basic concept explanations + - Questions requiring some reasoning + - Combining multiple concepts + - Moderate complexity explanations + + Hard queries (Level 4-5): + - Multi-step reasoning and problem-solving + - Complex design and architecture questions + - Advanced analysis and trade-offs + - Open-ended questions requiring deep understanding + """ + q = query.lower().strip() + + # Length-based indicators + word_count = len(q.split()) + char_count = len(q) + + # VERY EASY indicators (Level 1) + very_easy_indicators = [ + r'\b(?:what is|what does|what are)\b.*\?$', # Simple definition questions + r'\b(?:is|are|was|were|do|does|did|can|could|will|would)\b.*\?$', # Yes/no questions + r'\b(?:define|definition|meaning|stands for)\b', # Direct definition requests + r'\b(?:capital|name|who|when|where)\b.*\?$', # Simple fact recall + ] + + # EASY indicators (Level 2) + easy_indicators = [ + r'\b(?:how do you|how to|how can you)\b', # Simple procedural questions + r'\b(?:create|make|build|write|declare)\b', # Simple action requests + r'\b(?:example|examples|instance)\b', # Example requests + r'\b(?:basic|simple|fundamental)\b', # Basic concept indicators + r'\b(?:chapter|section|ch\.|sec\.)\b', # Location-based queries + ] + + # MEDIUM indicators (Level 3) + medium_indicators = [ + r'\b(?:advantages|disadvantages|pros|cons|benefits|drawbacks)\b', # Trade-off analysis + r'\b(?:implement|implementation|algorithm|method|approach)\b', # Implementation questions + r'\b(?:difference|differences|distinguish|distinction)\b', # Comparison questions + r'\b(?:explain|explanation|describe|description)\b', # Explanatory questions + r'\b(?:why|because|reason|reasons|cause|causes)\b', # Reasoning questions + r'\b(?:memory|performance|optimization|efficiency)\b', # Technical analysis + r'\b(?:debug|debugging|troubleshoot|error|exception)\b', # Problem-solving + r'\b(?:binary search|sorting|recursion|recursive)\b', # Specific algorithms + r'\b(?:data structure|data structures)\b', # Data structure questions + r'\b(?:__init__|constructor|initialization)\b', # Object-oriented concepts + ] + + # HARD indicators (Level 4-5) + hard_indicators = [ + # Complex design and architecture + r'\b(?:design|architecture|system|framework|platform)\b', + r'\b(?:scalable|distributed|microservices|monolithic)\b', + r'\b(?:concurrent|parallel|threading|asynchronous)\b', + r'\b(?:security|authentication|authorization|encryption)\b', + + # Multi-step reasoning + r'\b(?:compare and contrast|analyze|evaluate|assess)\b', + r'\b(?:trade-offs|tradeoffs|trade offs)\b', + r'\b(?:migrate|migration|refactor|refactoring)\b', + r'\b(?:best practices|guidelines|standards|patterns)\b', + + # Complex procedural + r'\b(?:process|procedure|workflow|pipeline)\b', + r'\b(?:steps|stages|phases|iterations)\b', + r'\b(?:integration|interaction|combination)\b', + + # Advanced concepts + r'\b(?:paradigm|paradigms|methodology|methodologies)\b', + r'\b(?:constraints|limitations|requirements|specifications)\b', + r'\b(?:deployment|production|environment|infrastructure)\b', + ] + + # VERY HARD indicators (Level 5) + very_hard_indicators = [ + r'\b(?:research|research-level|advanced|cutting-edge)\b', + r'\b(?:open-ended|open ended|exploratory|investigative)\b', + r'\b(?:consensus|distributed consensus|consensus algorithm)\b', + r'\b(?:fault-tolerant|fault tolerant|resilient|robust)\b', + r'\b(?:real-time|real time|streaming|event-driven)\b', + r'\b(?:machine learning|ml|ai|artificial intelligence)\b', + r'\b(?:quantum|blockchain|cryptocurrency|decentralized)\b', + ] + + # Count indicators for each difficulty level + very_easy_score = sum(1 for pattern in very_easy_indicators if re.search(pattern, q)) + easy_score = sum(1 for pattern in easy_indicators if re.search(pattern, q)) + medium_score = sum(1 for pattern in medium_indicators if re.search(pattern, q)) + hard_score = sum(1 for pattern in hard_indicators if re.search(pattern, q)) + very_hard_score = sum(1 for pattern in very_hard_indicators if re.search(pattern, q)) + + # Length-based scoring + if word_count <= 5 and char_count <= 30: + very_easy_score += 1 + elif word_count <= 10 and char_count <= 60: + easy_score += 1 + elif word_count <= 20 and char_count <= 120: + medium_score += 1 + elif word_count > 20 or char_count > 120: + hard_score += 1 + + # Complexity indicators + if q.count('?') > 1 or ';' in q or ':' in q: + medium_score += 1 + + # Multiple clauses + clause_indicators = ['what', 'how', 'why', 'when', 'where', 'which'] + clause_count = sum(1 for indicator in clause_indicators if indicator in q) + if clause_count > 1: + medium_score += 1 + + # Multiple concepts (indicated by multiple technical terms) + technical_terms = ['algorithm', 'data structure', 'function', 'class', 'object', 'method', 'variable', 'loop', 'condition', 'exception', 'database', 'api', 'framework', 'library'] + tech_term_count = sum(1 for term in technical_terms if term in q) + if tech_term_count > 2: + hard_score += 1 + elif tech_term_count > 1: + medium_score += 1 + + # Decision logic with weighted scoring - binary classification + total_easy = very_easy_score + easy_score + medium_score # Combine levels 1-3 into easy + total_hard = hard_score + very_hard_score # Combine levels 4-5 into hard + + # Determine difficulty based on highest score + scores = { + 'easy': total_easy, + 'hard': total_hard + } + + # Special handling for specific patterns - binary classification + if re.search(r'\b(?:advantages|disadvantages)\b', q): + scores['easy'] += 1 # Trade-off analysis is easy + if re.search(r'\b(?:binary search|recursion|algorithm)\b', q): + scores['easy'] += 1 # Algorithm questions are easy + if re.search(r'\b(?:difference|differences)\b', q): + scores['easy'] += 1 # Comparison questions are easy + if re.search(r'\b(?:design|architecture|scalable|microservices)\b', q): + scores['hard'] += 2 # Strong hard indicator + if re.search(r'\b(?:compare and contrast)\b', q): + scores['hard'] += 2 # Strong hard indicator + if re.search(r'\b(?:implement.*algorithm|algorithm.*implement)\b', q): + scores['easy'] += 1 # Implementation of algorithms is easy + if re.search(r'\b(?:difference.*between|between.*difference)\b', q): + scores['easy'] += 1 # "difference between" is easy + if re.search(r'\b(?:rest|graphql|api|apis)\b', q): + scores['easy'] += 1 # API questions are easy + + # If no clear indicators, use length as tiebreaker + if max(scores.values()) == 0: + if word_count <= 15: + return "easy" + else: + return "hard" + + # Return the difficulty level with the highest score + return max(scores, key=scores.get) + + def get_easy_pipeline_config(self, query: str, base_cfg: QueryPlanConfig) -> QueryPlanConfig: + """ + Easy Pipeline Configuration: + - Optimized for speed and direct keyword matching + - Smaller chunk sizes for precise answers + - BM25-heavy ranking for exact matches + - Reduced pool size for efficiency + """ + cfg = deepcopy(base_cfg) + + # Extract location hints (same as heuristic planner) + ch_match = re.search(r"\b(?:chapter|ch\.)\s*(\d{1,3})\b", query, flags=re.IGNORECASE) + sec_match = re.search(r"\b(?:section|sec\.|ยง)\s*(\d{1,3}(?:\.\d{1,3})+)\b", query, flags=re.IGNORECASE) + if ch_match or sec_match: + cfg.location_hint = { + "chapter": int(ch_match.group(1)) if ch_match else None, + "section": sec_match.group(1) if sec_match else None, + "raw": query, + } + + # Easy pipeline optimizations + cfg.chunk_mode = "tokens" + cfg.chunk_config = TokenChunkConfig(max_tokens=200) # Use existing 200-token index + cfg.pool_size = min(cfg.pool_size, 30) # Reduced pool for speed + cfg.top_k = min(cfg.top_k, 3) # Fewer results needed + + # BM25-heavy ranking for keyword matching + cfg.ranker_weights = {"faiss": 0.2, "bm25": 0.7, "tf-idf": 0.1, "location": 0.0} + + # Boost location if hints present + if cfg.location_hint: + cfg.ranker_weights = {"faiss": 0.1, "bm25": 0.6, "tf-idf": 0.1, "location": 0.2} + + return cfg + + def get_hard_pipeline_config(self, query: str, base_cfg: QueryPlanConfig) -> QueryPlanConfig: + """ + Hard Pipeline Configuration: + - Optimized for comprehensive understanding + - Larger chunks for context + - FAISS-heavy ranking for semantic similarity + - Larger pool size for thorough search + """ + cfg = deepcopy(base_cfg) + + # Extract location hints + ch_match = re.search(r"\b(?:chapter|ch\.)\s*(\d{1,3})\b", query, flags=re.IGNORECASE) + sec_match = re.search(r"\b(?:section|sec\.|ยง)\s*(\d{1,3}(?:\.\d{1,3})+)\b", query, flags=re.IGNORECASE) + if ch_match or sec_match: + cfg.location_hint = { + "chapter": int(ch_match.group(1)) if ch_match else None, + "section": sec_match.group(1) if sec_match else None, + "raw": query, + } + + # Hard pipeline optimizations + cfg.chunk_mode = "sections" # Larger chunks for context + cfg.chunk_config = SectionChunkConfig() + cfg.pool_size = max(cfg.pool_size, 80) # Larger pool for comprehensive search + cfg.top_k = max(cfg.top_k, 7) # More results for complex queries + + # FAISS-heavy ranking for semantic understanding + cfg.ranker_weights = {"faiss": 0.6, "bm25": 0.2, "tf-idf": 0.2, "location": 0.0} + + # Boost location if hints present + if cfg.location_hint: + cfg.ranker_weights = {"faiss": 0.5, "bm25": 0.2, "tf-idf": 0.2, "location": 0.1} + + return cfg + + def plan(self, query: str) -> QueryPlanConfig: + """ + Main planning method that routes to appropriate pipeline. + """ + difficulty = self.classify_difficulty(query) + + if difficulty == "easy": + cfg = self.get_easy_pipeline_config(query, self.base_cfg) + cfg._pipeline_type = "easy" + else: + cfg = self.get_hard_pipeline_config(query, self.base_cfg) + cfg._pipeline_type = "hard" + + # Log the decision + self._log_decision(cfg, extra_info={"difficulty": difficulty, "pipeline": cfg._pipeline_type}) + + return cfg + + def _log_decision(self, cfg: QueryPlanConfig, extra_info: Dict[str, Any] = None): + """Log the planning decision with pipeline information.""" + info = { + "planner": self.name, + "pipeline_type": getattr(cfg, '_pipeline_type', 'unknown'), + "chunk_mode": cfg.chunk_mode, + "ranker_weights": cfg.ranker_weights, + "pool_size": cfg.pool_size, + "top_k": cfg.top_k, + } + + if extra_info: + info.update(extra_info) + + if cfg.location_hint: + info["location_hint"] = cfg.location_hint + + print(f"[PLANNER] {info}") diff --git a/src/planning/heuristics.py b/src/planning/heuristics.py index e7011914..b03fdbce 100644 --- a/src/planning/heuristics.py +++ b/src/planning/heuristics.py @@ -1,8 +1,15 @@ -from src.config import QueryPlanConfig -from src.chunking import CharChunkConfig, TokenChunkConfig, SlidingTokenConfig, SectionChunkConfig +import sys +from pathlib import Path + +# Add src to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from config import QueryPlanConfig +from chunking import CharChunkConfig, TokenChunkConfig, SlidingTokenConfig, SectionChunkConfig from copy import deepcopy +import re -from src.planning.planner import QueryPlanner +from planning.planner import QueryPlanner """ Heuristic Query Planner @@ -27,6 +34,9 @@ def __init__(self, base_cfg: QueryPlanConfig): def classify(self, query: str) -> str: q = query.lower() + # Check for location queries first + if any(x in q for x in ["where is", "where can", "where do", "where does", "in which", "what section", "what chapter"]): + return "location" if any(x in q for x in ["what is", "define", "definition"]): return "definition" if any(x in q for x in ["why", "explain", "because"]): @@ -39,15 +49,34 @@ def plan(self, query: str) -> QueryPlanConfig: kind = self.classify(query) cfg = deepcopy(self.base_cfg) - if kind == "definition": + # --- extract optional location hints --- + # chapter: "chapter 19" or "ch. 19" + # section: "section 19.3" or "ยง19.3" + ch_match = re.search(r"\b(?:chapter|ch\.)\s*(\d{1,3})\b", query, flags=re.IGNORECASE) + sec_match = re.search(r"\b(?:section|sec\.|ยง)\s*(\d{1,3}(?:\.\d{1,3})+)\b", query, flags=re.IGNORECASE) + location_hint = None + if ch_match or sec_match: + location_hint = { + "chapter": int(ch_match.group(1)) if ch_match else None, + "section": sec_match.group(1) if sec_match else None, + "raw": query, + } + cfg.location_hint = location_hint + + if kind == "location": + cfg.chunk_mode = "sections" + cfg.chunk_config = SectionChunkConfig() + cfg.ranker_weights = {"faiss": 0.6, "bm25": 0.2, "tf-idf": 0.1, "location": 0.1} + + elif kind == "definition": cfg.chunk_mode = "tokens" cfg.chunk_config = TokenChunkConfig(max_tokens=200) - cfg.ranker_weights = {"faiss": 0.3, "bm25": 0.6, "tf-idf": 0.1} + cfg.ranker_weights = {"faiss": 0.3, "bm25": 0.6, "tf-idf": 0.1, "location": 0.0} elif kind == "explanatory": cfg.chunk_mode = "sections" cfg.chunk_config = SectionChunkConfig() - cfg.ranker_weights = {"faiss": 0.7, "bm25": 0.2, "tf-idf": 0.1} + cfg.ranker_weights = {"faiss": 0.7, "bm25": 0.2, "tf-idf": 0.1, "location": 0.0} elif kind == "procedural": cfg.chunk_mode = "sliding-tokens" @@ -57,13 +86,24 @@ def plan(self, query: str) -> QueryPlanConfig: tokenizer_name=cfg.embed_model, ) cfg.pool_size = max(cfg.pool_size, cfg.top_k * 5) - cfg.ranker_weights = {"faiss": 0.5, "bm25": 0.2, "tf-idf": 0.3} + cfg.ranker_weights = {"faiss": 0.5, "bm25": 0.2, "tf-idf": 0.3, "location": 0.0} else: print("Unknown query type. Defaulting to explanatory.") cfg.chunk_mode = "sections" cfg.chunk_config = SectionChunkConfig() - cfg.ranker_weights = {"faiss": 0.7, "bm25": 0.2, "tf-idf": 0.1} + cfg.ranker_weights = {"faiss": 0.7, "bm25": 0.2, "tf-idf": 0.1, "location": 0.0} + + # If location hints are present, boost location ranker weight + if cfg.location_hint: + # Reduce other weights proportionally to make room for location + total_other = sum(v for k, v in cfg.ranker_weights.items() if k != "location") + if total_other > 0: + scale_factor = 0.9 # Reserve 10% for location (more conservative) + for k in cfg.ranker_weights: + if k != "location": + cfg.ranker_weights[k] *= scale_factor + cfg.ranker_weights["location"] = 0.1 self._log_decision(cfg) return cfg diff --git a/src/planning/planner.py b/src/planning/planner.py index d0e7f3fa..32d3be60 100644 --- a/src/planning/planner.py +++ b/src/planning/planner.py @@ -3,8 +3,14 @@ from typing import Any, Dict from copy import deepcopy -from src.config import QueryPlanConfig -from src.instrumentation.logging import get_logger +import sys +from pathlib import Path + +# Add src to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from config import QueryPlanConfig +from instrumentation.logging import get_logger class QueryPlanner(ABC): diff --git a/src/ranking/rankers.py b/src/ranking/rankers.py index 6c511a4a..e8a51bae 100644 --- a/src/ranking/rankers.py +++ b/src/ranking/rankers.py @@ -74,3 +74,73 @@ def score(self, *, query, chunks, cand_idxs, context): return out +class LocationRanker(Ranker): + name = "location" + + def prepare(self, *, query, chunks, cand_idxs, context): + """Extract location hints from query and prepare for scoring.""" + # Location hints are passed via context from the query planner + context["_location_hint"] = context.get("location_hint", None) + + def score(self, *, query, chunks, cand_idxs, context): + """Score chunks based on location hints (chapter/section matching).""" + hint = context.get("_location_hint") + if not hint: + return {i: 0.0 for i in cand_idxs} + + want_ch = hint.get("chapter") if isinstance(hint, dict) else None + want_sec = hint.get("section") if isinstance(hint, dict) else None + metadata = context.get("metadata", []) + + out = {} + for i in cand_idxs: + score = 0.0 + try: + section_heading = metadata[i].get("section") if i < len(metadata) else None + except Exception: + section_heading = None + + if section_heading: + ch, sec = _extract_numbering_from_heading(section_heading) + + # Section match gets higher score (more specific) + if want_sec and sec and str(sec).startswith(str(want_sec)): + score += 1.0 + + # Chapter match gets lower score (less specific) + if want_ch is not None and ch is not None and int(ch) == int(want_ch): + score += 0.5 + + out[i] = score + + return out + + +def _extract_numbering_from_heading(heading: str) -> tuple[Optional[int], Optional[str]]: + """ + Extract chapter and section numbers from a heading string. + Returns (chapter_num, section_str) or (None, None) if not found. + """ + if not heading: + return None, None + + import re + + # Try to match section patterns like "19.3", "19.3.1", etc. + section_match = re.search(r'(\d+(?:\.\d+)+)', heading) + if section_match: + section_str = section_match.group(1) + # Extract chapter from section (first number) + chapter_match = re.search(r'^(\d+)', section_str) + chapter_num = int(chapter_match.group(1)) if chapter_match else None + return chapter_num, section_str + + # Try to match chapter patterns like "Chapter 19", "Ch. 19", etc. + chapter_match = re.search(r'(?:chapter|ch\.?)\s*(\d+)', heading, re.IGNORECASE) + if chapter_match: + chapter_num = int(chapter_match.group(1)) + return chapter_num, None + + return None, None + + diff --git a/src/retriever.py b/src/retriever.py index 5f3dbfcf..b38654bf 100644 --- a/src/retriever.py +++ b/src/retriever.py @@ -14,6 +14,7 @@ from __future__ import annotations import pickle from typing import List, Tuple, Optional, Dict +import re import faiss from sentence_transformers import SentenceTransformer @@ -32,7 +33,7 @@ def _get_embedder(model_name: str) -> SentenceTransformer: # -------------------------- Artifacts I/O ------------------------------- -def load_artifacts(index_prefix: str, cfg: QueryPlanConfig) -> Tuple[faiss.Index, List[str], List[str], object, Optional[List[List[str]]]]: +def load_artifacts(index_prefix: str, cfg: QueryPlanConfig) -> Tuple[faiss.Index, List[str], List[str], object, Optional[List[List[str]]], List[Dict]]: """ Loads: - FAISS index: {index_prefix}.faiss @@ -48,6 +49,10 @@ def load_artifacts(index_prefix: str, cfg: QueryPlanConfig) -> Tuple[faiss.Index index = faiss.read_index(f"{faiss_prefix}.faiss") chunks = pickle.load(open(f"{faiss_prefix}_chunks.pkl", "rb")) sources = pickle.load(open(f"{faiss_prefix}_sources.pkl", "rb")) + try: + metadata = pickle.load(open(f"{faiss_prefix}_meta.pkl", "rb")) + except Exception: + metadata = [{} for _ in range(len(chunks))] try: vectorizer = pickle.load(open(f"{meta_prefix}_tfidf.pkl", "rb")) @@ -55,7 +60,7 @@ def load_artifacts(index_prefix: str, cfg: QueryPlanConfig) -> Tuple[faiss.Index except Exception: vectorizer, chunk_tags = None, None - return index, chunks, sources, vectorizer, chunk_tags + return index, chunks, sources, vectorizer, chunk_tags, metadata # -------------------------- Pretty previews ----------------------------- @@ -115,3 +120,37 @@ def apply_seg_filter(cfg: QueryPlanConfig, chunks, ordered): else: topk_idxs = ordered[:cfg.top_k] return topk_idxs + + +# --------------------- Location-aware boosting --------------------------- +def boost_by_location(ordered: List[int], metadata: List[Dict], cfg: QueryPlanConfig) -> List[int]: + """ + DEPRECATED: This function is being replaced by LocationRanker integration. + Reorder indices in 'ordered' by boosting those whose metadata section matches + cfg.location_hint (chapter or section). Stable within equal bonus. + """ + from src.ranking.rankers import _extract_numbering_from_heading + + hint = getattr(cfg, "location_hint", None) + if not hint: + return ordered + + want_ch = hint.get("chapter") if isinstance(hint, dict) else None + want_sec = hint.get("section") if isinstance(hint, dict) else None + + def bonus_for_idx(idx: int) -> float: + try: + section_heading = metadata[idx].get("section") + except Exception: + section_heading = None + ch, sec = _extract_numbering_from_heading(section_heading or "") + bonus = 0.0 + if want_sec and sec and str(sec).startswith(str(want_sec)): + bonus += 2.0 + if want_ch is not None and ch is not None and int(ch) == int(want_ch): + bonus += 1.0 + return bonus + + with_bonus = [(idx, bonus_for_idx(idx), pos) for pos, idx in enumerate(ordered)] + with_bonus.sort(key=lambda t: (-t[1], t[2])) + return [idx for idx, _, _ in with_bonus]