Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions examples/tutorials/forecasting_with_workflow_presets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,64 @@
"fig.show()"
]
},
{
"cell_type": "markdown",
"id": "609b7509",
"metadata": {},
"source": [
"## 🔬 Step 9: Visualize Feature Contributions (SHAP)\n",
"\n",
"While feature importance shows **which** features matter overall, **contributions**\n",
"show how each feature pushed the prediction up or down for every individual timestep.\n",
"GBLinear models expose these as SHAP values via `predict_contributions()`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1dddfe95",
"metadata": {},
"outputs": [],
"source": [
"# Compute per-timestep feature contributions for the forecast period\n",
"from openstef_models.explainability import ContributionsPlotter\n",
"\n",
"contributions = workflow.model.predict_contributions(forecast_dataset)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "539e6156",
"metadata": {},
"outputs": [],
"source": [
"# Heatmap: contributions over time with prediction line\n",
"ContributionsPlotter.plot_heatmap(contributions, top_n=10, show_prediction=True).show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d68f4bbc",
"metadata": {},
"outputs": [],
"source": [
"# Waterfall: decompose a single timestep's prediction\n",
"ContributionsPlotter.plot_waterfall(contributions, timestep=0, top_n=10).show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19582854",
"metadata": {},
"outputs": [],
"source": [
"# Bar chart: mean absolute contribution per feature\n",
"ContributionsPlotter.plot_bar(contributions, top_n=10).show()"
]
},
{
"cell_type": "markdown",
"id": "1f53f172",
Expand Down
25 changes: 25 additions & 0 deletions examples/tutorials/forecasting_with_workflow_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,31 @@
fig.update_layout(title="🔍 Feature Importance Treemap")
fig.show()

# %% [markdown]
# ## 🔬 Step 9: Visualize Feature Contributions (SHAP)
#
# While feature importance shows **which** features matter overall, **contributions**
# show how each feature pushed the prediction up or down for every individual timestep.
# GBLinear models expose these as SHAP values via `predict_contributions()`.

# %%
# Compute per-timestep feature contributions for the forecast period
from openstef_models.explainability import ContributionsPlotter

contributions = workflow.model.predict_contributions(forecast_dataset)

# %%
# Heatmap: contributions over time with prediction line
ContributionsPlotter.plot_heatmap(contributions, top_n=10, show_prediction=True).show()

# %%
# Waterfall: decompose a single timestep's prediction
ContributionsPlotter.plot_waterfall(contributions, timestep=0, top_n=10).show()

# %%
# Bar chart: mean absolute contribution per feature
ContributionsPlotter.plot_bar(contributions, top_n=10).show()

# %% [markdown]
# ---
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
"""

from .mixins import ContributionsMixin, ExplainableForecaster
from .plotters import FeatureImportancePlotter
from .plotters import ContributionsPlotter, FeatureImportancePlotter

__all__ = [
"ContributionsMixin",
"ContributionsPlotter",
"ExplainableForecaster",
"FeatureImportancePlotter",
]
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
"""

from abc import ABC, abstractmethod
from typing import Any, Literal

import pandas as pd
import plotly.graph_objects as go

from openstef_core.datasets import ForecastInputDataset, TimeSeriesDataset
from openstef_core.types import Q, Quantile
from openstef_models.explainability.plotters.contributions_plotter import ContributionsPlotter
from openstef_models.explainability.plotters.feature_importance_plotter import FeatureImportancePlotter


Expand Down Expand Up @@ -88,3 +90,37 @@ def predict_contributions(self, data: ForecastInputDataset) -> TimeSeriesDataset
rows are timesteps. A ``bias`` column may be included for the
model intercept/base value.
"""

def plot_contributions(
self,
data: ForecastInputDataset,
kind: Literal["heatmap", "waterfall", "bar"] = "heatmap",
**kwargs: Any,
) -> go.Figure:
"""Plot per-sample feature contributions.

Calls ``predict_contributions()`` and visualizes the result using the
requested chart type.

Args:
data: Preprocessed input data.
kind: Chart type — ``"heatmap"``, ``"waterfall"``, or ``"bar"``.
**kwargs: Forwarded to the corresponding plotter method
(e.g. ``top_n``, ``timestep``).

Returns:
Plotly Figure.

Raises:
ValueError: If *kind* is not one of the supported chart types.
"""
contributions = self.predict_contributions(data)
plotters = {
"heatmap": ContributionsPlotter.plot_heatmap,
"waterfall": ContributionsPlotter.plot_waterfall,
"bar": ContributionsPlotter.plot_bar,
}
if kind not in plotters:
msg = f"Unknown plot kind {kind!r}. Choose from {list(plotters)}"
raise ValueError(msg)
return plotters[kind](contributions=contributions, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
scores and other model explanation outputs.
"""

from .contributions_plotter import ContributionsPlotter
from .feature_importance_plotter import FeatureImportancePlotter

__all__ = [
"ContributionsPlotter",
"FeatureImportancePlotter",
]
Loading
Loading