From 4f414b02ff9378c5c3a581e43854a0159723fce5 Mon Sep 17 00:00:00 2001 From: BabaKaElijah Date: Thu, 22 Jan 2026 13:34:49 +0200 Subject: [PATCH] Add SHAP explanations to API and Streamlit --- requirements.txt | 1 + src/app.py | 18 +++++++++++++++++- src/predict_pipeline.py | 34 ++++++++++++++++++++++++++++++++++ src/streamlit_app.py | 6 +++++- 4 files changed, 57 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index a59aa40..4271230 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ fastapi>=0.103 uvicorn>=0.23 jinja2>=3.1 streamlit>=1.26 +shap>=0.45 # Dev pytest>=7.4 diff --git a/src/app.py b/src/app.py index 4ea01d1..5dbaaa0 100644 --- a/src/app.py +++ b/src/app.py @@ -4,7 +4,7 @@ import pandas as pd from logger import get_logger -from predict_pipeline import predict +from predict_pipeline import predict, explain app = FastAPI(title="Inventory Analysis API") @@ -76,3 +76,19 @@ def predict_units_sold(payload: PredictRequest): except Exception as exc: logger.exception("Prediction failed") raise HTTPException(status_code=500, detail=str(exc)) + + +@app.post("/explain") +def explain_prediction(payload: PredictRequest, top_n: int = 10): + try: + features = payload.to_feature_dict() + df = pd.DataFrame([features]) + preds = predict(df) + prediction = float(preds[0]) + contributions = explain(df, top_n=top_n) + return {"prediction": prediction, "contributions": contributions} + except FileNotFoundError: + raise HTTPException(status_code=404, detail="Model not found. Train the model first.") + except Exception as exc: + logger.exception("Explain failed") + raise HTTPException(status_code=500, detail=str(exc)) diff --git a/src/predict_pipeline.py b/src/predict_pipeline.py index 69149c7..3b23ad7 100644 --- a/src/predict_pipeline.py +++ b/src/predict_pipeline.py @@ -12,3 +12,37 @@ def load_model(): def predict(input_df: pd.DataFrame): model = load_model() return model.predict(input_df) + + +def explain(input_df: pd.DataFrame, top_n: int = 10): + model = load_model() + preprocessor = model.named_steps["preprocess"] + estimator = model.named_steps["model"] + + X = preprocessor.transform(input_df) + feature_names = preprocessor.get_feature_names_out() + + import shap + + explainer = shap.TreeExplainer(estimator) + shap_values = explainer.shap_values(X) + if isinstance(shap_values, list): + shap_values = shap_values[0] + + if hasattr(X, "toarray"): + row_values = X[0].toarray().ravel() + else: + row_values = X[0] + + contributions = [] + for name, value, impact in zip(feature_names, row_values, shap_values[0]): + contributions.append( + { + "feature": str(name), + "value": float(value), + "impact": float(impact), + } + ) + + contributions.sort(key=lambda x: abs(x["impact"]), reverse=True) + return contributions[:top_n] diff --git a/src/streamlit_app.py b/src/streamlit_app.py index b72f9b0..77b5cb7 100644 --- a/src/streamlit_app.py +++ b/src/streamlit_app.py @@ -1,7 +1,7 @@ import pandas as pd import streamlit as st -from predict_pipeline import predict +from predict_pipeline import explain, predict st.set_page_config(page_title="Inventory Demand Forecast", layout="centered") @@ -73,6 +73,10 @@ df = pd.DataFrame([payload]) prediction = float(predict(df)[0]) st.success(f"Predicted Units Sold: {prediction:.2f}") + + st.subheader("Top Feature Contributions") + contributions = explain(df, top_n=10) + st.dataframe(contributions, use_container_width=True) except FileNotFoundError: st.error("Model not found. Train the model first.") except Exception as exc: