Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions routers/predict.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from fastapi import APIRouter, HTTPException
from datetime import datetime
from schemas import TextItem, Prediction
from services.prediction import predict_emotions
from services.prediction import (
predict_emotion_full,
predict_emotion_split_avg,
predict_emotion_overall_avg,
)
from db import collection

router = APIRouter(prefix="/predict", tags=["Prediction"])

@router.post("/", response_model=Prediction)
async def create_prediction(item: TextItem):
@router.post("/full", response_model=Prediction)
async def predict_full(item: TextItem):
try:
probs = predict_emotions(item.text)
probs = predict_emotion_full(item.text)
record = {
"text": item.text,
"probabilities": probs,
Expand All @@ -20,10 +24,30 @@ async def create_prediction(item: TextItem):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@router.get("/history", response_model=list[Prediction])
async def list_predictions(limit: int = 20):
cursor = collection.find().sort("timestamp", -1).limit(limit)
results = []
async for doc in cursor:
results.append(Prediction(**doc))
return results
@router.post("/split_avg", response_model=Prediction)
async def predict_split_avg(item: TextItem):
try:
probs = predict_emotion_split_avg(item.text)
record = {
"text": item.text,
"probabilities": probs,
"timestamp": datetime.utcnow()
}
await collection.insert_one(record)
return record
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@router.post("/overall_avg", response_model=Prediction)
async def predict_overall_avg(item: TextItem):
try:
probs = predict_emotion_overall_avg(item.text)
record = {
"text": item.text,
"probabilities": probs,
"timestamp": datetime.utcnow()
}
await collection.insert_one(record)
return record
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
63 changes: 58 additions & 5 deletions services/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,62 @@
_model = AutoModelForSequenceClassification.from_pretrained(settings.model_dir)
_model.eval()

def predict_emotions(text: str) -> dict[str, float]:
inputs = _tokenizer(text, return_tensors="pt", truncation=True, max_length=128)

def get_avg_emotion(results: list[dict[str, float]]) -> dict[str, float]:
if not results:
return {label: 0.0 for label in LABELS_5}

# 각 레이블별 합산
sums = {label: 0.0 for label in LABELS_5}
for prob in results:
for label, score in prob.items():
sums[label] += score

# 평균 계산
n = len(results)
return {label: sums[label] / n for label in LABELS_5}

def _predict_batch(texts: list[str]) -> list[dict[str, float]]:
inputs = _tokenizer(
texts,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
)

with torch.no_grad():
outputs = _model(**inputs)
probs = torch.softmax(outputs.logits[0], dim=-1)
return { LABELS_5[i]: float(probs[i]) for i in range(len(LABELS_5)) }
logits = _model(**inputs).logits # [batch_size, num_labels]
probs = torch.softmax(logits, dim=-1)

return [
{ LABELS_5[i]: float(prob[i]) for i in range(len(LABELS_5)) }
for prob in probs
]


def _format_percent(probs: dict[str, float]) -> dict[str, float]:
return {label: round(probs[label] * 100, 2) for label in LABELS_5}

def predict_emotion_split_avg(text: str) -> dict[str, float]:
sentences = [s.strip() for s in text.split('.') if s.strip()]
if not sentences:
sentences = [text.strip()]
raw_probs = _predict_batch(sentences)
avg_raw = get_avg_emotion(raw_probs)
return _format_percent(avg_raw)


def predict_emotion_overall_avg(text: str) -> dict[str, float]:
sentences = [s.strip() for s in text.split('.') if s.strip()]
if not sentences:
sentences = [text.strip()]
all_texts = [text.strip()] + sentences
raw_probs = _predict_batch(all_texts)
avg_raw = get_avg_emotion(raw_probs)
return _format_percent(avg_raw)


def predict_emotion_full(text: str) -> dict[str, float]:
[raw] = _predict_batch([text.strip()])
return _format_percent(raw)