-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
117 lines (95 loc) · 4.17 KB
/
app.py
File metadata and controls
117 lines (95 loc) · 4.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import sys
import os
from dotenv import load_dotenv
load_dotenv()
mongo_db_url = os.getenv("MONGO_DB_URL") or os.getenv("MONGODB_URL_KEY")
app_port = int(os.getenv("PORT", "8000"))
train_api_key = os.getenv("TRAIN_API_KEY")
from networksecurity.exception.exception import NetworkSecurityException
from networksecurity.logging.logger import logging
from networksecurity.pipeline.training_pipeline import TrainingPipeline
from networksecurity.cloud.s3_syncer import S3Sync
from networksecurity.constant.training_pipeline import TRAINING_BUCKET_NAME
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, File, UploadFile,Request, Header, HTTPException, status
from uvicorn import run as app_run
from fastapi.responses import Response
from starlette.responses import RedirectResponse
import pandas as pd
from networksecurity.utils.main_utils.utils import load_object
from networksecurity.utils.ml_utils.model.estimator import NetworkModel
PREPROCESSOR_FILE_PATH = "final_models/preprocessor.pkl"
MODEL_FILE_PATH = "final_models/model.pkl"
LATEST_MODEL_S3_URI = f"s3://{TRAINING_BUCKET_NAME}/final_models/latest"
def load_network_model() -> NetworkModel | None:
try:
if not os.path.exists(PREPROCESSOR_FILE_PATH) or not os.path.exists(MODEL_FILE_PATH):
logging.info("Prediction artifacts are missing locally. Attempting to sync from %s", LATEST_MODEL_S3_URI)
try:
S3Sync().sync_folder_from_s3(folder="final_models", aws_bucket_url=LATEST_MODEL_S3_URI)
except Exception as sync_error:
logging.warning("Unable to sync prediction artifacts from S3: %s", sync_error)
if not os.path.exists(PREPROCESSOR_FILE_PATH) or not os.path.exists(MODEL_FILE_PATH):
logging.warning("Prediction artifacts are not available locally or in S3 sync target.")
return None
preprocessor = load_object(PREPROCESSOR_FILE_PATH)
final_model = load_object(MODEL_FILE_PATH)
return NetworkModel(preprocessor=preprocessor, model=final_model)
except Exception as e:
raise NetworkSecurityException(e, sys)
def authorize_train_request(api_key: str | None) -> None:
if train_api_key and api_key != train_api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or missing training API key.",
)
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
from fastapi.templating import Jinja2Templates
templates = Jinja2Templates(directory="./templates")
@app.on_event("startup")
async def startup_event():
app.state.network_model = load_network_model()
@app.get("/", tags=["authentication"])
async def index():
return RedirectResponse(url="/docs")
@app.post("/train")
async def train_route(x_api_key: str | None = Header(default=None, alias="x-api-key")):
try:
authorize_train_request(x_api_key)
train_pipeline=TrainingPipeline()
train_pipeline.run_pipeline()
app.state.network_model = load_network_model()
return Response("Training is successful")
except Exception as e:
raise NetworkSecurityException(e,sys)
@app.post("/predict")
async def predict_route(request: Request,file: UploadFile = File(...)):
try:
df=pd.read_csv(file.file)
network_model = getattr(app.state, "network_model", None)
if network_model is None:
raise FileNotFoundError("Prediction model is not loaded. Train the model or add artifacts to final_models.")
print(df.iloc[0])
y_pred = network_model.predict(df)
print(y_pred)
df['predicted_column'] = y_pred
print(df['predicted_column'])
df.to_csv('prediction_output/output.csv')
table_html = df.to_html(classes='table table-striped')
return templates.TemplateResponse(
request,
"table.html",
{"request": request, "table": table_html}
)
except Exception as e:
raise NetworkSecurityException(e,sys)
if __name__=="__main__":
app_run(app,host="0.0.0.0",port=app_port)