diff --git a/routers/predict.py b/routers/predict.py index e0c560c..ab2c4dc 100644 --- a/routers/predict.py +++ b/routers/predict.py @@ -1,15 +1,19 @@ from fastapi import APIRouter, HTTPException from datetime import datetime from schemas import TextItem, Prediction -from services.prediction import predict_emotions +from services.prediction import ( + predict_emotion_full, + predict_emotion_split_avg, + predict_emotion_overall_avg, +) from db import collection router = APIRouter(prefix="/predict", tags=["Prediction"]) -@router.post("/", response_model=Prediction) -async def create_prediction(item: TextItem): +@router.post("/full", response_model=Prediction) +async def predict_full(item: TextItem): try: - probs = predict_emotions(item.text) + probs = predict_emotion_full(item.text) record = { "text": item.text, "probabilities": probs, @@ -20,10 +24,30 @@ async def create_prediction(item: TextItem): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) -@router.get("/history", response_model=list[Prediction]) -async def list_predictions(limit: int = 20): - cursor = collection.find().sort("timestamp", -1).limit(limit) - results = [] - async for doc in cursor: - results.append(Prediction(**doc)) - return results \ No newline at end of file +@router.post("/split_avg", response_model=Prediction) +async def predict_split_avg(item: TextItem): + try: + probs = predict_emotion_split_avg(item.text) + record = { + "text": item.text, + "probabilities": probs, + "timestamp": datetime.utcnow() + } + await collection.insert_one(record) + return record + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.post("/overall_avg", response_model=Prediction) +async def predict_overall_avg(item: TextItem): + try: + probs = predict_emotion_overall_avg(item.text) + record = { + "text": item.text, + "probabilities": probs, + "timestamp": datetime.utcnow() + } + await collection.insert_one(record) + return record + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/services/prediction.py b/services/prediction.py index 2ff94cb..df1baa3 100644 --- a/services/prediction.py +++ b/services/prediction.py @@ -8,9 +8,62 @@ _model = AutoModelForSequenceClassification.from_pretrained(settings.model_dir) _model.eval() -def predict_emotions(text: str) -> dict[str, float]: - inputs = _tokenizer(text, return_tensors="pt", truncation=True, max_length=128) + +def get_avg_emotion(results: list[dict[str, float]]) -> dict[str, float]: + if not results: + return {label: 0.0 for label in LABELS_5} + + # 각 레이블별 합산 + sums = {label: 0.0 for label in LABELS_5} + for prob in results: + for label, score in prob.items(): + sums[label] += score + + # 평균 계산 + n = len(results) + return {label: sums[label] / n for label in LABELS_5} + +def _predict_batch(texts: list[str]) -> list[dict[str, float]]: + inputs = _tokenizer( + texts, + return_tensors="pt", + truncation=True, + padding=True, + max_length=128 + ) + with torch.no_grad(): - outputs = _model(**inputs) - probs = torch.softmax(outputs.logits[0], dim=-1) - return { LABELS_5[i]: float(probs[i]) for i in range(len(LABELS_5)) } \ No newline at end of file + logits = _model(**inputs).logits # [batch_size, num_labels] + probs = torch.softmax(logits, dim=-1) + + return [ + { LABELS_5[i]: float(prob[i]) for i in range(len(LABELS_5)) } + for prob in probs + ] + + +def _format_percent(probs: dict[str, float]) -> dict[str, float]: + return {label: round(probs[label] * 100, 2) for label in LABELS_5} + +def predict_emotion_split_avg(text: str) -> dict[str, float]: + sentences = [s.strip() for s in text.split('.') if s.strip()] + if not sentences: + sentences = [text.strip()] + raw_probs = _predict_batch(sentences) + avg_raw = get_avg_emotion(raw_probs) + return _format_percent(avg_raw) + + +def predict_emotion_overall_avg(text: str) -> dict[str, float]: + sentences = [s.strip() for s in text.split('.') if s.strip()] + if not sentences: + sentences = [text.strip()] + all_texts = [text.strip()] + sentences + raw_probs = _predict_batch(all_texts) + avg_raw = get_avg_emotion(raw_probs) + return _format_percent(avg_raw) + + +def predict_emotion_full(text: str) -> dict[str, float]: + [raw] = _predict_batch([text.strip()]) + return _format_percent(raw)