Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ fastapi>=0.103
uvicorn>=0.23
jinja2>=3.1
streamlit>=1.26
shap>=0.45

# Dev
pytest>=7.4
18 changes: 17 additions & 1 deletion src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd

from logger import get_logger
from predict_pipeline import predict
from predict_pipeline import predict, explain


app = FastAPI(title="Inventory Analysis API")
Expand Down Expand Up @@ -76,3 +76,19 @@ def predict_units_sold(payload: PredictRequest):
except Exception as exc:
logger.exception("Prediction failed")
raise HTTPException(status_code=500, detail=str(exc))


@app.post("/explain")
def explain_prediction(payload: PredictRequest, top_n: int = 10):
try:
features = payload.to_feature_dict()
df = pd.DataFrame([features])
preds = predict(df)
prediction = float(preds[0])
contributions = explain(df, top_n=top_n)
return {"prediction": prediction, "contributions": contributions}
except FileNotFoundError:
raise HTTPException(status_code=404, detail="Model not found. Train the model first.")
except Exception as exc:
logger.exception("Explain failed")
raise HTTPException(status_code=500, detail=str(exc))
34 changes: 34 additions & 0 deletions src/predict_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,37 @@ def load_model():
def predict(input_df: pd.DataFrame):
model = load_model()
return model.predict(input_df)


def explain(input_df: pd.DataFrame, top_n: int = 10):
model = load_model()
preprocessor = model.named_steps["preprocess"]
estimator = model.named_steps["model"]

X = preprocessor.transform(input_df)
feature_names = preprocessor.get_feature_names_out()

import shap

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Declare SHAP dependency for package installs

This new import shap introduces a runtime dependency that is only added to requirements.txt, but pyproject.toml still omits shap from [project].dependencies. If users install the app via pip install . (PEP 621), the import will raise ModuleNotFoundError and /explain/Streamlit will fail at runtime; consider adding shap to pyproject.toml dependencies (or an optional extra) so package installs remain functional.

Useful? React with 👍 / 👎.


explainer = shap.TreeExplainer(estimator)
shap_values = explainer.shap_values(X)
if isinstance(shap_values, list):
shap_values = shap_values[0]

if hasattr(X, "toarray"):
row_values = X[0].toarray().ravel()
else:
row_values = X[0]

contributions = []
for name, value, impact in zip(feature_names, row_values, shap_values[0]):
contributions.append(
{
"feature": str(name),
"value": float(value),
"impact": float(impact),
}
)

contributions.sort(key=lambda x: abs(x["impact"]), reverse=True)
return contributions[:top_n]
6 changes: 5 additions & 1 deletion src/streamlit_app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pandas as pd
import streamlit as st

from predict_pipeline import predict
from predict_pipeline import explain, predict


st.set_page_config(page_title="Inventory Demand Forecast", layout="centered")
Expand Down Expand Up @@ -73,6 +73,10 @@
df = pd.DataFrame([payload])
prediction = float(predict(df)[0])
st.success(f"Predicted Units Sold: {prediction:.2f}")

st.subheader("Top Feature Contributions")
contributions = explain(df, top_n=10)
st.dataframe(contributions, use_container_width=True)
except FileNotFoundError:
st.error("Model not found. Train the model first.")
except Exception as exc:
Expand Down