Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
)
from openstef_models.utils.data_split import DataSplitter
from openstef_models.utils.feature_selection import Exclude, FeatureSelection, Include
from openstef_models.workflows.callbacks import ModelPerformanceCallback
from openstef_models.workflows.custom_forecasting_workflow import (
CustomForecastingWorkflow,
ForecastingCallback,
Expand Down Expand Up @@ -300,6 +301,17 @@ class EnsembleForecastingWorkflowConfig(BaseConfig):
description="Penalty to apply to the old model's metric to bias selection towards newer models.",
)

model_performance_callback_enabled: bool = Field(
default=False,
description=(
"Whether to enable the ModelPerformanceCallback that evaluates model performance at the end of fitting."
),
)
model_performance_callback_metric_threshold: tuple[QuantileOrGlobal, str, MetricDirection, float] = Field(
default=(Q(0.5), "R2", "higher_is_better", 0.0),
description=("Metric to monitor for model performance threshold at the end of fitting. "),
)

verbosity: Literal[0, 1, 2, 3, True] = Field(
default=0, description="Verbosity level. 0=silent, 1=warning, 2=info, 3=debug"
)
Expand Down Expand Up @@ -594,6 +606,17 @@ def create_ensemble_forecasting_workflow(config: EnsembleForecastingWorkflowConf
)
)

if config.model_performance_callback_enabled:
quantile, metric_name, metric_direction, threshold = config.model_performance_callback_metric_threshold
callbacks.append(
ModelPerformanceCallback(
metric_name=metric_name,
threshold=threshold,
metric_direction=metric_direction,
quantile=quantile,
)
)

return CustomForecastingWorkflow(
model=EnsembleForecastingModel(
preprocessing=common_preprocessing,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from openstef_models.integrations.mlflow import MLFlowStorage, MLFlowStorageCallback
from openstef_models.mixins import ModelIdentifier
from openstef_models.models import ForecastingModel
from openstef_models.models.forecasting.constant_quantile_forecaster import ConstantQuantileForecaster
from openstef_models.models.forecasting.flatliner_forecaster import FlatlinerForecaster
from openstef_models.models.forecasting.gblinear_forecaster import GBLinearForecaster, GBLinearHyperParams
from openstef_models.models.forecasting.lgbm_forecaster import LGBMForecaster, LGBMHyperParams
Expand Down Expand Up @@ -65,6 +66,7 @@
)
from openstef_models.utils.data_split import DataSplitter
from openstef_models.utils.feature_selection import Exclude, FeatureSelection, Include
from openstef_models.workflows.callbacks import ModelPerformanceCallback
from openstef_models.workflows.custom_forecasting_workflow import (
CustomForecastingWorkflow,
ForecastingCallback,
Expand Down Expand Up @@ -116,7 +118,7 @@ class ForecastingWorkflowConfig(BaseConfig): # PredictionJob
)

# Model configuration
model: Literal["xgboost", "gblinear", "flatliner", "median", "lgbm", "lgbmlinear"] = Field(
model: Literal["xgboost", "gblinear", "flatliner", "median", "constant_quantile", "lgbm", "lgbmlinear"] = Field(
description="Type of forecasting model to use."
)
quantiles: list[Quantile] = Field(
Expand Down Expand Up @@ -287,6 +289,17 @@ class ForecastingWorkflowConfig(BaseConfig): # PredictionJob
description="Penalty to apply to the old model's metric to bias selection towards newer models.",
)

model_performance_callback_enabled: bool = Field(
default=False,
description=(
"Whether to enable the ModelPerformanceCallback that evaluates model performance at the end of fitting."
),
)
model_performance_callback_metric_threshold: tuple[QuantileOrGlobal, str, MetricDirection, float] = Field(
default=(Q(0.5), "R2", "higher_is_better", 0.0),
description=("Metric to monitor for model performance threshold at the end of fitting. "),
)

verbosity: Literal[0, 1, 2, 3, True] = Field(
default=0, description="Verbosity level. 0=silent, 1=warning, 2=info, 3=debug"
)
Expand Down Expand Up @@ -492,6 +505,13 @@ def create_forecasting_workflow(
horizons=config.horizons,
)
postprocessing = []
elif config.model == "constant_quantile":
preprocessing = []
forecaster = ConstantQuantileForecaster(
quantiles=config.quantiles,
horizons=config.horizons,
)
postprocessing = []
elif config.model == "flatliner":
preprocessing = []
forecaster = FlatlinerForecaster(
Expand Down Expand Up @@ -530,6 +550,17 @@ def create_forecasting_workflow(
)
)

if config.model_performance_callback_enabled:
quantile, metric_name, metric_direction, threshold = config.model_performance_callback_metric_threshold
callbacks.append(
ModelPerformanceCallback(
metric_name=metric_name,
threshold=threshold,
metric_direction=metric_direction,
quantile=quantile,
)
)

return CustomForecastingWorkflow(
model=ForecastingModel(
preprocessing=TransformPipeline(transforms=preprocessing),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,9 @@
"""Workflow callbacks for data capture, persistence, and debugging."""

from openstef_models.workflows.callbacks.data_save import DataSaveCallback
from openstef_models.workflows.callbacks.model_performance_callback import ModelPerformanceCallback

__all__ = ["DataSaveCallback"]
__all__ = [
"DataSaveCallback",
"ModelPerformanceCallback",
]
Loading