diff --git a/examples/tutorials/forecasting_with_workflow_presets.ipynb b/examples/tutorials/forecasting_with_workflow_presets.ipynb index cdf4a3800..dd0910ee4 100644 --- a/examples/tutorials/forecasting_with_workflow_presets.ipynb +++ b/examples/tutorials/forecasting_with_workflow_presets.ipynb @@ -335,6 +335,64 @@ "fig.show()" ] }, + { + "cell_type": "markdown", + "id": "609b7509", + "metadata": {}, + "source": [ + "## 🔬 Step 9: Visualize Feature Contributions (SHAP)\n", + "\n", + "While feature importance shows **which** features matter overall, **contributions**\n", + "show how each feature pushed the prediction up or down for every individual timestep.\n", + "GBLinear models expose these as SHAP values via `predict_contributions()`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1dddfe95", + "metadata": {}, + "outputs": [], + "source": [ + "# Compute per-timestep feature contributions for the forecast period\n", + "from openstef_models.explainability import ContributionsPlotter\n", + "\n", + "contributions = workflow.model.predict_contributions(forecast_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "539e6156", + "metadata": {}, + "outputs": [], + "source": [ + "# Heatmap: contributions over time with prediction line\n", + "ContributionsPlotter.plot_heatmap(contributions, top_n=10, show_prediction=True).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d68f4bbc", + "metadata": {}, + "outputs": [], + "source": [ + "# Waterfall: decompose a single timestep's prediction\n", + "ContributionsPlotter.plot_waterfall(contributions, timestep=0, top_n=10).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19582854", + "metadata": {}, + "outputs": [], + "source": [ + "# Bar chart: mean absolute contribution per feature\n", + "ContributionsPlotter.plot_bar(contributions, top_n=10).show()" + ] + }, { "cell_type": "markdown", "id": "1f53f172", diff --git a/examples/tutorials/forecasting_with_workflow_presets.py b/examples/tutorials/forecasting_with_workflow_presets.py index aaba18a08..c1d5526b3 100644 --- a/examples/tutorials/forecasting_with_workflow_presets.py +++ b/examples/tutorials/forecasting_with_workflow_presets.py @@ -234,6 +234,31 @@ fig.update_layout(title="🔍 Feature Importance Treemap") fig.show() +# %% [markdown] +# ## 🔬 Step 9: Visualize Feature Contributions (SHAP) +# +# While feature importance shows **which** features matter overall, **contributions** +# show how each feature pushed the prediction up or down for every individual timestep. +# GBLinear models expose these as SHAP values via `predict_contributions()`. + +# %% +# Compute per-timestep feature contributions for the forecast period +from openstef_models.explainability import ContributionsPlotter + +contributions = workflow.model.predict_contributions(forecast_dataset) + +# %% +# Heatmap: contributions over time with prediction line +ContributionsPlotter.plot_heatmap(contributions, top_n=10, show_prediction=True).show() + +# %% +# Waterfall: decompose a single timestep's prediction +ContributionsPlotter.plot_waterfall(contributions, timestep=0, top_n=10).show() + +# %% +# Bar chart: mean absolute contribution per feature +ContributionsPlotter.plot_bar(contributions, top_n=10).show() + # %% [markdown] # --- # diff --git a/packages/openstef-models/src/openstef_models/explainability/__init__.py b/packages/openstef-models/src/openstef_models/explainability/__init__.py index c31525b3d..bfc562145 100644 --- a/packages/openstef-models/src/openstef_models/explainability/__init__.py +++ b/packages/openstef-models/src/openstef_models/explainability/__init__.py @@ -8,10 +8,11 @@ """ from .mixins import ContributionsMixin, ExplainableForecaster -from .plotters import FeatureImportancePlotter +from .plotters import ContributionsPlotter, FeatureImportancePlotter __all__ = [ "ContributionsMixin", + "ContributionsPlotter", "ExplainableForecaster", "FeatureImportancePlotter", ] diff --git a/packages/openstef-models/src/openstef_models/explainability/mixins.py b/packages/openstef-models/src/openstef_models/explainability/mixins.py index f9bb717ea..bfdbd4706 100644 --- a/packages/openstef-models/src/openstef_models/explainability/mixins.py +++ b/packages/openstef-models/src/openstef_models/explainability/mixins.py @@ -9,12 +9,14 @@ """ from abc import ABC, abstractmethod +from typing import Any, Literal import pandas as pd import plotly.graph_objects as go from openstef_core.datasets import ForecastInputDataset, TimeSeriesDataset from openstef_core.types import Q, Quantile +from openstef_models.explainability.plotters.contributions_plotter import ContributionsPlotter from openstef_models.explainability.plotters.feature_importance_plotter import FeatureImportancePlotter @@ -88,3 +90,37 @@ def predict_contributions(self, data: ForecastInputDataset) -> TimeSeriesDataset rows are timesteps. A ``bias`` column may be included for the model intercept/base value. """ + + def plot_contributions( + self, + data: ForecastInputDataset, + kind: Literal["heatmap", "waterfall", "bar"] = "heatmap", + **kwargs: Any, + ) -> go.Figure: + """Plot per-sample feature contributions. + + Calls ``predict_contributions()`` and visualizes the result using the + requested chart type. + + Args: + data: Preprocessed input data. + kind: Chart type — ``"heatmap"``, ``"waterfall"``, or ``"bar"``. + **kwargs: Forwarded to the corresponding plotter method + (e.g. ``top_n``, ``timestep``). + + Returns: + Plotly Figure. + + Raises: + ValueError: If *kind* is not one of the supported chart types. + """ + contributions = self.predict_contributions(data) + plotters = { + "heatmap": ContributionsPlotter.plot_heatmap, + "waterfall": ContributionsPlotter.plot_waterfall, + "bar": ContributionsPlotter.plot_bar, + } + if kind not in plotters: + msg = f"Unknown plot kind {kind!r}. Choose from {list(plotters)}" + raise ValueError(msg) + return plotters[kind](contributions=contributions, **kwargs) diff --git a/packages/openstef-models/src/openstef_models/explainability/plotters/__init__.py b/packages/openstef-models/src/openstef_models/explainability/plotters/__init__.py index b3d603eec..027485f79 100644 --- a/packages/openstef-models/src/openstef_models/explainability/plotters/__init__.py +++ b/packages/openstef-models/src/openstef_models/explainability/plotters/__init__.py @@ -8,8 +8,10 @@ scores and other model explanation outputs. """ +from .contributions_plotter import ContributionsPlotter from .feature_importance_plotter import FeatureImportancePlotter __all__ = [ + "ContributionsPlotter", "FeatureImportancePlotter", ] diff --git a/packages/openstef-models/src/openstef_models/explainability/plotters/contributions_plotter.py b/packages/openstef-models/src/openstef_models/explainability/plotters/contributions_plotter.py new file mode 100644 index 000000000..97de74ae1 --- /dev/null +++ b/packages/openstef-models/src/openstef_models/explainability/plotters/contributions_plotter.py @@ -0,0 +1,228 @@ +# SPDX-FileCopyrightText: 2026 Contributors to the OpenSTEF project +# +# SPDX-License-Identifier: MPL-2.0 + +"""Visualizations for per-sample feature contributions (SHAP values).""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import plotly.graph_objects as go +from plotly.subplots import make_subplots # pyright: ignore[reportUnknownVariableType] + +from openstef_core.datasets import TimeSeriesDataset # noqa: TC001 # runtime needed for pyright + +if TYPE_CHECKING: + import pandas as pd + + +class ContributionsPlotter: + """Visualizations for per-timestep feature contributions.""" + + @staticmethod + def plot_heatmap( + contributions: TimeSeriesDataset, + top_n: int = 10, + target_column: str = "load", + bias_column: str = "bias", + *, + show_prediction: bool = True, + ) -> go.Figure: + """Create an interactive heatmap of feature contributions over time. + + X-axis is the prediction datetime, Y-axis shows feature names ranked by mean absolute contribution + (most important at top). Color ranges from blue (negative) through white (zero) to red (positive). + When ``show_prediction`` is True a line plot of the model prediction (sum of contributions + bias) + is shown above the heatmap. + + Args: + contributions: Output of ``predict_contributions()``. + top_n: Number of top features to show (ranked by mean absolute contribution). + target_column: Name of the target column to exclude. Default "load". + bias_column: Name of the bias column. Default "bias". + show_prediction: If True, add a prediction line subplot above the heatmap. Default True. + + Returns: + Plotly Figure with a diverging heatmap centered at zero (and optional prediction line). + """ + bias = contributions.data[bias_column] if bias_column in contributions.data.columns else None + cols_to_drop = [c for c in [target_column, bias_column] if c in contributions.data.columns] + df = contributions.data.drop(columns=cols_to_drop) + ranked: list[str] = df.abs().mean().sort_values(ascending=False).head(top_n).index.tolist() + + # Most-important feature at top of Y-axis + y_labels = list(reversed(ranked)) + + heatmap = go.Heatmap( + z=df[y_labels].T.values, + x=df.index, + y=y_labels, + colorscale="RdBu_r", + zmid=0, + colorbar={"title": "Contribution"}, + showlegend=False, + ) + + if show_prediction: + prediction = df.sum(axis=1) + if bias is not None: + prediction += bias + + fig = make_subplots( + rows=2, + cols=1, + shared_xaxes=True, + row_heights=[0.2, 0.8], + vertical_spacing=0.03, + ) + + fig.add_trace( # pyright: ignore[reportUnknownMemberType] + go.Scatter( + x=df.index, + y=prediction, + mode="lines", + name="Prediction", + line={"color": "black", "width": 1.5}, + showlegend=False, + ), + row=1, + col=1, + ) + fig.add_trace(heatmap, row=2, col=1) # pyright: ignore[reportUnknownMemberType] + + fig.update_layout( # pyright: ignore[reportUnknownMemberType] + yaxis_title="Prediction", + yaxis2_title="Feature", + xaxis2_title="Time", + margin={"t": 30, "r": 10, "b": 40, "l": 120}, + ) + else: + fig = go.Figure( + data=heatmap, + layout={ + "xaxis_title": "Time", + "yaxis_title": "Feature", + "margin": {"t": 30, "r": 10, "b": 40, "l": 120}, + }, + ) + + return fig + + @staticmethod + def plot_waterfall( + contributions: TimeSeriesDataset, + timestep: int = 0, + top_n: int = 10, + target_column: str = "load", + bias_column: str = "bias", + ) -> go.Figure: + """Create a waterfall chart decomposing a single timestep's prediction. + + Shows how the bias (base value) is pushed up or down by each feature's + contribution to arrive at the final prediction. + + Args: + contributions: Output of ``predict_contributions()``. + timestep: Row index (0-based) of the timestep to explain. + top_n: Number of top features to show. Remaining features are + aggregated into an "other" bar. + target_column: Name of the target column to exclude. Default "load". + bias_column: Name of the bias column used as base value. Default "bias". + + Returns: + Plotly Figure with waterfall chart. + """ + bias = contributions.data[bias_column] if bias_column in contributions.data.columns else None + cols_to_drop = [c for c in [target_column, bias_column] if c in contributions.data.columns] + df = contributions.data.drop(columns=cols_to_drop) + row = df.iloc[timestep] + base_value = float(bias.iloc[timestep]) if bias is not None else 0.0 + + # Rank by |contribution| for this specific timestep + abs_sorted = row.abs().sort_values(ascending=False) + top = abs_sorted.head(top_n).index.tolist() + remaining = [c for c in abs_sorted.index if c not in top] + + names: list[str] = [bias_column] + values: list[float] = [base_value] + measures: list[str] = ["absolute"] + + for feat in top: + names.append(feat) + values.append(float(row[feat])) # pyright: ignore[reportArgumentType] + measures.append("relative") + + if len(remaining) > 0: + other_sum = float(row[remaining].sum()) + names.append(f"other ({len(remaining)})") + values.append(other_sum) + measures.append("relative") + + names.append("Prediction") + values.append(base_value + float(row.sum())) + measures.append("total") + + timestamp = contributions.data.index[timestep] + return go.Figure( + go.Waterfall( + x=names, + y=values, + measure=measures, + connector={"line": {"color": "grey", "width": 0.5}}, + increasing={"marker": {"color": "#ff4136"}}, + decreasing={"marker": {"color": "#0074d9"}}, + totals={"marker": {"color": "#2ecc40"}}, + textposition="outside", + text=[f"{v:+.4f}" if m == "relative" else f"{v:.4f}" for v, m in zip(values, measures, strict=True)], + ), + layout={ + "title": f"Contributions at {timestamp}", + "yaxis_title": "Contribution", + "margin": {"t": 50, "r": 10, "b": 40, "l": 60}, + "showlegend": False, + }, + ) + + @staticmethod + def plot_bar( + contributions: TimeSeriesDataset, + top_n: int = 10, + target_column: str = "load", + bias_column: str = "bias", + ) -> go.Figure: + """Create a horizontal bar chart of mean absolute contributions per feature. + + Features are ranked from most to least important (top to bottom). + + Args: + contributions: Output of ``predict_contributions()``. + top_n: Number of top features to show. + target_column: Name of the target column to exclude. Default "load". + bias_column: Name of the bias column to exclude. Default "bias". + + Returns: + Plotly Figure with horizontal bar chart. + """ + cols_to_drop = [c for c in [target_column, bias_column] if c in contributions.data.columns] + df = contributions.data.drop(columns=cols_to_drop) + mean_abs: pd.Series = df.abs().mean().sort_values(ascending=False).head(top_n) + + # Reverse for plotly (bottom-to-top rendering) + mean_abs = mean_abs.iloc[::-1] + + return go.Figure( + go.Bar( + x=mean_abs.values, # pyright: ignore[reportArgumentType] + y=mean_abs.index.tolist(), + orientation="h", + marker_color="#1f77b4", + hovertemplate="%{y}
mean |SHAP|: %{x:.4f}", + ), + layout={ + "xaxis_title": "mean |SHAP value|", + "yaxis_title": "Feature", + "margin": {"t": 30, "r": 10, "b": 40, "l": 120}, + "showlegend": False, + }, + ) diff --git a/packages/openstef-models/tests/unit/explainability/__init__.py b/packages/openstef-models/tests/unit/explainability/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/openstef-models/tests/unit/explainability/plotters/__init__.py b/packages/openstef-models/tests/unit/explainability/plotters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/openstef-models/tests/unit/explainability/plotters/test_contributions_plotter.py b/packages/openstef-models/tests/unit/explainability/plotters/test_contributions_plotter.py new file mode 100644 index 000000000..f2315c6d6 --- /dev/null +++ b/packages/openstef-models/tests/unit/explainability/plotters/test_contributions_plotter.py @@ -0,0 +1,219 @@ +# SPDX-FileCopyrightText: 2026 Contributors to the OpenSTEF project +# +# SPDX-License-Identifier: MPL-2.0 + +"""Tests for ContributionsPlotter.""" + +from datetime import timedelta + +import numpy as np +import pandas as pd +import plotly.graph_objects as go +import pytest + +from openstef_core.datasets import TimeSeriesDataset +from openstef_models.explainability.plotters.contributions_plotter import ContributionsPlotter + +TARGET_COLUMN = "load" +BIAS_COLUMN = "bias" + + +@pytest.fixture +def contributions_dataset() -> TimeSeriesDataset: + """5 features with deliberately different magnitudes, plus bias and target.""" + rng = np.random.default_rng(42) + index = pd.date_range("2025-01-01", periods=10, freq="15min") + data = pd.DataFrame( + { + "feat_a": rng.normal(10, 1, 10), # largest + "feat_b": rng.normal(5, 1, 10), + "feat_c": rng.normal(2, 0.5, 10), + "feat_d": rng.normal(1, 0.3, 10), + "feat_e": rng.normal(0.1, 0.05, 10), # smallest + BIAS_COLUMN: np.full(10, 100.0), + TARGET_COLUMN: rng.normal(120, 5, 10), + }, + index=index, + ) + return TimeSeriesDataset(data=data, sample_interval=timedelta(minutes=15)) + + +ALL_PLOT_METHODS = ["plot_heatmap", "plot_waterfall", "plot_bar"] + + +@pytest.mark.parametrize("method", ALL_PLOT_METHODS) +def test_returns_figure( + contributions_dataset: TimeSeriesDataset, + method: str, +) -> None: + # Act + result = getattr(ContributionsPlotter, method)(contributions_dataset) + + # Assert + assert isinstance(result, go.Figure) + + +@pytest.mark.parametrize("method", ALL_PLOT_METHODS) +def test_no_bias_column(method: str) -> None: + # Arrange + index = pd.date_range("2025-01-01", periods=5, freq="15min") + data = pd.DataFrame({"feat_x": [1, 2, 3, 4, 5], "feat_y": [5, 4, 3, 2, 1]}, index=index) + ds = TimeSeriesDataset(data=data, sample_interval=timedelta(minutes=15)) + + # Act + fig = getattr(ContributionsPlotter, method)(ds) + + # Assert + assert isinstance(fig, go.Figure) + + +def test_heatmap_with_prediction_has_two_traces(contributions_dataset: TimeSeriesDataset) -> None: + # Act + fig = ContributionsPlotter.plot_heatmap(contributions_dataset, show_prediction=True) + + # Assert + assert len(fig.data) == 2 + + +def test_heatmap_without_prediction_has_one_trace(contributions_dataset: TimeSeriesDataset) -> None: + # Act + fig = ContributionsPlotter.plot_heatmap(contributions_dataset, show_prediction=False) + + # Assert + assert len(fig.data) == 1 + + +def test_heatmap_excludes_target_and_bias(contributions_dataset: TimeSeriesDataset) -> None: + # Act + fig = ContributionsPlotter.plot_heatmap(contributions_dataset, show_prediction=False) + y_labels = list(fig.data[0].y) + + # Assert + assert TARGET_COLUMN not in y_labels + assert BIAS_COLUMN not in y_labels + + +def test_heatmap_features_sorted_by_importance(contributions_dataset: TimeSeriesDataset) -> None: + # Act + fig = ContributionsPlotter.plot_heatmap(contributions_dataset, top_n=5, show_prediction=False) + y_labels = list(fig.data[0].y) + + # Assert + assert y_labels[-1] == "feat_a" # most important at top (last in list) + assert y_labels[0] == "feat_e" # least important at bottom + + +def test_heatmap_top_n_limits_features(contributions_dataset: TimeSeriesDataset) -> None: + # Act + fig = ContributionsPlotter.plot_heatmap(contributions_dataset, top_n=3, show_prediction=False) + y_labels = list(fig.data[0].y) + + # Assert + assert len(y_labels) == 3 + assert set(y_labels) == {"feat_a", "feat_b", "feat_c"} + + +def test_heatmap_prediction_line_values(contributions_dataset: TimeSeriesDataset) -> None: + """Prediction = sum(feature contributions) + bias.""" + # Arrange + df = contributions_dataset.data + feature_cols = [c for c in df.columns if c not in {TARGET_COLUMN, BIAS_COLUMN}] + expected_prediction = df[feature_cols].sum(axis=1) + df[BIAS_COLUMN] + + # Act + fig = ContributionsPlotter.plot_heatmap(contributions_dataset, show_prediction=True) + prediction_trace = fig.data[0] + + # Assert + np.testing.assert_array_almost_equal(prediction_trace.y, expected_prediction.values) + + +def test_waterfall_starts_with_bias(contributions_dataset: TimeSeriesDataset) -> None: + # Act + fig = ContributionsPlotter.plot_waterfall(contributions_dataset) + wf = fig.data[0] + + # Assert + assert wf.x[0] == BIAS_COLUMN + assert wf.y[0] == pytest.approx(100.0) + assert wf.measure[0] == "absolute" + + +def test_waterfall_ends_with_prediction_total(contributions_dataset: TimeSeriesDataset) -> None: + # Arrange + df = contributions_dataset.data + feature_cols = [c for c in df.columns if c not in {TARGET_COLUMN, BIAS_COLUMN}] + expected_prediction = float(df[BIAS_COLUMN].iloc[0]) + float(df[feature_cols].iloc[0].sum()) + + # Act + fig = ContributionsPlotter.plot_waterfall(contributions_dataset, timestep=0) + wf = fig.data[0] + + # Assert + assert wf.x[-1] == "Prediction" + assert wf.measure[-1] == "total" + assert wf.y[-1] == pytest.approx(expected_prediction) + + +def test_waterfall_top_n_limits_bars(contributions_dataset: TimeSeriesDataset) -> None: + # Act + fig = ContributionsPlotter.plot_waterfall(contributions_dataset, top_n=2) + wf = fig.data[0] + names = list(wf.x) + + # Assert + assert len(names) == 5 # bias + 2 features + "other (...)" + Prediction + assert any("other" in n for n in names) + + +def test_waterfall_title_contains_timestamp(contributions_dataset: TimeSeriesDataset) -> None: + # Arrange + expected_ts = str(contributions_dataset.data.index[3]) + + # Act + fig = ContributionsPlotter.plot_waterfall(contributions_dataset, timestep=3) + + # Assert + assert expected_ts in fig.layout.title.text + + +def test_bar_excludes_target_and_bias(contributions_dataset: TimeSeriesDataset) -> None: + # Act + fig = ContributionsPlotter.plot_bar(contributions_dataset) + y_labels = list(fig.data[0].y) + + # Assert + assert TARGET_COLUMN not in y_labels + assert BIAS_COLUMN not in y_labels + + +def test_bar_top_n_limits_bars(contributions_dataset: TimeSeriesDataset) -> None: + # Act + fig = ContributionsPlotter.plot_bar(contributions_dataset, top_n=3) + y_labels = list(fig.data[0].y) + + # Assert + assert len(y_labels) == 3 + + +def test_bar_values_are_mean_absolute(contributions_dataset: TimeSeriesDataset) -> None: + # Arrange + df = contributions_dataset.data + feature_cols = [c for c in df.columns if c not in {TARGET_COLUMN, BIAS_COLUMN}] + expected = df[feature_cols].abs().mean().sort_values(ascending=True) # reversed for plotly + + # Act + fig = ContributionsPlotter.plot_bar(contributions_dataset, top_n=5) + bar = fig.data[0] + + # Assert + np.testing.assert_array_almost_equal(bar.x, expected.values) + + +def test_bar_most_important_feature_at_top(contributions_dataset: TimeSeriesDataset) -> None: + # Act + fig = ContributionsPlotter.plot_bar(contributions_dataset) + y_labels = list(fig.data[0].y) + + # Assert plotly renders bottom-to-top, so last y label is at the top + assert y_labels[-1] == "feat_a"