Skip to content

Rag complete contribute #222

@salvino080-coder

Description

@salvino080-coder
#!/usr/bin/env python3
"""
RAG Toolkit — Complete single-file toolkit with embedded guide (ENGLISH)

This file contains:
- A comprehensive guide (WHAT IS MISSING, HOW TO FIX, HOW TO TEST) embedded below.
- A modular Python toolkit that implements:
  - Embeddings (sentence-transformers)
  - FAISS indexing (CPU/GPU-aware)
  - Simple RAG pipeline (retrieve + generate)
  - Hugging Face model loading with optional 8-bit (bitsandbytes)
  - Benchmark harness (latency, throughput, memory, optional GPU power)
  - ONNX export helper and Triton model-repo builder helper
  - BentoML placeholder exporter
  - MLflow + HF Hub helper hooks
  - DeepSpeed config generator helper
  - CLI for common tasks

USAGE (short):
  python rag_toolkit_complete.py build-index --corpus data/corpus.txt --index-out data/my.index
  python rag_toolkit_complete.py benchmark --model gpt2 --prompts data/prompts.txt --index data/my.index --corpus data/corpus.txt --iters 50
  python rag_toolkit_complete.py export-onnx --model gpt2 --out models/gpt2.onnx

EMBEDDED GUIDE (ENGLISH)
-------------------------
This guide explains remaining gaps in the code, how to resolve them, and how to test everything end-to-end.
Paste this file in a repository and follow the sections below.

1) Overview of remaining gaps (why they exist)
- Native dependencies and CUDA toolchain: some performance features (faiss-gpu, bitsandbytes, TensorRT) require compatible CUDA, drivers, and compiled wheels.
- Production-grade exporters: ONNX export in this script is a best-effort single-pass export. For Triton/TensorRT one must use Hugging Face Optimum, Torch-TensorRT or vendor tools and tune per-model.
- Distributed training: DeepSpeed/FSDP require cluster setup and resource planning (multiple GPUs, node networking, NVMe/SSD for offload).
- Serving at scale: This script provides helpers and placeholders (BentoML/Triton) but not a full operator, autoscaler, or canary deployment pipeline. Kubernetes manifests/Helm charts are produced externally.
- Safety & governance: toxicity / bias checks and license enforcement require policy definitions and test datasets.

2) How to resolve each gap (step-by-step, practical commands)

A. Ensure correct GPU drivers & PyTorch
 - Check GPU drivers:
     nvidia-smi
 - Install PyTorch matching CUDA version (example for CUDA 11.8):
     pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
 - Verify:
     python -c "import torch; print(torch.cuda.is_available(), torch.version.cuda)"

B. Install bitsandbytes (8-bit) safely
 - Recommended: use pip when CUDA and gcc toolchain match. Example:
     pip install --upgrade pip
     pip install bitsandbytes
 - If pip wheel fails, check bitsandbytes docs for building from source.
 - Test:
     python -c "import bitsandbytes as bnb; print('bnb OK')"

C. Install faiss-gpu
 - Best via conda:
     conda install -c pytorch faiss-gpu cudatoolkit=11.8
 - Or use faiss-cpu for CPU-only:
     pip install faiss-cpu
 - Validate:
     python -c "import faiss; print('faiss OK')"

D. DeepSpeed and distributed training
 - Install:
     pip install deepspeed
 - Generate a ds_config with this script:
     python rag_toolkit_complete.py gen-deepspeed --zero-stage 3 --out ds_config.json
 - Launch (example single-node 4 GPUs):
     deepspeed --num_gpus=4 train.py --deepspeed_config ds_config.json

E. Optimum / ONNX / TensorRT / Triton (export & serving)
 - Install optimum and onnxruntime:
     pip install optimum onnxruntime onnx
 - Use optimum for better conversion of Transformers to ONNX/TensorRT (check Optimum docs).
 - Build Triton model repo using this script helper:
     python rag_toolkit_complete.py build-triton --onnx models/gpt2.onnx --model-name gpt2
 - Deploy Triton server (follow NVIDIA Triton docs) and mount model repo.

F. BentoML (serving bundle)
 - Install BentoML:
     pip install bentoml
 - Use save-bento command:
     python rag_toolkit_complete.py save-bento --model gpt2 --out-dir bento_bundle
 - Implement a BentoML Service class that wraps tokenizer+model and exposes REST/gRPC.

G. Vector DB (Milvus) for persistent large indices
 - Install Milvus server or use hosted service.
 - Install client:
     pip install pymilvus
 - Use Milvus to store embeddings and build distributed indexes for >10M vectors.

H. Observability and CI/CD
 - MLflow:
     pip install mlflow
 - Prometheus / Grafana: instrument the service; expose /metrics via prometheus_client.
 - CI: build Docker image in CI, run unit tests and lightweight integration tests, publish image to registry.

3) How to test locally (quickstart)
- Create venv and install minimal deps:
    python -m venv .venv && source .venv/bin/activate
    pip install -U pip
    pip install transformers sentence-transformers torch faiss-cpu psutil
- Prepare small corpus:
    mkdir -p data
    printf "Document one\nDocument two\nDocument three\n" > data/corpus.txt
    printf "Write a summary of Document one\nWhat is Document two about?\n" > data/prompts.txt
- Build index:
    python rag_toolkit_complete.py build-index --corpus data/corpus.txt --index-out data/my.index
- Benchmark:
    python rag_toolkit_complete.py benchmark --model gpt2 --prompts data/prompts.txt --index data/my.index --corpus data/corpus.txt --iters 5
- Export ONNX (CPU test):
    python rag_toolkit_complete.py export-onnx --model gpt2 --out models/gpt2.onnx

4) Validation checklist (what to verify)
- Model loaded and responds to prompts (sanity check).
- Index built and retrieval returns plausible documents.
- Benchmark produces reasonable latencies and writes output JSON.
- ONNX file exists and can be inspected with onnx.checker.
- Optional: MLflow/Hub artifacts logged or uploaded successfully.

5) Production recommendations (short)
- Split services: embedding service, retriever, generator (each containerized).
- Use a vector DB (Milvus/Weaviate) for persistence and scaling.
- Use Triton or BentoML for high-throughput serving; put generator behind an API Gateway and rate limiting.
- Integrate monitoring: Prometheus metrics, Grafana dashboards, traces via OpenTelemetry.
- Implement safety pipelines: automatic toxicity checks, model cards, license scanning.
- Use DeepSpeed/ZeRO and quantization (bitsandbytes) to reduce cost for large models.
- Implement canary/blue-green deploys and autoscaling policies for GPU-backed pods.

6) Cost measurement & optimization
- Measure cost per 1M requests using benchmark QPS and cloud instance pricing.
- Try quantized models (8-bit/4-bit), lower precision (fp16), or smaller distilled checkpoints for cheaper inference.
- Use batching, caching, and shard models for throughput improvements.

7) Troubleshooting common issues
- Out-of-memory when loading model: try load_in_8bit, device_map="auto", or move to a machine with more GPU memory.
- ONNX export fails: try smaller model or use Optimum exporter; ensure matching torch/onnx/onnxruntime versions.
- faiss-gpu import errors: ensure CUDA and conda package version compatibility.

End of embedded guide.
-------------------------

The code implementation follows below.
"""

from dataclasses import dataclass, asdict
from pathlib import Path
import argparse
import json
import logging
import os
import subprocess
import sys
import time
from typing import List, Optional, Tuple

import numpy as np
import psutil
import torch

# Optional components detection
try:
    from transformers import AutoTokenizer, AutoModelForCausalLM
    TRANSFORMERS_AVAILABLE = True
except Exception:
    TRANSFORMERS_AVAILABLE = False

try:
    from sentence_transformers import SentenceTransformer
    SENTENCE_TRANSFORMERS_AVAILABLE = True
except Exception:
    SentenceTransformer = None
    SENTENCE_TRANSFORMERS_AVAILABLE = False

try:
    import faiss
    FAISS_AVAILABLE = True
except Exception:
    faiss = None
    FAISS_AVAILABLE = False

try:
    import bitsandbytes as bnb  # noqa: F401
    BNB_AVAILABLE = True
except Exception:
    BNB_AVAILABLE = False

try:
    import mlflow
    MLFLOW_AVAILABLE = True
except Exception:
    mlflow = None
    MLFLOW_AVAILABLE = False

try:
    import bentoml
    BENTOML_AVAILABLE = True
except Exception:
    bentoml = None
    BENTOML_AVAILABLE = False

try:
    from huggingface_hub import HfApi
    HF_HUB_AVAILABLE = True
except Exception:
    HfApi = None
    HF_HUB_AVAILABLE = False

logger = logging.getLogger("rag_toolkit_complete")
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
logger.addHandler(handler)
logger.setLevel(logging.INFO)


@dataclass
class BenchmarkResult:
    count: int
    mean: float
    median: float
    p50: float
    p90: float
    p99: float
    throughput_qps: Optional[float]
    mem_samples: List[dict]
    power_avg: Optional[float]


# -----------------------
# System utilities
# -----------------------
def is_cuda_available() -> bool:
    return torch.cuda.is_available()


def measure_system_memory() -> dict:
    vm = psutil.virtual_memory()
    return {"total_gb": vm.total / (1024 ** 3), "available_gb": vm.available / (1024 ** 3), "used_gb": vm.used / (1024 ** 3), "percent": vm.percent}


def measure_gpu_memory(device: int = 0) -> dict:
    if not is_cuda_available():
        return {"cuda_available": False}
    try:
        allocated = torch.cuda.memory_allocated(device) / (1024 ** 3)
        reserved = torch.cuda.memory_reserved(device) / (1024 ** 3)
        return {"cuda_available": True, "allocated_gb": allocated, "reserved_gb": reserved}
    except Exception as e:
        return {"cuda_available": True, "error": str(e)}


def get_nvidia_power_draw() -> Optional[float]:
    try:
        out = subprocess.check_output(["nvidia-smi", "--query-gpu=power.draw", "--format=csv,noheader,nounits"], stderr=subprocess.DEVNULL)
        vals = [float(x.strip()) for x in out.decode().splitlines() if x.strip()]
        return float(sum(vals) / len(vals)) if vals else None
    except Exception:
        return None


# -----------------------
# Embeddings & FAISS
# -----------------------
def load_sentence_transformer(model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
    if not SENTENCE_TRANSFORMERS_AVAILABLE:
        raise RuntimeError("sentence-transformers not installed. Install it to use embeddings.")
    logger.info(f"Loading SentenceTransformer: {model_name}")
    return SentenceTransformer(model_name)


def encode_texts(emb_model, texts: List[str], batch_size: int = 128) -> np.ndarray:
    logger.info(f"Encoding {len(texts)} texts (batch_size={batch_size})")
    embeddings = emb_model.encode(texts, batch_size=batch_size, convert_to_numpy=True, normalize_embeddings=True)
    return np.asarray(embeddings, dtype=np.float32)


def build_faiss_index(embeddings: np.ndarray, index_path: Optional[str] = None, n_list: int = 128):
    if not FAISS_AVAILABLE:
        raise RuntimeError("faiss is required. Install faiss-cpu or faiss-gpu.")
    d = embeddings.shape[1]
    logger.info(f"Building FAISS Index IVF (d={d}, nlist={n_list})")
    quantizer = faiss.IndexFlatL2(d)
    index = faiss.IndexIVFFlat(quantizer, d, n_list, faiss.METRIC_L2)
    index.train(embeddings)
    index.add(embeddings)
    if index_path:
        faiss.write_index(index, index_path)
        logger.info(f"Saved FAISS index to {index_path}")
    return index


def load_faiss_index(index_path: str):
    if not FAISS_AVAILABLE:
        raise RuntimeError("faiss is required.")
    if not os.path.exists(index_path):
        raise FileNotFoundError(index_path)
    idx = faiss.read_index(index_path)
    logger.info(f"Loaded FAISS index from {index_path}")
    return idx


# -----------------------
# Model loading & generation
# -----------------------
def load_generation_model(model_name: str, device: str = "cuda", load_in_8bit: bool = False, trust_remote_code: bool = False):
    if not TRANSFORMERS_AVAILABLE:
        raise RuntimeError("transformers not installed.")
    logger.info(f"Loading generation model {model_name} (device={device}, 8bit={load_in_8bit})")
    load_kwargs = {}
    if load_in_8bit and BNB_AVAILABLE:
        load_kwargs["load_in_8bit"] = True
        load_kwargs["device_map"] = "auto"
    else:
        if is_cuda_available() and device.startswith("cuda"):
            load_kwargs["torch_dtype"] = torch.float16
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=trust_remote_code)
    model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs, trust_remote_code=trust_remote_code)
    if "device_map" not in load_kwargs:
        dev = torch.device(device if is_cuda_available() else "cpu")
        model.to(dev)
    model.eval()
    logger.info("Model loaded.")
    return tokenizer, model


def generate_with_context(tokenizer, model, prompt: str, context_docs: Optional[List[str]] = None, max_new_tokens: int = 128, device: str = "cuda", **gen_kwargs):
    context = ("\n\n".join(context_docs) + "\n\n") if context_docs else ""
    full = context + prompt if context else prompt
    inputs = tokenizer(full, return_tensors="pt", truncation=True)
    if is_cuda_available() and device.startswith("cuda"):
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id, **gen_kwargs)
    text = tokenizer.decode(out[0], skip_special_tokens=True)
    if text.startswith(full):
        return text[len(full):].strip()
    return text


# -----------------------
# ONNX & Triton helpers
# -----------------------
def export_model_to_onnx(tokenizer, model, output_path: str, opset: int = 13):
    logger.info(f"Exporting model to ONNX at {output_path} (best-effort)")
    model_cpu = model.to("cpu")
    model_cpu.eval()
    dummy = "This is a dummy input"
    inputs = tokenizer(dummy, return_tensors="pt")
    input_ids = inputs["input_ids"]
    try:
        torch.onnx.export(model_cpu, (input_ids,), output_path, input_names=["input_ids"], output_names=["logits"], opset_version=opset, dynamic_axes={"input_ids": {0: "batch", 1: "sequence"}, "logits": {0: "batch", 1: "sequence"}})
        logger.info("ONNX export succeeded.")
    except Exception:
        logger.exception("ONNX export failed. Consider using Optimum or vendor exporters.")
        raise
    finally:
        if is_cuda_available():
            try:
                model.to("cuda")
            except Exception:
                pass
    return output_path


def build_triton_model_repo(onnx_path: str, model_name: str, model_version: int = 1, repo_dir: str = "triton_model_repo"):
    logger.info("Creating Triton model repo (ONNX) ...")
    model_dir = Path(repo_dir) / model_name / str(model_version)
    model_dir.mkdir(parents=True, exist_ok=True)
    dest = model_dir / Path(onnx_path).name
    dest.write_bytes(Path(onnx_path).read_bytes())
    cfg = f"""name: "{model_name}"
platform: "onnxruntime_onnx"
max_batch_size: 8
input [
  {{
    name: "input_ids"
    data_type: TYPE_INT32
    dims: [-1]
  }}
]
output [
  {{
    name: "logits"
    data_type: TYPE_FP32
    dims: [-1, -1]
  }}
]
"""
    cfgp = Path(repo_dir) / model_name / "config.pbtxt"
    cfgp.write_text(cfg)
    logger.info(f"Triton model repo created at {Path(repo_dir)/model_name}")
    return str(Path(repo_dir) / model_name)


# -----------------------
# BentoML placeholder
# -----------------------
def save_bentoml_bundle(tokenizer, model, save_dir: str = "bento_bundle"):
    if not BENTOML_AVAILABLE:
        raise RuntimeError("BentoML not installed.")
    out = Path(save_dir)
    out.mkdir(parents=True, exist_ok=True)
    torch.save(model.state_dict(), out / "model_state.pt")
    tokenizer.save_pretrained(out / "tokenizer")
    logger.info(f"BentoML bundle saved at {out}")
    return str(out)


# -----------------------
# MLflow & HF Hub helpers
# -----------------------
def mlflow_log_artifact_local(path: str):
    if not MLFLOW_AVAILABLE:
        logger.warning("MLflow not installed; skip logging.")
        return
    mlflow.log_artifact(path)
    logger.info(f"Logged {path} to MLflow.")


def push_file_to_hf(file_path: str, repo_id: str, path_in_repo: Optional[str] = None, token: Optional[str] = None):
    if not HF_HUB_AVAILABLE:
        raise RuntimeError("huggingface_hub not available.")
    api = HfApi()
    token = token or os.environ.get("HF_TOKEN")
    dest = path_in_repo or Path(file_path).name
    logger.info(f"Pushing {file_path} to HF repo {repo_id}/{dest}")
    api.upload_file(path_or_fileobj=file_path, path_in_repo=dest, repo_id=repo_id, token=token)
    logger.info("Upload succeeded.")


# -----------------------
# DeepSpeed helpers
# -----------------------
def generate_deepspeed_config(zero_stage: int = 3, offload_type: str = "cpu", fp16_enabled: bool = True, train_batch_size: int = 1, gradient_accumulation_steps: int = 1):
    cfg = {
        "train_batch_size": train_batch_size,
        "gradient_accumulation_steps": gradient_accumulation_steps,
        "fp16": {"enabled": fp16_enabled},
        "zero_optimization": {
            "stage": zero_stage,
            "offload_param": {"device": offload_type} if offload_type else {},
            "offload_optimizer": {"device": offload_type} if offload_type else {},
        },
    }
    return cfg


def write_deepspeed_config(cfg: dict, out_path: str = "ds_config.json"):
    Path(out_path).write_text(json.dumps(cfg, indent=2))
    logger.info(f"Wrote DeepSpeed config to {out_path}")
    return out_path


# -----------------------
# RAG pipeline & benchmark
# -----------------------
class RAGPipeline:
    def __init__(self, embed_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
        if not SENTENCE_TRANSFORMERS_AVAILABLE:
            raise RuntimeError("sentence-transformers required.")
        self.embed_model = load_sentence_transformer(embed_model_name)
        self.faiss_index = None
        self.corpus: List[str] = []

    def build_index(self, corpus: List[str], index_out: Optional[str] = None, n_list: int = 128):
        self.corpus = corpus
        emb = encode_texts(self.embed_model, corpus)
        self.faiss_index = build_faiss_index(emb, index_path=index_out, n_list=n_list)
        logger.info("FAISS index built.")

    def load_index(self, index_path: str, corpus_path: Optional[str] = None):
        self.faiss_index = load_faiss_index(index_path)
        if corpus_path:
            self.corpus = [l.strip() for l in Path(corpus_path).read_text(encoding="utf-8").splitlines() if l.strip()]
        logger.info("Index and corpus loaded.")

    def retrieve(self, query: str, top_k: int = 5):
        q_emb = self.embed_model.encode([query], normalize_embeddings=True).astype(np.float32)
        dists, idxs = self.faiss_index.search(q_emb, top_k)
        results = []
        for idx, dist in zip(idxs[0], dists[0]):
            if idx < 0:
                continue
            results.append({"id": int(idx), "score": float(dist), "text": self.corpus[int(idx)]})
        return results

    def generate(self, tokenizer, model, query: str, top_k: int = 3, max_new_tokens: int = 128, device: str = "cuda", **gen_kwargs):
        docs = self.retrieve(query, top_k=top_k)
        ctx = [d["text"] for d in docs]
        return generate_with_context(tokenizer, model, query, context_docs=ctx, max_new_tokens=max_new_tokens, device=device, **gen_kwargs)


def run_benchmark(tokenizer, model, prompts: List[str], rag: Optional[RAGPipeline], iters: int = 50, warmup: int = 3, top_k: int = 3, device: str = "cuda") -> BenchmarkResult:
    logger.info("Starting benchmark (warmup then runs)...")
    for _ in range(warmup):
        q = prompts[0]
        if rag:
            _ = rag.generate(tokenizer, model, q, top_k=top_k, device=device)
        else:
            _ = generate_with_context(tokenizer, model, q, None, device=device)
    times = []
    mem_samples = []
    power_samples = []
    for i in range(iters):
        q = prompts[i % len(prompts)]
        start = time.time()
        if rag:
            _ = rag.generate(tokenizer, model, q, top_k=top_k, device=device)
        else:
            _ = generate_with_context(tokenizer, model, q, None, device=device)
        end = time.time()
        lat = end - start
        times.append(lat)
        mem = measure_system_memory()
        gpu = measure_gpu_memory()
        mem_samples.append({"iter": i, "latency_s": lat, "cpu": mem, "gpu": gpu})
        pw = get_nvidia_power_draw()
        if pw:
            power_samples.append(pw)
        if (i + 1) % max(1, iters // 5) == 0:
            logger.info(f"Iteration {i+1}/{iters} latency {lat:.3f}s")
    arr = np.array(times)
    res = BenchmarkResult(count=len(arr), mean=float(arr.mean()), median=float(np.median(arr)), p50=float(np.percentile(arr, 50)), p90=float(np.percentile(arr, 90)), p99=float(np.percentile(arr, 99)), throughput_qps=float(1.0 / arr.mean()) if arr.mean() > 0 else None, mem_samples=mem_samples, power_avg=float(np.mean(power_samples)) if power_samples else None)
    if MLFLOW_AVAILABLE:
        mlflow.log_metrics({"mean_latency": res.mean, "p90": res.p90, "throughput_qps": res.throughput_qps})
    return res


# -----------------------
# CLI
# -----------------------
def parse_args():
    p = argparse.ArgumentParser(description="RAG Toolkit Complete CLI")
    sub = p.add_subparsers(dest="cmd")

    bi = sub.add_parser("build-index", help="Build FAISS index from corpus")
    bi.add_argument("--corpus", required=True)
    bi.add_argument("--index-out", required=True)
    bi.add_argument("--embed-model", default="sentence-transformers/all-MiniLM-L6-v2")
    bi.add_argument("--n-list", type=int, default=128)

    bm = sub.add_parser("benchmark", help="Benchmark a model (optionally with RAG)")
    bm.add_argument("--model", required=True)
    bm.add_argument("--prompts", required=True)
    bm.add_argument("--corpus", help="corpus file path for retrieval")
    bm.add_argument("--index", help="faiss index path")
    bm.add_argument("--embed-model", default="sentence-transformers/all-MiniLM-L6-v2")
    bm.add_argument("--use-8bit", action="store_true")
    bm.add_argument("--iters", type=int, default=50)
    bm.add_argument("--device", default="cuda" if is_cuda_available() else "cpu")
    bm.add_argument("--output", default="benchmark_result.json")

    ex = sub.add_parser("export-onnx", help="Export model to ONNX")
    ex.add_argument("--model", required=True)
    ex.add_argument("--out", required=True)

    tr = sub.add_parser("build-triton", help="Build Triton model repo from ONNX")
    tr.add_argument("--onnx", required=True)
    tr.add_argument("--model-name", required=True)
    tr.add_argument("--repo-dir", default="triton_model_repo")

    hf = sub.add_parser("push-hf", help="Push artifact to Hugging Face Hub")
    hf.add_argument("--artifact", required=True)
    hf.add_argument("--repo-id", required=True)
    hf.add_argument("--path-in-repo", help="path inside repo")
    hf.add_argument("--token", help="HF token")

    ds = sub.add_parser("gen-deepspeed", help="Generate DeepSpeed config")
    ds.add_argument("--zero-stage", type=int, default=3)
    ds.add_argument("--offload-type", default="cpu")
    ds.add_argument("--fp16", action="store_true")
    ds.add_argument("--out", default="ds_config.json")

    sb = sub.add_parser("save-bento", help="Save BentoML bundle (placeholder)")
    sb.add_argument("--model", required=True)
    sb.add_argument("--out-dir", default="bento_bundle")

    return p.parse_args()


def main():
    args = parse_args()
    if args.cmd == "build-index":
        texts = [l.strip() for l in Path(args.corpus).read_text(encoding="utf-8").splitlines() if l.strip()]
        rag = RAGPipeline(embed_model_name=args.embed_model)
        rag.build_index(texts, index_out=args.index_out, n_list=args.n_list)
        Path(args.corpus + ".saved").write_text("\n".join(texts), encoding="utf-8")
        logger.info("Index built.")
    elif args.cmd == "benchmark":
        raw = Path(args.prompts).read_text(encoding="utf-8").strip()
        if raw.startswith("["):
            prompts = json.loads(raw)
        else:
            prompts = [l.strip() for l in raw.splitlines() if l.strip()]
        rag_pipeline = None
        if args.corpus and args.index:
            corpus = [l.strip() for l in Path(args.corpus).read_text(encoding="utf-8").splitlines() if l.strip()]
            rag_pipeline = RAGPipeline(embed_model_name=args.embed_model)
            rag_pipeline.load_index(args.index, corpus_path=args.corpus)
        tokenizer, model = load_generation_model(args.model, device=args.device, load_in_8bit=args.use_8bit)
        res = run_benchmark(tokenizer, model, prompts, rag_pipeline, iters=args.iters, device=args.device)
        Path(args.output).write_text(json.dumps(asdict(res), indent=2), encoding="utf-8")
        logger.info(f"Saved benchmark to {args.output}")
        if MLFLOW_AVAILABLE:
            mlflow.log_artifact(args.output)
    elif args.cmd == "export-onnx":
        tokenizer, model = load_generation_model(args.model, device="cpu", load_in_8bit=False)
        out = export_model_to_onnx(tokenizer, model, args.out)
        logger.info(f"ONNX exported to {out}")
    elif args.cmd == "build-triton":
        repo = build_triton_model_repo(args.onnx, args.model_name, repo_dir=args.repo_dir)
        logger.info(f"Triton repo ready at {repo}")
    elif args.cmd == "push-hf":
        push_file_to_hf(args.artifact, args.repo_id, path_in_repo=args.path_in_repo, token=args.token)
    elif args.cmd == "gen-deepspeed":
        cfg = generate_deepspeed_config(zero_stage=args.zero_stage, offload_type=args.offload_type, fp16_enabled=args.fp16)
        write_deepspeed_config(cfg, out_path=args.out)
    elif args.cmd == "save-bento":
        tokenizer, model = load_generation_model(args.model, device="cpu", load_in_8bit=False)
        save_bentoml_bundle(tokenizer, model, save_dir=args.out_dir)
    else:
        print("No command specified. Use --help.")


if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions