From cdac87c4c174457ee026115c780ce69f1e0f876c Mon Sep 17 00:00:00 2001 From: Gumraze-git Date: Sun, 13 Jul 2025 12:24:11 +0900 Subject: [PATCH 1/4] =?UTF-8?q?[FEAT]=20=EA=B0=90=EC=A0=95=20=EC=98=88?= =?UTF-8?q?=EC=B8=A1=20=EB=A1=9C=EC=A7=81=20=ED=99=95=EC=9E=A5=20=EB=B0=8F?= =?UTF-8?q?=20=ED=8F=89=EA=B7=A0=20=EA=B3=84=EC=82=B0=20=ED=95=A8=EC=88=98?= =?UTF-8?q?=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 클래스별 예측 확률 평균 계산 함수 추가 - 배치 예측 및 문장 분리 기반 감정 예측 로직 추가 - 다양한 감정 예측 함수 제공 (split_avg, overall_avg, full) --- services/prediction.py | 60 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/services/prediction.py b/services/prediction.py index 2ff94cb..b91fa72 100644 --- a/services/prediction.py +++ b/services/prediction.py @@ -8,9 +8,59 @@ _model = AutoModelForSequenceClassification.from_pretrained(settings.model_dir) _model.eval() -def predict_emotions(text: str) -> dict[str, float]: - inputs = _tokenizer(text, return_tensors="pt", truncation=True, max_length=128) + +def get_avg_emotion(results: list[dict[str, float]]) -> dict[str, float]: + if not results: + return {label: 0.0 for label in LABELS_5} + + # 각 레이블별 합산 + sums = {label: 0.0 for label in LABELS_5} + for prob in results: + for label, score in prob.items(): + sums[label] += score + + # 평균 계산 + n = len(results) + return {label: sums[label] / n for label in LABELS_5} + +def _predict_batch(texts: list[str]) -> list[dict[str, float]]: + inputs = _tokenizer( + texts, + return_tensors="pt", + truncation=True, + padding=True, + max_length=128 + ) + with torch.no_grad(): - outputs = _model(**inputs) - probs = torch.softmax(outputs.logits[0], dim=-1) - return { LABELS_5[i]: float(probs[i]) for i in range(len(LABELS_5)) } \ No newline at end of file + logits = _model(**inputs).logits # [batch_size, num_labels] + probs = torch.softmax(logits, dim=-1) + + return [ + { LABELS_5[i]: float(prob[i]) for i in range(len(LABELS_5)) } + for prob in probs + ] + + +def predict_emotion_split_avg(text: str) -> dict[str, float]: + sentences = [s.strip() for s in text.split('.') if s.strip()] + if not sentences: + sentences = [text.strip()] + + probs_list = _predict_batch(sentences) + return get_avg_emotion(probs_list) + + +def predict_emotion_overall_avg(text: str) -> dict[str, float]: + sentences = [s.strip() for s in text.split('.') if s.strip()] + if not sentences: + sentences = [text.strip()] + all_texts = [text.strip()] + sentences + + probs_list = _predict_batch(all_texts) + return get_avg_emotion(probs_list) + + +def predict_emotion_full(text: str) -> dict[str, float]: + [prob] = _predict_batch([text.strip()]) + return prob \ No newline at end of file From d69aa9a4d6bb0ba512e66e85a616f15b851b0768 Mon Sep 17 00:00:00 2001 From: Gumraze-git Date: Sun, 13 Jul 2025 12:24:19 +0900 Subject: [PATCH 2/4] =?UTF-8?q?[FEAT]=20=EA=B0=90=EC=A0=95=20=EC=98=88?= =?UTF-8?q?=EC=B8=A1=20=EC=97=94=EB=93=9C=ED=8F=AC=EC=9D=B8=ED=8A=B8=20?= =?UTF-8?q?=ED=99=95=EC=9E=A5=20=EB=B0=8F=20=EC=84=B8=20=EA=B0=80=EC=A7=80?= =?UTF-8?q?=20=EC=83=88=EB=A1=9C=EC=9A=B4=20=EC=98=88=EC=B8=A1=20=EB=A1=9C?= =?UTF-8?q?=EC=A7=81=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 새로운 엔드포인트 (/full, /split_avg, /overall_avg) 추가 - 감정 예측 로직 다양화 (full, split_avg, overall_avg) - 예측 결과를 MongoDB에 저장하도록 수정 --- routers/predict.py | 40 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/routers/predict.py b/routers/predict.py index e0c560c..53257bd 100644 --- a/routers/predict.py +++ b/routers/predict.py @@ -1,15 +1,47 @@ from fastapi import APIRouter, HTTPException from datetime import datetime from schemas import TextItem, Prediction -from services.prediction import predict_emotions +from services.prediction import ( + predict_emotion_full, + predict_emotion_split_avg, + predict_emotion_overall_avg, +) from db import collection router = APIRouter(prefix="/predict", tags=["Prediction"]) -@router.post("/", response_model=Prediction) -async def create_prediction(item: TextItem): +@router.post("/full", response_model=Prediction) +async def predict_full(item: TextItem): try: - probs = predict_emotions(item.text) + probs = predict_emotion_full(item.text) + record = { + "text": item.text, + "probabilities": probs, + "timestamp": datetime.utcnow() + } + await collection.insert_one(record) + return record + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.post("/split_avg", response_model=Prediction) +async def predict_split_avg(item: TextItem): + try: + probs = predict_emotion_split_avg(item.text) + record = { + "text": item.text, + "probabilities": probs, + "timestamp": datetime.utcnow() + } + await collection.insert_one(record) + return record + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.post("/overall_avg", response_model=Prediction) +async def predict_overall_avg(item: TextItem): + try: + probs = predict_emotion_overall_avg(item.text) record = { "text": item.text, "probabilities": probs, From a36001ca4b64a14582f7c3536488754fd9ad9980 Mon Sep 17 00:00:00 2001 From: Gumraze-git Date: Sun, 13 Jul 2025 12:30:14 +0900 Subject: [PATCH 3/4] =?UTF-8?q?[FEAT]=20=EA=B0=90=EC=A0=95=20=EC=98=88?= =?UTF-8?q?=EC=B8=A1=20=ED=99=95=EB=A5=A0=EC=9D=84=20=EB=B0=B1=EB=B6=84?= =?UTF-8?q?=EC=9C=A8=20=ED=98=95=EC=8B=9D=EC=9C=BC=EB=A1=9C=20=EB=B3=80?= =?UTF-8?q?=ED=99=98=ED=95=98=EB=8A=94=20=EB=A1=9C=EC=A7=81=20=EC=B6=94?= =?UTF-8?q?=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 예측 결과를 소수점에서 백분율 형식으로 변환하는 `_format_percent` 함수 추가 - split_avg, overall_avg, full 예측 함수에 변환 로직 적용 --- services/prediction.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/services/prediction.py b/services/prediction.py index b91fa72..df1baa3 100644 --- a/services/prediction.py +++ b/services/prediction.py @@ -42,13 +42,16 @@ def _predict_batch(texts: list[str]) -> list[dict[str, float]]: ] +def _format_percent(probs: dict[str, float]) -> dict[str, float]: + return {label: round(probs[label] * 100, 2) for label in LABELS_5} + def predict_emotion_split_avg(text: str) -> dict[str, float]: sentences = [s.strip() for s in text.split('.') if s.strip()] if not sentences: sentences = [text.strip()] - - probs_list = _predict_batch(sentences) - return get_avg_emotion(probs_list) + raw_probs = _predict_batch(sentences) + avg_raw = get_avg_emotion(raw_probs) + return _format_percent(avg_raw) def predict_emotion_overall_avg(text: str) -> dict[str, float]: @@ -56,11 +59,11 @@ def predict_emotion_overall_avg(text: str) -> dict[str, float]: if not sentences: sentences = [text.strip()] all_texts = [text.strip()] + sentences - - probs_list = _predict_batch(all_texts) - return get_avg_emotion(probs_list) + raw_probs = _predict_batch(all_texts) + avg_raw = get_avg_emotion(raw_probs) + return _format_percent(avg_raw) def predict_emotion_full(text: str) -> dict[str, float]: - [prob] = _predict_batch([text.strip()]) - return prob \ No newline at end of file + [raw] = _predict_batch([text.strip()]) + return _format_percent(raw) From eb4c53fe3c49dddf2cadbff326dd05c4d7220db6 Mon Sep 17 00:00:00 2001 From: Gumraze-git Date: Sun, 13 Jul 2025 12:44:53 +0900 Subject: [PATCH 4/4] =?UTF-8?q?[CHORE]=20=EC=98=88=EC=B8=A1=20=EA=B8=B0?= =?UTF-8?q?=EB=A1=9D=20=EC=A1=B0=ED=9A=8C=20=EC=97=94=EB=93=9C=ED=8F=AC?= =?UTF-8?q?=EC=9D=B8=ED=8A=B8=20=EC=A0=9C=EA=B1=B0=20=EB=B0=8F=20=EC=BD=94?= =?UTF-8?q?=EB=93=9C=20=EC=A0=95=EB=A6=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - history는 commment_id와 함께 전달 필요함. - 따라서 Spring server에서 post시 모델 결과, comment_id 등과 함께 mongoDB에 저장하는 로직으로 변경하기 위해 삭제 --- routers/predict.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/routers/predict.py b/routers/predict.py index 53257bd..ab2c4dc 100644 --- a/routers/predict.py +++ b/routers/predict.py @@ -50,12 +50,4 @@ async def predict_overall_avg(item: TextItem): await collection.insert_one(record) return record except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - -@router.get("/history", response_model=list[Prediction]) -async def list_predictions(limit: int = 20): - cursor = collection.find().sort("timestamp", -1).limit(limit) - results = [] - async for doc in cursor: - results.append(Prediction(**doc)) - return results \ No newline at end of file + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file