Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions src/Components/.env_example
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,31 @@
# === Stripe Payment Gateway ===
STRIPE_PRIVATE_KEY=sk_test_your_stripe_key_here

# === MongoDB Connection ===
# Replace mongodb+srv://<username>:<password>@<cluster>.mongodb.net/echoDB with "mongodb://modelUser:EchoNetAccess2023@ts-mongodb-cont:27017/EchoNet"
MONGODB_URI=mongodb+srv://<username>:<password>@<cluster>.mongodb.net/echoDB
# === MongoDB ===
DB_HOST=ts-mongodb-cont
DB_NAME=EchoNet
DB_USER=modelUser
DB_USER_PASS=EchoNetAccess2023
DB_ROOT_USER=root
DB_ROOT_USER_PASS=root_password

# Please go through "Environment Setup & Stripe Integration Guide" uploaded in the Project Echo team files.
# === MongoDB URI ===
MONGODB_URI=mongodb://modelUser:EchoNetAccess2023@ts-mongodb-cont:27017/EchoNet
USER_MONGODB_URI=mongodb://root:root_password@ts-mongodb-cont:27017/UserSample?authSource=admin

# === Redis ===
NODE_ENV=development
REDIS_HOST=echo-redis

# === API ===
API_HOST=ts-api-cont

# === Mail ===
MAIL_STARTTLS=true
MAIL_SSL_TLS=false

# Model Adapter Configuration
MODEL_SERVER_URL=http://ts-echo-model-cont:8501/v1/models/echo_model/versions/1:predict
MODEL_REQUEST_TIMEOUT_SECONDS=10
MODEL_SIGNATURE_NAME=serving_default
MODEL_CLASS_NAMES=
13 changes: 13 additions & 0 deletions src/Components/API/app/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from app.adapters.model_adapter import (
AdapterError,
AdapterTimeoutError,
HttpModelAdapter,
PredictionResult,
)

__all__ = [
"AdapterError",
"AdapterTimeoutError",
"HttpModelAdapter",
"PredictionResult",
]
120 changes: 120 additions & 0 deletions src/Components/API/app/adapters/model_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import requests


class AdapterError(Exception):
pass


class AdapterTimeoutError(AdapterError):
pass


@dataclass
class PredictionResult:
species: str
confidence: float
raw: Optional[Dict[str, Any]] = None


class HttpModelAdapter:
def __init__(
self,
model_server_url: Optional[str] = None,
timeout_seconds: Optional[float] = None,
signature_name: Optional[str] = None,
class_names: Optional[List[str]] = None,
) -> None:
self.model_server_url = model_server_url or os.getenv(
"MODEL_SERVER_URL",
"http://ts-echo-model-cont:8501/v1/models/echo_model/versions/1:predict",
)
self.timeout_seconds = timeout_seconds or float(
os.getenv("MODEL_REQUEST_TIMEOUT_SECONDS", "10")
)
self.signature_name = signature_name or os.getenv(
"MODEL_SIGNATURE_NAME", "serving_default"
)
self.class_names = class_names or self._parse_class_names(
os.getenv("MODEL_CLASS_NAMES", "")
)

@staticmethod
def _parse_class_names(value: str) -> List[str]:
if not value:
return []
return [item.strip() for item in value.split(",") if item.strip()]

def predict(self, model_inputs: Any) -> PredictionResult:
payload = {
"signature_name": self.signature_name,
"inputs": model_inputs,
}

try:
response = requests.post(
self.model_server_url,
json=payload,
timeout=self.timeout_seconds,
)
except requests.Timeout as exc:
raise AdapterTimeoutError("Model server request timed out") from exc
except requests.RequestException as exc:
raise AdapterError(f"Failed to connect model server: {exc}") from exc

if response.status_code >= 400:
raise AdapterError(
f"Model server returned status {response.status_code}: {response.text}"
)

try:
body = response.json()
except ValueError as exc:
raise AdapterError("Model server returned non-JSON response") from exc

return self._map_response(body)

def _map_response(self, body: Dict[str, Any]) -> PredictionResult:
vector = self._extract_scores(body)
if not vector:
raise AdapterError("Model server response does not contain prediction scores")

max_index = max(range(len(vector)), key=lambda idx: vector[idx])
confidence = float(vector[max_index])
species = (
self.class_names[max_index]
if self.class_names and max_index < len(self.class_names)
else f"class_{max_index}"
)
return PredictionResult(species=species, confidence=confidence, raw=body)

def _extract_scores(self, body: Dict[str, Any]) -> List[float]:
candidates: List[Any] = []
if "outputs" in body:
candidates.append(body["outputs"])
if "predictions" in body:
candidates.append(body["predictions"])

for candidate in candidates:
vector = self._flatten_first_numeric_vector(candidate)
if vector:
return vector
return []

def _flatten_first_numeric_vector(self, value: Any) -> List[float]:
if isinstance(value, list):
if value and all(isinstance(item, (int, float)) for item in value):
return [float(item) for item in value]
for item in value:
nested = self._flatten_first_numeric_vector(item)
if nested:
return nested
elif isinstance(value, dict):
for nested_value in value.values():
nested = self._flatten_first_numeric_vector(nested_value)
if nested:
return nested
return []
3 changes: 2 additions & 1 deletion src/Components/API/app/routers/detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from app import detections as detections_service
from app.middleware.pause_guard import pause_guard
from app.services.budget import enforce_and_consume
from app.services.predictions import predict_from_payload

router = APIRouter(
prefix="/detections",
Expand Down Expand Up @@ -110,4 +111,4 @@ def delete_detection_endpoint(
def predict_endpoint(payload: Dict[str, Any] = Body(...)):
enforce_and_consume("species_predictor", cost=5)

return {"status": "not_implemented_here", "received": payload}
return predict_from_payload(payload)
37 changes: 14 additions & 23 deletions src/Components/API/app/routers/species_predictor.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
from fastapi import APIRouter, UploadFile, File, HTTPException, Form
from fastapi.responses import JSONResponse
from typing import Optional
from app.database import Predictions
import datetime
import re
from app.database import Predictions
from app.services.predictions import predict_uploaded_audio

router = APIRouter()

# Placeholder prediction
def predict_species(audio_file: UploadFile):
return {
"species": "Crimson Rosella",
"confidence": 0.92
}

@router.post("/predict")
async def predict(
audio: UploadFile = File(...),
Expand All @@ -31,23 +25,20 @@ async def predict(
else:
raise HTTPException(status_code=400, detail="Invalid upload_id format")

prediction = predict_species(audio)

# Persist prediction result
try:
doc = {
"filename": audio.filename,
"predicted_species": prediction["species"],
"confidence": prediction["confidence"],
"timestamp": datetime.datetime.utcnow(),
"user_id": user_id,
}
if valid_upload_id:
doc["upload_id"] = valid_upload_id

Predictions.insert_one(doc)
audio_bytes = await audio.read()
if not audio_bytes:
raise HTTPException(status_code=400, detail="Empty audio file provided")
prediction = predict_uploaded_audio(
filename=audio.filename,
audio_bytes=audio_bytes,
user_id=user_id,
upload_id=valid_upload_id,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to store prediction: {e}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")

return JSONResponse(content=prediction)

Expand Down
90 changes: 90 additions & 0 deletions src/Components/API/app/services/predictions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import base64
import datetime
from typing import Any, Dict, Optional

from fastapi import HTTPException

from app.adapters import AdapterError, AdapterTimeoutError, HttpModelAdapter
from app.database import Predictions


def _raise_prediction_http_error(exc: Exception) -> None:
if isinstance(exc, AdapterTimeoutError):
raise HTTPException(status_code=504, detail="Model service timeout")
if isinstance(exc, AdapterError):
raise HTTPException(status_code=502, detail=str(exc))
raise HTTPException(status_code=500, detail=f"Prediction failed: {exc}")


def _build_inputs_from_audio_bytes(audio_bytes: bytes, filename: str) -> Dict[str, Any]:
return {
"audio_b64": base64.b64encode(audio_bytes).decode("utf-8"),
"filename": filename,
}


def _build_inputs_from_payload(payload: Dict[str, Any]) -> Dict[str, Any]:
if "inputs" in payload:
return payload["inputs"]
if "audio_base64" in payload:
return {"audio_b64": payload["audio_base64"]}
if "audioClip" in payload:
return {"audio_b64": payload["audioClip"]}
raise HTTPException(
status_code=400,
detail="Payload must include either 'inputs', 'audio_base64', or 'audioClip'",
)


def predict_uploaded_audio(
*,
filename: str,
audio_bytes: bytes,
user_id: Optional[str] = None,
upload_id: Optional[str] = None,
) -> Dict[str, Any]:
adapter = HttpModelAdapter()
model_inputs = _build_inputs_from_audio_bytes(audio_bytes, filename)

try:
prediction = adapter.predict(model_inputs)
except Exception as exc:
_raise_prediction_http_error(exc)
raise

doc = {
"filename": filename,
"predicted_species": prediction.species,
"confidence": prediction.confidence,
"timestamp": datetime.datetime.utcnow(),
"user_id": user_id,
}
if upload_id:
doc["upload_id"] = upload_id

try:
Predictions.insert_one(doc)
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Failed to store prediction: {exc}")

return {
"species": prediction.species,
"confidence": prediction.confidence,
}


def predict_from_payload(payload: Dict[str, Any]) -> Dict[str, Any]:
adapter = HttpModelAdapter()
model_inputs = _build_inputs_from_payload(payload)

try:
prediction = adapter.predict(model_inputs)
except Exception as exc:
_raise_prediction_http_error(exc)
raise

return {
"species": prediction.species,
"confidence": prediction.confidence,
"model_response": prediction.raw,
}
Loading
Loading