-
Notifications
You must be signed in to change notification settings - Fork 291
Open
Description
#!/usr/bin/env python3
"""
RAG Toolkit — Complete single-file toolkit with embedded guide (ENGLISH)
This file contains:
- A comprehensive guide (WHAT IS MISSING, HOW TO FIX, HOW TO TEST) embedded below.
- A modular Python toolkit that implements:
- Embeddings (sentence-transformers)
- FAISS indexing (CPU/GPU-aware)
- Simple RAG pipeline (retrieve + generate)
- Hugging Face model loading with optional 8-bit (bitsandbytes)
- Benchmark harness (latency, throughput, memory, optional GPU power)
- ONNX export helper and Triton model-repo builder helper
- BentoML placeholder exporter
- MLflow + HF Hub helper hooks
- DeepSpeed config generator helper
- CLI for common tasks
USAGE (short):
python rag_toolkit_complete.py build-index --corpus data/corpus.txt --index-out data/my.index
python rag_toolkit_complete.py benchmark --model gpt2 --prompts data/prompts.txt --index data/my.index --corpus data/corpus.txt --iters 50
python rag_toolkit_complete.py export-onnx --model gpt2 --out models/gpt2.onnx
EMBEDDED GUIDE (ENGLISH)
-------------------------
This guide explains remaining gaps in the code, how to resolve them, and how to test everything end-to-end.
Paste this file in a repository and follow the sections below.
1) Overview of remaining gaps (why they exist)
- Native dependencies and CUDA toolchain: some performance features (faiss-gpu, bitsandbytes, TensorRT) require compatible CUDA, drivers, and compiled wheels.
- Production-grade exporters: ONNX export in this script is a best-effort single-pass export. For Triton/TensorRT one must use Hugging Face Optimum, Torch-TensorRT or vendor tools and tune per-model.
- Distributed training: DeepSpeed/FSDP require cluster setup and resource planning (multiple GPUs, node networking, NVMe/SSD for offload).
- Serving at scale: This script provides helpers and placeholders (BentoML/Triton) but not a full operator, autoscaler, or canary deployment pipeline. Kubernetes manifests/Helm charts are produced externally.
- Safety & governance: toxicity / bias checks and license enforcement require policy definitions and test datasets.
2) How to resolve each gap (step-by-step, practical commands)
A. Ensure correct GPU drivers & PyTorch
- Check GPU drivers:
nvidia-smi
- Install PyTorch matching CUDA version (example for CUDA 11.8):
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
- Verify:
python -c "import torch; print(torch.cuda.is_available(), torch.version.cuda)"
B. Install bitsandbytes (8-bit) safely
- Recommended: use pip when CUDA and gcc toolchain match. Example:
pip install --upgrade pip
pip install bitsandbytes
- If pip wheel fails, check bitsandbytes docs for building from source.
- Test:
python -c "import bitsandbytes as bnb; print('bnb OK')"
C. Install faiss-gpu
- Best via conda:
conda install -c pytorch faiss-gpu cudatoolkit=11.8
- Or use faiss-cpu for CPU-only:
pip install faiss-cpu
- Validate:
python -c "import faiss; print('faiss OK')"
D. DeepSpeed and distributed training
- Install:
pip install deepspeed
- Generate a ds_config with this script:
python rag_toolkit_complete.py gen-deepspeed --zero-stage 3 --out ds_config.json
- Launch (example single-node 4 GPUs):
deepspeed --num_gpus=4 train.py --deepspeed_config ds_config.json
E. Optimum / ONNX / TensorRT / Triton (export & serving)
- Install optimum and onnxruntime:
pip install optimum onnxruntime onnx
- Use optimum for better conversion of Transformers to ONNX/TensorRT (check Optimum docs).
- Build Triton model repo using this script helper:
python rag_toolkit_complete.py build-triton --onnx models/gpt2.onnx --model-name gpt2
- Deploy Triton server (follow NVIDIA Triton docs) and mount model repo.
F. BentoML (serving bundle)
- Install BentoML:
pip install bentoml
- Use save-bento command:
python rag_toolkit_complete.py save-bento --model gpt2 --out-dir bento_bundle
- Implement a BentoML Service class that wraps tokenizer+model and exposes REST/gRPC.
G. Vector DB (Milvus) for persistent large indices
- Install Milvus server or use hosted service.
- Install client:
pip install pymilvus
- Use Milvus to store embeddings and build distributed indexes for >10M vectors.
H. Observability and CI/CD
- MLflow:
pip install mlflow
- Prometheus / Grafana: instrument the service; expose /metrics via prometheus_client.
- CI: build Docker image in CI, run unit tests and lightweight integration tests, publish image to registry.
3) How to test locally (quickstart)
- Create venv and install minimal deps:
python -m venv .venv && source .venv/bin/activate
pip install -U pip
pip install transformers sentence-transformers torch faiss-cpu psutil
- Prepare small corpus:
mkdir -p data
printf "Document one\nDocument two\nDocument three\n" > data/corpus.txt
printf "Write a summary of Document one\nWhat is Document two about?\n" > data/prompts.txt
- Build index:
python rag_toolkit_complete.py build-index --corpus data/corpus.txt --index-out data/my.index
- Benchmark:
python rag_toolkit_complete.py benchmark --model gpt2 --prompts data/prompts.txt --index data/my.index --corpus data/corpus.txt --iters 5
- Export ONNX (CPU test):
python rag_toolkit_complete.py export-onnx --model gpt2 --out models/gpt2.onnx
4) Validation checklist (what to verify)
- Model loaded and responds to prompts (sanity check).
- Index built and retrieval returns plausible documents.
- Benchmark produces reasonable latencies and writes output JSON.
- ONNX file exists and can be inspected with onnx.checker.
- Optional: MLflow/Hub artifacts logged or uploaded successfully.
5) Production recommendations (short)
- Split services: embedding service, retriever, generator (each containerized).
- Use a vector DB (Milvus/Weaviate) for persistence and scaling.
- Use Triton or BentoML for high-throughput serving; put generator behind an API Gateway and rate limiting.
- Integrate monitoring: Prometheus metrics, Grafana dashboards, traces via OpenTelemetry.
- Implement safety pipelines: automatic toxicity checks, model cards, license scanning.
- Use DeepSpeed/ZeRO and quantization (bitsandbytes) to reduce cost for large models.
- Implement canary/blue-green deploys and autoscaling policies for GPU-backed pods.
6) Cost measurement & optimization
- Measure cost per 1M requests using benchmark QPS and cloud instance pricing.
- Try quantized models (8-bit/4-bit), lower precision (fp16), or smaller distilled checkpoints for cheaper inference.
- Use batching, caching, and shard models for throughput improvements.
7) Troubleshooting common issues
- Out-of-memory when loading model: try load_in_8bit, device_map="auto", or move to a machine with more GPU memory.
- ONNX export fails: try smaller model or use Optimum exporter; ensure matching torch/onnx/onnxruntime versions.
- faiss-gpu import errors: ensure CUDA and conda package version compatibility.
End of embedded guide.
-------------------------
The code implementation follows below.
"""
from dataclasses import dataclass, asdict
from pathlib import Path
import argparse
import json
import logging
import os
import subprocess
import sys
import time
from typing import List, Optional, Tuple
import numpy as np
import psutil
import torch
# Optional components detection
try:
from transformers import AutoTokenizer, AutoModelForCausalLM
TRANSFORMERS_AVAILABLE = True
except Exception:
TRANSFORMERS_AVAILABLE = False
try:
from sentence_transformers import SentenceTransformer
SENTENCE_TRANSFORMERS_AVAILABLE = True
except Exception:
SentenceTransformer = None
SENTENCE_TRANSFORMERS_AVAILABLE = False
try:
import faiss
FAISS_AVAILABLE = True
except Exception:
faiss = None
FAISS_AVAILABLE = False
try:
import bitsandbytes as bnb # noqa: F401
BNB_AVAILABLE = True
except Exception:
BNB_AVAILABLE = False
try:
import mlflow
MLFLOW_AVAILABLE = True
except Exception:
mlflow = None
MLFLOW_AVAILABLE = False
try:
import bentoml
BENTOML_AVAILABLE = True
except Exception:
bentoml = None
BENTOML_AVAILABLE = False
try:
from huggingface_hub import HfApi
HF_HUB_AVAILABLE = True
except Exception:
HfApi = None
HF_HUB_AVAILABLE = False
logger = logging.getLogger("rag_toolkit_complete")
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
@dataclass
class BenchmarkResult:
count: int
mean: float
median: float
p50: float
p90: float
p99: float
throughput_qps: Optional[float]
mem_samples: List[dict]
power_avg: Optional[float]
# -----------------------
# System utilities
# -----------------------
def is_cuda_available() -> bool:
return torch.cuda.is_available()
def measure_system_memory() -> dict:
vm = psutil.virtual_memory()
return {"total_gb": vm.total / (1024 ** 3), "available_gb": vm.available / (1024 ** 3), "used_gb": vm.used / (1024 ** 3), "percent": vm.percent}
def measure_gpu_memory(device: int = 0) -> dict:
if not is_cuda_available():
return {"cuda_available": False}
try:
allocated = torch.cuda.memory_allocated(device) / (1024 ** 3)
reserved = torch.cuda.memory_reserved(device) / (1024 ** 3)
return {"cuda_available": True, "allocated_gb": allocated, "reserved_gb": reserved}
except Exception as e:
return {"cuda_available": True, "error": str(e)}
def get_nvidia_power_draw() -> Optional[float]:
try:
out = subprocess.check_output(["nvidia-smi", "--query-gpu=power.draw", "--format=csv,noheader,nounits"], stderr=subprocess.DEVNULL)
vals = [float(x.strip()) for x in out.decode().splitlines() if x.strip()]
return float(sum(vals) / len(vals)) if vals else None
except Exception:
return None
# -----------------------
# Embeddings & FAISS
# -----------------------
def load_sentence_transformer(model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
if not SENTENCE_TRANSFORMERS_AVAILABLE:
raise RuntimeError("sentence-transformers not installed. Install it to use embeddings.")
logger.info(f"Loading SentenceTransformer: {model_name}")
return SentenceTransformer(model_name)
def encode_texts(emb_model, texts: List[str], batch_size: int = 128) -> np.ndarray:
logger.info(f"Encoding {len(texts)} texts (batch_size={batch_size})")
embeddings = emb_model.encode(texts, batch_size=batch_size, convert_to_numpy=True, normalize_embeddings=True)
return np.asarray(embeddings, dtype=np.float32)
def build_faiss_index(embeddings: np.ndarray, index_path: Optional[str] = None, n_list: int = 128):
if not FAISS_AVAILABLE:
raise RuntimeError("faiss is required. Install faiss-cpu or faiss-gpu.")
d = embeddings.shape[1]
logger.info(f"Building FAISS Index IVF (d={d}, nlist={n_list})")
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, n_list, faiss.METRIC_L2)
index.train(embeddings)
index.add(embeddings)
if index_path:
faiss.write_index(index, index_path)
logger.info(f"Saved FAISS index to {index_path}")
return index
def load_faiss_index(index_path: str):
if not FAISS_AVAILABLE:
raise RuntimeError("faiss is required.")
if not os.path.exists(index_path):
raise FileNotFoundError(index_path)
idx = faiss.read_index(index_path)
logger.info(f"Loaded FAISS index from {index_path}")
return idx
# -----------------------
# Model loading & generation
# -----------------------
def load_generation_model(model_name: str, device: str = "cuda", load_in_8bit: bool = False, trust_remote_code: bool = False):
if not TRANSFORMERS_AVAILABLE:
raise RuntimeError("transformers not installed.")
logger.info(f"Loading generation model {model_name} (device={device}, 8bit={load_in_8bit})")
load_kwargs = {}
if load_in_8bit and BNB_AVAILABLE:
load_kwargs["load_in_8bit"] = True
load_kwargs["device_map"] = "auto"
else:
if is_cuda_available() and device.startswith("cuda"):
load_kwargs["torch_dtype"] = torch.float16
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=trust_remote_code)
model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs, trust_remote_code=trust_remote_code)
if "device_map" not in load_kwargs:
dev = torch.device(device if is_cuda_available() else "cpu")
model.to(dev)
model.eval()
logger.info("Model loaded.")
return tokenizer, model
def generate_with_context(tokenizer, model, prompt: str, context_docs: Optional[List[str]] = None, max_new_tokens: int = 128, device: str = "cuda", **gen_kwargs):
context = ("\n\n".join(context_docs) + "\n\n") if context_docs else ""
full = context + prompt if context else prompt
inputs = tokenizer(full, return_tensors="pt", truncation=True)
if is_cuda_available() and device.startswith("cuda"):
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id, **gen_kwargs)
text = tokenizer.decode(out[0], skip_special_tokens=True)
if text.startswith(full):
return text[len(full):].strip()
return text
# -----------------------
# ONNX & Triton helpers
# -----------------------
def export_model_to_onnx(tokenizer, model, output_path: str, opset: int = 13):
logger.info(f"Exporting model to ONNX at {output_path} (best-effort)")
model_cpu = model.to("cpu")
model_cpu.eval()
dummy = "This is a dummy input"
inputs = tokenizer(dummy, return_tensors="pt")
input_ids = inputs["input_ids"]
try:
torch.onnx.export(model_cpu, (input_ids,), output_path, input_names=["input_ids"], output_names=["logits"], opset_version=opset, dynamic_axes={"input_ids": {0: "batch", 1: "sequence"}, "logits": {0: "batch", 1: "sequence"}})
logger.info("ONNX export succeeded.")
except Exception:
logger.exception("ONNX export failed. Consider using Optimum or vendor exporters.")
raise
finally:
if is_cuda_available():
try:
model.to("cuda")
except Exception:
pass
return output_path
def build_triton_model_repo(onnx_path: str, model_name: str, model_version: int = 1, repo_dir: str = "triton_model_repo"):
logger.info("Creating Triton model repo (ONNX) ...")
model_dir = Path(repo_dir) / model_name / str(model_version)
model_dir.mkdir(parents=True, exist_ok=True)
dest = model_dir / Path(onnx_path).name
dest.write_bytes(Path(onnx_path).read_bytes())
cfg = f"""name: "{model_name}"
platform: "onnxruntime_onnx"
max_batch_size: 8
input [
{{
name: "input_ids"
data_type: TYPE_INT32
dims: [-1]
}}
]
output [
{{
name: "logits"
data_type: TYPE_FP32
dims: [-1, -1]
}}
]
"""
cfgp = Path(repo_dir) / model_name / "config.pbtxt"
cfgp.write_text(cfg)
logger.info(f"Triton model repo created at {Path(repo_dir)/model_name}")
return str(Path(repo_dir) / model_name)
# -----------------------
# BentoML placeholder
# -----------------------
def save_bentoml_bundle(tokenizer, model, save_dir: str = "bento_bundle"):
if not BENTOML_AVAILABLE:
raise RuntimeError("BentoML not installed.")
out = Path(save_dir)
out.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), out / "model_state.pt")
tokenizer.save_pretrained(out / "tokenizer")
logger.info(f"BentoML bundle saved at {out}")
return str(out)
# -----------------------
# MLflow & HF Hub helpers
# -----------------------
def mlflow_log_artifact_local(path: str):
if not MLFLOW_AVAILABLE:
logger.warning("MLflow not installed; skip logging.")
return
mlflow.log_artifact(path)
logger.info(f"Logged {path} to MLflow.")
def push_file_to_hf(file_path: str, repo_id: str, path_in_repo: Optional[str] = None, token: Optional[str] = None):
if not HF_HUB_AVAILABLE:
raise RuntimeError("huggingface_hub not available.")
api = HfApi()
token = token or os.environ.get("HF_TOKEN")
dest = path_in_repo or Path(file_path).name
logger.info(f"Pushing {file_path} to HF repo {repo_id}/{dest}")
api.upload_file(path_or_fileobj=file_path, path_in_repo=dest, repo_id=repo_id, token=token)
logger.info("Upload succeeded.")
# -----------------------
# DeepSpeed helpers
# -----------------------
def generate_deepspeed_config(zero_stage: int = 3, offload_type: str = "cpu", fp16_enabled: bool = True, train_batch_size: int = 1, gradient_accumulation_steps: int = 1):
cfg = {
"train_batch_size": train_batch_size,
"gradient_accumulation_steps": gradient_accumulation_steps,
"fp16": {"enabled": fp16_enabled},
"zero_optimization": {
"stage": zero_stage,
"offload_param": {"device": offload_type} if offload_type else {},
"offload_optimizer": {"device": offload_type} if offload_type else {},
},
}
return cfg
def write_deepspeed_config(cfg: dict, out_path: str = "ds_config.json"):
Path(out_path).write_text(json.dumps(cfg, indent=2))
logger.info(f"Wrote DeepSpeed config to {out_path}")
return out_path
# -----------------------
# RAG pipeline & benchmark
# -----------------------
class RAGPipeline:
def __init__(self, embed_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
if not SENTENCE_TRANSFORMERS_AVAILABLE:
raise RuntimeError("sentence-transformers required.")
self.embed_model = load_sentence_transformer(embed_model_name)
self.faiss_index = None
self.corpus: List[str] = []
def build_index(self, corpus: List[str], index_out: Optional[str] = None, n_list: int = 128):
self.corpus = corpus
emb = encode_texts(self.embed_model, corpus)
self.faiss_index = build_faiss_index(emb, index_path=index_out, n_list=n_list)
logger.info("FAISS index built.")
def load_index(self, index_path: str, corpus_path: Optional[str] = None):
self.faiss_index = load_faiss_index(index_path)
if corpus_path:
self.corpus = [l.strip() for l in Path(corpus_path).read_text(encoding="utf-8").splitlines() if l.strip()]
logger.info("Index and corpus loaded.")
def retrieve(self, query: str, top_k: int = 5):
q_emb = self.embed_model.encode([query], normalize_embeddings=True).astype(np.float32)
dists, idxs = self.faiss_index.search(q_emb, top_k)
results = []
for idx, dist in zip(idxs[0], dists[0]):
if idx < 0:
continue
results.append({"id": int(idx), "score": float(dist), "text": self.corpus[int(idx)]})
return results
def generate(self, tokenizer, model, query: str, top_k: int = 3, max_new_tokens: int = 128, device: str = "cuda", **gen_kwargs):
docs = self.retrieve(query, top_k=top_k)
ctx = [d["text"] for d in docs]
return generate_with_context(tokenizer, model, query, context_docs=ctx, max_new_tokens=max_new_tokens, device=device, **gen_kwargs)
def run_benchmark(tokenizer, model, prompts: List[str], rag: Optional[RAGPipeline], iters: int = 50, warmup: int = 3, top_k: int = 3, device: str = "cuda") -> BenchmarkResult:
logger.info("Starting benchmark (warmup then runs)...")
for _ in range(warmup):
q = prompts[0]
if rag:
_ = rag.generate(tokenizer, model, q, top_k=top_k, device=device)
else:
_ = generate_with_context(tokenizer, model, q, None, device=device)
times = []
mem_samples = []
power_samples = []
for i in range(iters):
q = prompts[i % len(prompts)]
start = time.time()
if rag:
_ = rag.generate(tokenizer, model, q, top_k=top_k, device=device)
else:
_ = generate_with_context(tokenizer, model, q, None, device=device)
end = time.time()
lat = end - start
times.append(lat)
mem = measure_system_memory()
gpu = measure_gpu_memory()
mem_samples.append({"iter": i, "latency_s": lat, "cpu": mem, "gpu": gpu})
pw = get_nvidia_power_draw()
if pw:
power_samples.append(pw)
if (i + 1) % max(1, iters // 5) == 0:
logger.info(f"Iteration {i+1}/{iters} latency {lat:.3f}s")
arr = np.array(times)
res = BenchmarkResult(count=len(arr), mean=float(arr.mean()), median=float(np.median(arr)), p50=float(np.percentile(arr, 50)), p90=float(np.percentile(arr, 90)), p99=float(np.percentile(arr, 99)), throughput_qps=float(1.0 / arr.mean()) if arr.mean() > 0 else None, mem_samples=mem_samples, power_avg=float(np.mean(power_samples)) if power_samples else None)
if MLFLOW_AVAILABLE:
mlflow.log_metrics({"mean_latency": res.mean, "p90": res.p90, "throughput_qps": res.throughput_qps})
return res
# -----------------------
# CLI
# -----------------------
def parse_args():
p = argparse.ArgumentParser(description="RAG Toolkit Complete CLI")
sub = p.add_subparsers(dest="cmd")
bi = sub.add_parser("build-index", help="Build FAISS index from corpus")
bi.add_argument("--corpus", required=True)
bi.add_argument("--index-out", required=True)
bi.add_argument("--embed-model", default="sentence-transformers/all-MiniLM-L6-v2")
bi.add_argument("--n-list", type=int, default=128)
bm = sub.add_parser("benchmark", help="Benchmark a model (optionally with RAG)")
bm.add_argument("--model", required=True)
bm.add_argument("--prompts", required=True)
bm.add_argument("--corpus", help="corpus file path for retrieval")
bm.add_argument("--index", help="faiss index path")
bm.add_argument("--embed-model", default="sentence-transformers/all-MiniLM-L6-v2")
bm.add_argument("--use-8bit", action="store_true")
bm.add_argument("--iters", type=int, default=50)
bm.add_argument("--device", default="cuda" if is_cuda_available() else "cpu")
bm.add_argument("--output", default="benchmark_result.json")
ex = sub.add_parser("export-onnx", help="Export model to ONNX")
ex.add_argument("--model", required=True)
ex.add_argument("--out", required=True)
tr = sub.add_parser("build-triton", help="Build Triton model repo from ONNX")
tr.add_argument("--onnx", required=True)
tr.add_argument("--model-name", required=True)
tr.add_argument("--repo-dir", default="triton_model_repo")
hf = sub.add_parser("push-hf", help="Push artifact to Hugging Face Hub")
hf.add_argument("--artifact", required=True)
hf.add_argument("--repo-id", required=True)
hf.add_argument("--path-in-repo", help="path inside repo")
hf.add_argument("--token", help="HF token")
ds = sub.add_parser("gen-deepspeed", help="Generate DeepSpeed config")
ds.add_argument("--zero-stage", type=int, default=3)
ds.add_argument("--offload-type", default="cpu")
ds.add_argument("--fp16", action="store_true")
ds.add_argument("--out", default="ds_config.json")
sb = sub.add_parser("save-bento", help="Save BentoML bundle (placeholder)")
sb.add_argument("--model", required=True)
sb.add_argument("--out-dir", default="bento_bundle")
return p.parse_args()
def main():
args = parse_args()
if args.cmd == "build-index":
texts = [l.strip() for l in Path(args.corpus).read_text(encoding="utf-8").splitlines() if l.strip()]
rag = RAGPipeline(embed_model_name=args.embed_model)
rag.build_index(texts, index_out=args.index_out, n_list=args.n_list)
Path(args.corpus + ".saved").write_text("\n".join(texts), encoding="utf-8")
logger.info("Index built.")
elif args.cmd == "benchmark":
raw = Path(args.prompts).read_text(encoding="utf-8").strip()
if raw.startswith("["):
prompts = json.loads(raw)
else:
prompts = [l.strip() for l in raw.splitlines() if l.strip()]
rag_pipeline = None
if args.corpus and args.index:
corpus = [l.strip() for l in Path(args.corpus).read_text(encoding="utf-8").splitlines() if l.strip()]
rag_pipeline = RAGPipeline(embed_model_name=args.embed_model)
rag_pipeline.load_index(args.index, corpus_path=args.corpus)
tokenizer, model = load_generation_model(args.model, device=args.device, load_in_8bit=args.use_8bit)
res = run_benchmark(tokenizer, model, prompts, rag_pipeline, iters=args.iters, device=args.device)
Path(args.output).write_text(json.dumps(asdict(res), indent=2), encoding="utf-8")
logger.info(f"Saved benchmark to {args.output}")
if MLFLOW_AVAILABLE:
mlflow.log_artifact(args.output)
elif args.cmd == "export-onnx":
tokenizer, model = load_generation_model(args.model, device="cpu", load_in_8bit=False)
out = export_model_to_onnx(tokenizer, model, args.out)
logger.info(f"ONNX exported to {out}")
elif args.cmd == "build-triton":
repo = build_triton_model_repo(args.onnx, args.model_name, repo_dir=args.repo_dir)
logger.info(f"Triton repo ready at {repo}")
elif args.cmd == "push-hf":
push_file_to_hf(args.artifact, args.repo_id, path_in_repo=args.path_in_repo, token=args.token)
elif args.cmd == "gen-deepspeed":
cfg = generate_deepspeed_config(zero_stage=args.zero_stage, offload_type=args.offload_type, fp16_enabled=args.fp16)
write_deepspeed_config(cfg, out_path=args.out)
elif args.cmd == "save-bento":
tokenizer, model = load_generation_model(args.model, device="cpu", load_in_8bit=False)
save_bentoml_bundle(tokenizer, model, save_dir=args.out_dir)
else:
print("No command specified. Use --help.")
if __name__ == "__main__":
main()Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels