diff --git a/src/Components/.env_example b/src/Components/.env_example index d24bbcfa1..bcc3b134c 100644 --- a/src/Components/.env_example +++ b/src/Components/.env_example @@ -2,8 +2,31 @@ # === Stripe Payment Gateway === STRIPE_PRIVATE_KEY=sk_test_your_stripe_key_here -# === MongoDB Connection === -# Replace mongodb+srv://:@.mongodb.net/echoDB with "mongodb://modelUser:EchoNetAccess2023@ts-mongodb-cont:27017/EchoNet" -MONGODB_URI=mongodb+srv://:@.mongodb.net/echoDB +# === MongoDB === +DB_HOST=ts-mongodb-cont +DB_NAME=EchoNet +DB_USER=modelUser +DB_USER_PASS=EchoNetAccess2023 +DB_ROOT_USER=root +DB_ROOT_USER_PASS=root_password -# Please go through "Environment Setup & Stripe Integration Guide" uploaded in the Project Echo team files. \ No newline at end of file +# === MongoDB URI === +MONGODB_URI=mongodb://modelUser:EchoNetAccess2023@ts-mongodb-cont:27017/EchoNet +USER_MONGODB_URI=mongodb://root:root_password@ts-mongodb-cont:27017/UserSample?authSource=admin + +# === Redis === +NODE_ENV=development +REDIS_HOST=echo-redis + +# === API === +API_HOST=ts-api-cont + +# === Mail === +MAIL_STARTTLS=true +MAIL_SSL_TLS=false + +# Model Adapter Configuration +MODEL_SERVER_URL=http://ts-echo-model-cont:8501/v1/models/echo_model/versions/1:predict +MODEL_REQUEST_TIMEOUT_SECONDS=10 +MODEL_SIGNATURE_NAME=serving_default +MODEL_CLASS_NAMES= \ No newline at end of file diff --git a/src/Components/API/app/adapters/__init__.py b/src/Components/API/app/adapters/__init__.py new file mode 100644 index 000000000..de81f4550 --- /dev/null +++ b/src/Components/API/app/adapters/__init__.py @@ -0,0 +1,13 @@ +from app.adapters.model_adapter import ( + AdapterError, + AdapterTimeoutError, + HttpModelAdapter, + PredictionResult, +) + +__all__ = [ + "AdapterError", + "AdapterTimeoutError", + "HttpModelAdapter", + "PredictionResult", +] diff --git a/src/Components/API/app/adapters/model_adapter.py b/src/Components/API/app/adapters/model_adapter.py new file mode 100644 index 000000000..003a58b9d --- /dev/null +++ b/src/Components/API/app/adapters/model_adapter.py @@ -0,0 +1,120 @@ +import os +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import requests + + +class AdapterError(Exception): + pass + + +class AdapterTimeoutError(AdapterError): + pass + + +@dataclass +class PredictionResult: + species: str + confidence: float + raw: Optional[Dict[str, Any]] = None + + +class HttpModelAdapter: + def __init__( + self, + model_server_url: Optional[str] = None, + timeout_seconds: Optional[float] = None, + signature_name: Optional[str] = None, + class_names: Optional[List[str]] = None, + ) -> None: + self.model_server_url = model_server_url or os.getenv( + "MODEL_SERVER_URL", + "http://ts-echo-model-cont:8501/v1/models/echo_model/versions/1:predict", + ) + self.timeout_seconds = timeout_seconds or float( + os.getenv("MODEL_REQUEST_TIMEOUT_SECONDS", "10") + ) + self.signature_name = signature_name or os.getenv( + "MODEL_SIGNATURE_NAME", "serving_default" + ) + self.class_names = class_names or self._parse_class_names( + os.getenv("MODEL_CLASS_NAMES", "") + ) + + @staticmethod + def _parse_class_names(value: str) -> List[str]: + if not value: + return [] + return [item.strip() for item in value.split(",") if item.strip()] + + def predict(self, model_inputs: Any) -> PredictionResult: + payload = { + "signature_name": self.signature_name, + "inputs": model_inputs, + } + + try: + response = requests.post( + self.model_server_url, + json=payload, + timeout=self.timeout_seconds, + ) + except requests.Timeout as exc: + raise AdapterTimeoutError("Model server request timed out") from exc + except requests.RequestException as exc: + raise AdapterError(f"Failed to connect model server: {exc}") from exc + + if response.status_code >= 400: + raise AdapterError( + f"Model server returned status {response.status_code}: {response.text}" + ) + + try: + body = response.json() + except ValueError as exc: + raise AdapterError("Model server returned non-JSON response") from exc + + return self._map_response(body) + + def _map_response(self, body: Dict[str, Any]) -> PredictionResult: + vector = self._extract_scores(body) + if not vector: + raise AdapterError("Model server response does not contain prediction scores") + + max_index = max(range(len(vector)), key=lambda idx: vector[idx]) + confidence = float(vector[max_index]) + species = ( + self.class_names[max_index] + if self.class_names and max_index < len(self.class_names) + else f"class_{max_index}" + ) + return PredictionResult(species=species, confidence=confidence, raw=body) + + def _extract_scores(self, body: Dict[str, Any]) -> List[float]: + candidates: List[Any] = [] + if "outputs" in body: + candidates.append(body["outputs"]) + if "predictions" in body: + candidates.append(body["predictions"]) + + for candidate in candidates: + vector = self._flatten_first_numeric_vector(candidate) + if vector: + return vector + return [] + + def _flatten_first_numeric_vector(self, value: Any) -> List[float]: + if isinstance(value, list): + if value and all(isinstance(item, (int, float)) for item in value): + return [float(item) for item in value] + for item in value: + nested = self._flatten_first_numeric_vector(item) + if nested: + return nested + elif isinstance(value, dict): + for nested_value in value.values(): + nested = self._flatten_first_numeric_vector(nested_value) + if nested: + return nested + return [] diff --git a/src/Components/API/app/routers/detections.py b/src/Components/API/app/routers/detections.py index 0156e4caf..e06602400 100644 --- a/src/Components/API/app/routers/detections.py +++ b/src/Components/API/app/routers/detections.py @@ -7,6 +7,7 @@ from app import detections as detections_service from app.middleware.pause_guard import pause_guard from app.services.budget import enforce_and_consume +from app.services.predictions import predict_from_payload router = APIRouter( prefix="/detections", @@ -110,4 +111,4 @@ def delete_detection_endpoint( def predict_endpoint(payload: Dict[str, Any] = Body(...)): enforce_and_consume("species_predictor", cost=5) - return {"status": "not_implemented_here", "received": payload} + return predict_from_payload(payload) diff --git a/src/Components/API/app/routers/species_predictor.py b/src/Components/API/app/routers/species_predictor.py index 81e8373ce..51d0abdba 100644 --- a/src/Components/API/app/routers/species_predictor.py +++ b/src/Components/API/app/routers/species_predictor.py @@ -1,19 +1,13 @@ from fastapi import APIRouter, UploadFile, File, HTTPException, Form from fastapi.responses import JSONResponse from typing import Optional -from app.database import Predictions import datetime import re +from app.database import Predictions +from app.services.predictions import predict_uploaded_audio router = APIRouter() -# Placeholder prediction -def predict_species(audio_file: UploadFile): - return { - "species": "Crimson Rosella", - "confidence": 0.92 - } - @router.post("/predict") async def predict( audio: UploadFile = File(...), @@ -31,23 +25,20 @@ async def predict( else: raise HTTPException(status_code=400, detail="Invalid upload_id format") - prediction = predict_species(audio) - - # Persist prediction result try: - doc = { - "filename": audio.filename, - "predicted_species": prediction["species"], - "confidence": prediction["confidence"], - "timestamp": datetime.datetime.utcnow(), - "user_id": user_id, - } - if valid_upload_id: - doc["upload_id"] = valid_upload_id - - Predictions.insert_one(doc) + audio_bytes = await audio.read() + if not audio_bytes: + raise HTTPException(status_code=400, detail="Empty audio file provided") + prediction = predict_uploaded_audio( + filename=audio.filename, + audio_bytes=audio_bytes, + user_id=user_id, + upload_id=valid_upload_id, + ) + except HTTPException: + raise except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to store prediction: {e}") + raise HTTPException(status_code=500, detail=f"Prediction failed: {e}") return JSONResponse(content=prediction) diff --git a/src/Components/API/app/services/predictions.py b/src/Components/API/app/services/predictions.py new file mode 100644 index 000000000..9cb464335 --- /dev/null +++ b/src/Components/API/app/services/predictions.py @@ -0,0 +1,90 @@ +import base64 +import datetime +from typing import Any, Dict, Optional + +from fastapi import HTTPException + +from app.adapters import AdapterError, AdapterTimeoutError, HttpModelAdapter +from app.database import Predictions + + +def _raise_prediction_http_error(exc: Exception) -> None: + if isinstance(exc, AdapterTimeoutError): + raise HTTPException(status_code=504, detail="Model service timeout") + if isinstance(exc, AdapterError): + raise HTTPException(status_code=502, detail=str(exc)) + raise HTTPException(status_code=500, detail=f"Prediction failed: {exc}") + + +def _build_inputs_from_audio_bytes(audio_bytes: bytes, filename: str) -> Dict[str, Any]: + return { + "audio_b64": base64.b64encode(audio_bytes).decode("utf-8"), + "filename": filename, + } + + +def _build_inputs_from_payload(payload: Dict[str, Any]) -> Dict[str, Any]: + if "inputs" in payload: + return payload["inputs"] + if "audio_base64" in payload: + return {"audio_b64": payload["audio_base64"]} + if "audioClip" in payload: + return {"audio_b64": payload["audioClip"]} + raise HTTPException( + status_code=400, + detail="Payload must include either 'inputs', 'audio_base64', or 'audioClip'", + ) + + +def predict_uploaded_audio( + *, + filename: str, + audio_bytes: bytes, + user_id: Optional[str] = None, + upload_id: Optional[str] = None, +) -> Dict[str, Any]: + adapter = HttpModelAdapter() + model_inputs = _build_inputs_from_audio_bytes(audio_bytes, filename) + + try: + prediction = adapter.predict(model_inputs) + except Exception as exc: + _raise_prediction_http_error(exc) + raise + + doc = { + "filename": filename, + "predicted_species": prediction.species, + "confidence": prediction.confidence, + "timestamp": datetime.datetime.utcnow(), + "user_id": user_id, + } + if upload_id: + doc["upload_id"] = upload_id + + try: + Predictions.insert_one(doc) + except Exception as exc: + raise HTTPException(status_code=500, detail=f"Failed to store prediction: {exc}") + + return { + "species": prediction.species, + "confidence": prediction.confidence, + } + + +def predict_from_payload(payload: Dict[str, Any]) -> Dict[str, Any]: + adapter = HttpModelAdapter() + model_inputs = _build_inputs_from_payload(payload) + + try: + prediction = adapter.predict(model_inputs) + except Exception as exc: + _raise_prediction_http_error(exc) + raise + + return { + "species": prediction.species, + "confidence": prediction.confidence, + "model_response": prediction.raw, + } diff --git a/src/Components/docker-compose.yml b/src/Components/docker-compose.yml index e56098976..5cfe9b2a6 100644 --- a/src/Components/docker-compose.yml +++ b/src/Components/docker-compose.yml @@ -11,6 +11,8 @@ services: image: ts-echo-model container_name: ts-echo-model-cont command: --model_config_file=/models/models.config + env_file: + - .env networks: - echo-net ports: @@ -26,6 +28,8 @@ services: dockerfile: Engine.Dockerfile image: ts-echo-engine container_name: ts-echo-engine-cont + env_file: + - .env networks: - echo-net volumes: @@ -40,8 +44,8 @@ services: dockerfile: HMI.Dockerfile image: ts-echo-hmi container_name: ts-echo-hmi-cont - # env_file: - # - .env + env_file: + - .env networks: - echo-net volumes: @@ -55,9 +59,7 @@ services: - echo-redis stdin_open: true tty: true - environment: - - NODE_ENV=development - - API_HOST=ts-api-cont + echo_mqtt: @@ -105,9 +107,8 @@ services: - ./API:/app # Mount the API's src directory to /app in the container - credentials_volume:/root/.config/gcloud/ - weather_data:/app/weather_data # Add a volume for weather data - environment: - MAIL_STARTTLS: "true" - MAIL_SSL_TLS: "false" + env_file: + - .env stdin_open: false @@ -120,11 +121,13 @@ services: restart: always ports: - 8888:8081 + env_file: + - .env environment: - ME_CONFIG_MONGODB_SERVER: ts-mongodb-cont + ME_CONFIG_MONGODB_SERVER: ${DB_HOST} ME_CONFIG_MONGODB_ENABLE_ADMIN: "true" - ME_CONFIG_MONGODB_ADMINUSERNAME: root - ME_CONFIG_MONGODB_ADMINPASSWORD: root_password + ME_CONFIG_MONGODB_ADMINUSERNAME: ${DB_ROOT_USER} + ME_CONFIG_MONGODB_ADMINPASSWORD: ${DB_ROOT_USER_PASS} depends_on: - echo_store networks: @@ -139,10 +142,12 @@ services: container_name: ts-mongodb-cont networks: - echo-net + env_file: + - .env environment: - MONGO_INITDB_ROOT_USERNAME: root - MONGO_INITDB_ROOT_PASSWORD: root_password - MONGO_INITDB_DATABASE: EchoNet + MONGO_INITDB_ROOT_USERNAME: ${DB_ROOT_USER} + MONGO_INITDB_ROOT_PASSWORD: ${DB_ROOT_USER_PASS} + MONGO_INITDB_DATABASE: ${DB_NAME} ports: - "27017:27017" volumes: