From 93dc1f3f6d1dca60ceca9898ccf0ba331cc54301 Mon Sep 17 00:00:00 2001 From: Gernot Maier Date: Thu, 14 May 2026 12:35:42 +0200 Subject: [PATCH 01/10] defensive hyperparameters --- src/eventdisplay_ml/hyper_parameters.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/eventdisplay_ml/hyper_parameters.py b/src/eventdisplay_ml/hyper_parameters.py index 0a24979..dc52ffe 100644 --- a/src/eventdisplay_ml/hyper_parameters.py +++ b/src/eventdisplay_ml/hyper_parameters.py @@ -33,13 +33,14 @@ "objective": "binary:logistic", "eval_metric": ["logloss", "auc"], "n_estimators": 5000, - "early_stopping_rounds": 50, - "max_depth": 7, - "learning_rate": 0.05, + "early_stopping_rounds": 100, + "max_depth": 4, + "learning_rate": 0.02, + "gamma": 0.2, "subsample": 0.8, - "colsample_bytree": 0.8, + "colsample_bytree": 0.6, "random_state": None, - "n_jobs": 48, + "n_jobs": 96, }, } } From 8e3d1187a5fc351df2bb0ea8094d2f71f792e574 Mon Sep 17 00:00:00 2001 From: Gernot Maier Date: Thu, 14 May 2026 17:19:17 +0200 Subject: [PATCH 02/10] allow to choose energy bin --- .../scripts/plot_classification_performance_metrics.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/eventdisplay_ml/scripts/plot_classification_performance_metrics.py b/src/eventdisplay_ml/scripts/plot_classification_performance_metrics.py index a46c11e..bdd8d74 100644 --- a/src/eventdisplay_ml/scripts/plot_classification_performance_metrics.py +++ b/src/eventdisplay_ml/scripts/plot_classification_performance_metrics.py @@ -142,10 +142,18 @@ def main(): parser = argparse.ArgumentParser(description="Plot TMVA and XGBoost metrics.") parser.add_argument("root_dir", help="Path to the TMVA BDT .root file") parser.add_argument("joblib_dir", help="Path to the XGB BDT .joblib file") + parser.add_argument( + "--energy-bin", + type=int, + choices=range(9), + default=None, + help="Plot only a single energy bin (0-8). If omitted, all bins are processed.", + ) args = parser.parse_args() # assume energy binning is identical in XGB and TMVA files. - for ebin in range(9): + energy_bins = [args.energy_bin] if args.energy_bin is not None else range(9) + for ebin in energy_bins: x_root, y_effs, y_effb = load_efficiency_tmva(args.root_dir, ebin) x_joblib, y_effs_xgb, y_effb_xgb = load_efficiency_xgb(args.joblib_dir, ebin) From c0d870f76a983d44f1d9f17530a0af772243a000 Mon Sep 17 00:00:00 2001 From: Gernot Maier Date: Thu, 14 May 2026 18:18:17 +0200 Subject: [PATCH 03/10] allow to use model_dir --- .../scripts/plot_training_evaluation.py | 109 +++++++++++++----- 1 file changed, 79 insertions(+), 30 deletions(-) diff --git a/src/eventdisplay_ml/scripts/plot_training_evaluation.py b/src/eventdisplay_ml/scripts/plot_training_evaluation.py index b9c9c1a..6b51138 100644 --- a/src/eventdisplay_ml/scripts/plot_training_evaluation.py +++ b/src/eventdisplay_ml/scripts/plot_training_evaluation.py @@ -119,57 +119,106 @@ def main(): "(stereo or classification)." ) ) - parser.add_argument( + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( "--model_file", - required=True, type=str, - help="Path to the trained model joblib file (e.g., dispdir_bdt.joblib).", + help="Path to a single trained model joblib file (e.g., dispdir_bdt.joblib).", + ) + group.add_argument( + "--model_dir", + type=str, + help="Directory containing multiple joblib model files. All *.joblib files will be processed.", ) parser.add_argument( "--output_file", type=str, default=None, - help="Path to save the output plot (PNG/PDF). If not provided, display interactively.", + help="Path to save the output plot (PNG/PDF). If not provided, display interactively. Only for single file mode.", + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Directory to save all output plots (required if --model_dir is used).", ) args = parser.parse_args() - model_path = Path(args.model_file) - if not model_path.exists(): - raise FileNotFoundError(f"Model file not found: {model_path}") + if args.model_file: + model_path = Path(args.model_file) + if not model_path.exists(): + raise FileNotFoundError(f"Model file not found: {model_path}") - _logger.info(f"Loading model from: {model_path}") - model_configs = joblib.load(model_path) + _logger.info(f"Loading model from: {model_path}") + model_configs = joblib.load(model_path) - # Extract the XGBoost model and its evaluation results - if "models" not in model_configs: - raise ValueError("Model file does not contain 'models' key.") + # Extract the XGBoost model and its evaluation results + if "models" not in model_configs: + raise ValueError("Model file does not contain 'models' key.") - if "xgboost" not in model_configs["models"]: - raise ValueError("Model file does not contain 'xgboost' model.") + if "xgboost" not in model_configs["models"]: + raise ValueError("Model file does not contain 'xgboost' model.") - xgb_model = model_configs["models"]["xgboost"]["model"] + xgb_model = model_configs["models"]["xgboost"]["model"] - if not hasattr(xgb_model, "evals_result"): - raise AttributeError( - "XGBoost model does not have 'evals_result' method. " - "Model may not have been trained with eval_set parameter." - ) + if not hasattr(xgb_model, "evals_result"): + raise AttributeError( + "XGBoost model does not have 'evals_result' method. " + "Model may not have been trained with eval_set parameter." + ) + + evals_result = xgb_model.evals_result() - evals_result = xgb_model.evals_result() + _logger.info(f"Model type: {type(xgb_model).__name__}") + _logger.info(f"Number of boosting rounds: {xgb_model.get_booster().num_boosted_rounds()}") - _logger.info(f"Model type: {type(xgb_model).__name__}") - _logger.info(f"Number of boosting rounds: {xgb_model.get_booster().num_boosted_rounds()}") + # Additional model info + if "features" in model_configs: + _logger.info(f"Number of features: {len(model_configs['features'])}") + if "targets" in model_configs: + _logger.info(f"Target variables: {model_configs['targets']}") - # Additional model info - if "features" in model_configs: - _logger.info(f"Number of features: {len(model_configs['features'])}") - if "targets" in model_configs: - _logger.info(f"Target variables: {model_configs['targets']}") + plot_training_curves(evals_result, args.output_file) + _logger.info("Plotting completed successfully.") + + elif args.model_dir: + model_dir = Path(args.model_dir) + if not model_dir.exists() or not model_dir.is_dir(): + raise FileNotFoundError(f"Model directory not found: {model_dir}") + + if not args.output_dir: + raise ValueError("--output_dir must be specified when using --model_dir.") + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + joblib_files = sorted(model_dir.glob("*.joblib")) + if not joblib_files: + raise FileNotFoundError(f"No joblib files found in directory: {model_dir}") + + for model_path in joblib_files: + _logger.info(f"Loading model from: {model_path}") + try: + model_configs = joblib.load(model_path) + except Exception as e: + _logger.error(f"Failed to load {model_path}: {e}") + continue + + if "models" not in model_configs or "xgboost" not in model_configs["models"]: + _logger.error(f"Skipping {model_path}: missing 'models/xgboost' key.") + continue + + xgb_model = model_configs["models"]["xgboost"].get("model") + if not hasattr(xgb_model, "evals_result"): + _logger.error(f"Skipping {model_path}: model missing 'evals_result'.") + continue - plot_training_curves(evals_result, args.output_file) + evals_result = xgb_model.evals_result() + output_file = output_dir / (model_path.stem + ".png") + plot_training_curves(evals_result, output_file) + _logger.info(f"Saved plot for {model_path.name} to {output_file}") - _logger.info("Plotting completed successfully.") + _logger.info("Batch plotting completed.") if __name__ == "__main__": From 069c20983634df942d3a4ce6251bbff5bb6c951e Mon Sep 17 00:00:00 2001 From: Gernot Maier Date: Thu, 14 May 2026 18:22:03 +0200 Subject: [PATCH 04/10] training output --- src/eventdisplay_ml/scripts/plot_training_evaluation.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/eventdisplay_ml/scripts/plot_training_evaluation.py b/src/eventdisplay_ml/scripts/plot_training_evaluation.py index 6b51138..e87cd46 100644 --- a/src/eventdisplay_ml/scripts/plot_training_evaluation.py +++ b/src/eventdisplay_ml/scripts/plot_training_evaluation.py @@ -179,7 +179,13 @@ def main(): if "targets" in model_configs: _logger.info(f"Target variables: {model_configs['targets']}") - plot_training_curves(evals_result, args.output_file) + output_file = args.output_file + if output_file is None: + # Save as training_evaluation_.png in current directory + output_file = f"training_evaluation_{model_path.stem}.png" + _logger.info(f"No --output_file given. Saving to {output_file}") + + plot_training_curves(evals_result, output_file) _logger.info("Plotting completed successfully.") elif args.model_dir: From cdebd9a2ba50f86376d201ba551efd5a1259b7a1 Mon Sep 17 00:00:00 2001 From: Gernot Maier Date: Thu, 14 May 2026 18:24:27 +0200 Subject: [PATCH 05/10] output file stem --- .../scripts/plot_training_evaluation.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/eventdisplay_ml/scripts/plot_training_evaluation.py b/src/eventdisplay_ml/scripts/plot_training_evaluation.py index e87cd46..af3ca0c 100644 --- a/src/eventdisplay_ml/scripts/plot_training_evaluation.py +++ b/src/eventdisplay_ml/scripts/plot_training_evaluation.py @@ -144,7 +144,6 @@ def main(): ) args = parser.parse_args() - if args.model_file: model_path = Path(args.model_file) if not model_path.exists(): @@ -179,13 +178,7 @@ def main(): if "targets" in model_configs: _logger.info(f"Target variables: {model_configs['targets']}") - output_file = args.output_file - if output_file is None: - # Save as training_evaluation_.png in current directory - output_file = f"training_evaluation_{model_path.stem}.png" - _logger.info(f"No --output_file given. Saving to {output_file}") - - plot_training_curves(evals_result, output_file) + plot_training_curves(evals_result, args.output_file) _logger.info("Plotting completed successfully.") elif args.model_dir: @@ -220,7 +213,7 @@ def main(): continue evals_result = xgb_model.evals_result() - output_file = output_dir / (model_path.stem + ".png") + output_file = output_dir / (f"training_evaluation_{model_path.stem}.png") plot_training_curves(evals_result, output_file) _logger.info(f"Saved plot for {model_path.name} to {output_file}") From 58aed457ba8921fc8052a8a0cb1278351fe1f037 Mon Sep 17 00:00:00 2001 From: Gernot Maier Date: Thu, 14 May 2026 18:34:30 +0200 Subject: [PATCH 06/10] joblib --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index cbf08cd..cff930d 100644 --- a/README.md +++ b/README.md @@ -315,6 +315,14 @@ eventdisplay-ml-plot-training-evaluation \ --output_file diagnostics/training_curves.png ``` +or for all joblib files in a directory: + +```bash +eventdisplay-ml-plot-training-evaluation \ + --model_dir models/ \ + --output_dir diagnostics/ +``` + Output: - Figure with one panel per tracked metric (for example `rmse`), showing training and test curves. From 3c3f7293bc4dd65dd5fec53a7e97d4b2b62f7057 Mon Sep 17 00:00:00 2001 From: Gernot Maier Date: Thu, 14 May 2026 18:35:30 +0200 Subject: [PATCH 07/10] changelog --- docs/changes/60.feature.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/changes/60.feature.md diff --git a/docs/changes/60.feature.md b/docs/changes/60.feature.md new file mode 100644 index 0000000..a1a76a8 --- /dev/null +++ b/docs/changes/60.feature.md @@ -0,0 +1 @@ +Improve classification hyperparameters with focus on robustness. From 4fc27e289616664b4c6a7a6cb8bd22da23375736 Mon Sep 17 00:00:00 2001 From: Gernot Maier Date: Thu, 14 May 2026 18:45:36 +0200 Subject: [PATCH 08/10] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- docs/changes/60.feature.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changes/60.feature.md b/docs/changes/60.feature.md index a1a76a8..9aa0782 100644 --- a/docs/changes/60.feature.md +++ b/docs/changes/60.feature.md @@ -1 +1 @@ -Improve classification hyperparameters with focus on robustness. +Improve classification hyperparameters with a focus on robustness, and add user-facing plotting CLI options for selecting `--model_dir`/`--output_dir` and `--energy-bin`. From 1365afb4b5dba5fa1226348e5068cf117bfe4849 Mon Sep 17 00:00:00 2001 From: Gernot Maier Date: Thu, 14 May 2026 18:50:00 +0200 Subject: [PATCH 09/10] add joblib gz --- src/eventdisplay_ml/diagnostic_utils.py | 3 +- src/eventdisplay_ml/models.py | 20 +++++++++--- .../scripts/diagnostic_shap_summary.py | 4 ++- .../scripts/optimize_classification.py | 9 ++++-- ...plot_classification_performance_metrics.py | 5 ++- .../scripts/plot_training_evaluation.py | 32 +++++++++++++++---- src/eventdisplay_ml/utils.py | 21 +++++++++++- 7 files changed, 78 insertions(+), 16 deletions(-) diff --git a/src/eventdisplay_ml/diagnostic_utils.py b/src/eventdisplay_ml/diagnostic_utils.py index aae271e..315d175 100644 --- a/src/eventdisplay_ml/diagnostic_utils.py +++ b/src/eventdisplay_ml/diagnostic_utils.py @@ -14,6 +14,7 @@ from sklearn.metrics import mean_squared_error from sklearn.model_selection import train_test_split +from eventdisplay_ml import utils from eventdisplay_ml.data_processing import load_training_data _logger = logging.getLogger(__name__) @@ -21,7 +22,7 @@ def _load_model_cfg(model_file): """Load full model dictionary and the first model configuration entry.""" - model_dict = joblib.load(model_file) + model_dict = joblib.load(utils.resolve_joblib_path(model_file)) models = model_dict.get("models", {}) model_cfg = next(iter(models.values())) if models else None return model_dict, model_cfg diff --git a/src/eventdisplay_ml/models.py b/src/eventdisplay_ml/models.py index 1578d7e..059468b 100644 --- a/src/eventdisplay_ml/models.py +++ b/src/eventdisplay_ml/models.py @@ -97,12 +97,24 @@ def load_classification_models(model_prefix, model_name): models = {} par = {} - pattern = f"{model_prefix.name}_ebin*.joblib" - files = sorted(model_dir_path.glob(pattern)) + pattern = re.compile(rf"^{re.escape(model_prefix.name)}_ebin(\d+)\.joblib(?:\.gz)?$") + matched_files = [ + file for file in model_dir_path.iterdir() if file.is_file() and pattern.match(file.name) + ] + files_by_bin = {} + for file in matched_files: + match = pattern.match(file.name) + if not match: + continue + e_bin = int(match.group(1)) + existing = files_by_bin.get(e_bin) + if existing is None or file.name.endswith(".joblib.gz"): + files_by_bin[e_bin] = file + files = [files_by_bin[e_bin] for e_bin in sorted(files_by_bin)] _logger.info(f"Loading classification models from {files}") for file in files: - match = re.search(r"_ebin(\d+)\.joblib$", file.name) + match = pattern.match(file.name) if not match: _logger.warning(f"Could not extract energy bin from filename: {file.name}") continue @@ -206,7 +218,7 @@ def load_regression_models(model_prefix, model_name): dict Model dictionary. """ - model_path = Path(model_prefix).with_suffix(".joblib") + model_path = utils.resolve_joblib_path(model_prefix) _logger.info(f"Loading regression model: {model_path}") model_data = joblib.load(model_path) diff --git a/src/eventdisplay_ml/scripts/diagnostic_shap_summary.py b/src/eventdisplay_ml/scripts/diagnostic_shap_summary.py index 67c34e3..e879d64 100644 --- a/src/eventdisplay_ml/scripts/diagnostic_shap_summary.py +++ b/src/eventdisplay_ml/scripts/diagnostic_shap_summary.py @@ -22,13 +22,15 @@ import numpy as np import pandas as pd +from eventdisplay_ml import utils + _logger = logging.getLogger(__name__) def load_model_config(model_file): """Load model configuration with cached feature importances.""" _logger.info(f"Loading model from {model_file}") - model_dict = joblib.load(model_file) + model_dict = joblib.load(utils.resolve_joblib_path(model_file)) models = model_dict.get("models") if not isinstance(models, dict) or not models: diff --git a/src/eventdisplay_ml/scripts/optimize_classification.py b/src/eventdisplay_ml/scripts/optimize_classification.py index 1c75da9..f2dbe31 100644 --- a/src/eventdisplay_ml/scripts/optimize_classification.py +++ b/src/eventdisplay_ml/scripts/optimize_classification.py @@ -33,6 +33,8 @@ from astropy.table import Table from scipy.interpolate import LinearNDInterpolator, RegularGridInterpolator +from eventdisplay_ml import utils + logging.basicConfig(level=logging.INFO) _logger = logging.getLogger(__name__) @@ -106,7 +108,8 @@ def _load_multi_bin_roc(joblib_paths): for path in joblib_paths: try: - data = joblib.load(path) + resolved_path = utils.resolve_joblib_path(path) + data = joblib.load(resolved_path) ebins = data["energy_bins_log10_tev"] e_min = ebins["E_min"] e_max = ebins["E_max"] @@ -457,7 +460,9 @@ def main(): """CLI entry point.""" parser = argparse.ArgumentParser(description="Optimize classification cuts.") parser.add_argument("input_root", help="ROOT file with rate surfaces") - parser.add_argument("roc_files", nargs="+", help="List of ebin*.joblib files") + parser.add_argument( + "roc_files", nargs="+", help="List of ebin* model files (.joblib.gz preferred)." + ) parser.add_argument("source_strength", type=float, help="Fraction of Crab (e.g. 0.1 for 10%%)") parser.add_argument( "--source-index", diff --git a/src/eventdisplay_ml/scripts/plot_classification_performance_metrics.py b/src/eventdisplay_ml/scripts/plot_classification_performance_metrics.py index bdd8d74..23cd5ae 100644 --- a/src/eventdisplay_ml/scripts/plot_classification_performance_metrics.py +++ b/src/eventdisplay_ml/scripts/plot_classification_performance_metrics.py @@ -21,6 +21,8 @@ import numpy as np import uproot +from eventdisplay_ml import utils + logging.basicConfig(level=logging.INFO) _logger = logging.getLogger(__name__) @@ -127,7 +129,8 @@ def load_efficiency_tmva(path, ebin, zebin=0): def load_efficiency_xgb(path, ebin): """Load efficiencies from XGB files.""" - data_joblib = joblib.load(Path(path) / f"gammahadron_bdt_ebin{ebin}.joblib") + model_file = utils.resolve_joblib_path(Path(path) / f"gammahadron_bdt_ebin{ebin}") + data_joblib = joblib.load(model_file) df_xgboost = data_joblib["models"]["xgboost"]["efficiency"] x_joblib = df_xgboost["threshold"] diff --git a/src/eventdisplay_ml/scripts/plot_training_evaluation.py b/src/eventdisplay_ml/scripts/plot_training_evaluation.py index af3ca0c..16ec84c 100644 --- a/src/eventdisplay_ml/scripts/plot_training_evaluation.py +++ b/src/eventdisplay_ml/scripts/plot_training_evaluation.py @@ -26,10 +26,18 @@ import matplotlib.pyplot as plt import numpy as np +from eventdisplay_ml import utils + logging.basicConfig(level=logging.INFO) _logger = logging.getLogger(__name__) +def _joblib_basename(model_path): + """Return basename without .joblib/.joblib.gz suffixes.""" + name = Path(model_path).name + return name.removesuffix(".joblib.gz").removesuffix(".joblib") + + def plot_training_curves(evals_result, output_file=None): """ Plot training and validation curves from XGBoost evaluation results. @@ -145,9 +153,7 @@ def main(): args = parser.parse_args() if args.model_file: - model_path = Path(args.model_file) - if not model_path.exists(): - raise FileNotFoundError(f"Model file not found: {model_path}") + model_path = utils.resolve_joblib_path(args.model_file) _logger.info(f"Loading model from: {model_path}") model_configs = joblib.load(model_path) @@ -178,7 +184,12 @@ def main(): if "targets" in model_configs: _logger.info(f"Target variables: {model_configs['targets']}") - plot_training_curves(evals_result, args.output_file) + output_file = args.output_file + if output_file is None: + output_file = f"training_evaluation_{_joblib_basename(model_path)}.png" + _logger.info(f"No --output_file given. Saving to {output_file}") + + plot_training_curves(evals_result, output_file) _logger.info("Plotting completed successfully.") elif args.model_dir: @@ -191,7 +202,16 @@ def main(): output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) - joblib_files = sorted(model_dir.glob("*.joblib")) + discovered_files = sorted( + set(model_dir.glob("*.joblib")).union(model_dir.glob("*.joblib.gz")) + ) + files_by_name = {} + for model_path in discovered_files: + key = _joblib_basename(model_path) + existing = files_by_name.get(key) + if existing is None or model_path.name.endswith(".joblib.gz"): + files_by_name[key] = model_path + joblib_files = [files_by_name[name] for name in sorted(files_by_name)] if not joblib_files: raise FileNotFoundError(f"No joblib files found in directory: {model_dir}") @@ -213,7 +233,7 @@ def main(): continue evals_result = xgb_model.evals_result() - output_file = output_dir / (f"training_evaluation_{model_path.stem}.png") + output_file = output_dir / f"training_evaluation_{_joblib_basename(model_path)}.png" plot_training_curves(evals_result, output_file) _logger.info(f"Saved plot for {model_path.name} to {output_file}") diff --git a/src/eventdisplay_ml/utils.py b/src/eventdisplay_ml/utils.py index 39770ef..3489c3f 100644 --- a/src/eventdisplay_ml/utils.py +++ b/src/eventdisplay_ml/utils.py @@ -7,6 +7,25 @@ _logger = logging.getLogger(__name__) +def resolve_joblib_path(path_or_prefix): + """Resolve model path supporting .joblib.gz (preferred) and .joblib.""" + path = Path(path_or_prefix) + path_str = str(path) + + if path_str.endswith(".joblib.gz"): + candidates = [path, Path(path_str.removesuffix(".gz"))] + elif path_str.endswith(".joblib"): + candidates = [Path(f"{path_str}.gz"), path] + else: + candidates = [Path(f"{path_str}.joblib.gz"), Path(f"{path_str}.joblib"), path] + + for candidate in candidates: + if candidate.exists() and candidate.is_file(): + return candidate + + raise FileNotFoundError(f"Could not resolve model file from '{path_or_prefix}'.") + + def read_input_file_list(input_file_list): """ Read a list of input files from a text file. @@ -129,6 +148,6 @@ def output_file_name(model_prefix, n_tel=None, energy_bin_number=None): filename = f"{model_prefix!s}_ntel{n_tel}" if energy_bin_number is not None: filename += f"_ebin{energy_bin_number}" - filename += ".joblib" + filename += ".joblib.gz" _logger.info(f"Output filename: {filename}") return filename From cdc890d146288031f3e94ae7a49ad6cfb331bc8b Mon Sep 17 00:00:00 2001 From: Gernot Maier Date: Thu, 14 May 2026 18:53:14 +0200 Subject: [PATCH 10/10] copilot --- README.md | 5 --- .../scripts/plot_training_evaluation.py | 39 +++++++++++-------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index cff930d..8c7b784 100644 --- a/README.md +++ b/README.md @@ -302,11 +302,6 @@ Analysis type: stereo reconstruction or gamma/hadron separation, depending on th - Plot XGBoost training vs validation metric curves - Useful for checking convergence and overfitting behavior -Required inputs: - -- `--model_file`: trained model `.joblib` containing an XGBoost model -- `--output_file`: output image path (optional; if omitted, plot is shown interactively) - Run: ```bash diff --git a/src/eventdisplay_ml/scripts/plot_training_evaluation.py b/src/eventdisplay_ml/scripts/plot_training_evaluation.py index 16ec84c..3f2986a 100644 --- a/src/eventdisplay_ml/scripts/plot_training_evaluation.py +++ b/src/eventdisplay_ml/scripts/plot_training_evaluation.py @@ -136,13 +136,19 @@ def main(): group.add_argument( "--model_dir", type=str, - help="Directory containing multiple joblib model files. All *.joblib files will be processed.", + help=( + "Directory containing multiple joblib model files. " + "All *.joblib files will be processed.", + ), ) parser.add_argument( "--output_file", type=str, default=None, - help="Path to save the output plot (PNG/PDF). If not provided, display interactively. Only for single file mode.", + help=( + "Path to save the output plot (PNG/PDF). If not provided, display interactively. " + "Only for single file mode." + ), ) parser.add_argument( "--output_dir", @@ -219,24 +225,23 @@ def main(): _logger.info(f"Loading model from: {model_path}") try: model_configs = joblib.load(model_path) + if "models" not in model_configs or "xgboost" not in model_configs["models"]: + _logger.error(f"Skipping {model_path}: missing 'models/xgboost' key.") + continue + + xgb_model = model_configs["models"]["xgboost"].get("model") + if not hasattr(xgb_model, "evals_result"): + _logger.error(f"Skipping {model_path}: model missing 'evals_result'.") + continue + + evals_result = xgb_model.evals_result() + output_file = output_dir / f"training_evaluation_{_joblib_basename(model_path)}.png" + plot_training_curves(evals_result, output_file) + _logger.info(f"Saved plot for {model_path.name} to {output_file}") except Exception as e: - _logger.error(f"Failed to load {model_path}: {e}") + _logger.exception(f"Skipping {model_path}: failed to process model ({e})") continue - if "models" not in model_configs or "xgboost" not in model_configs["models"]: - _logger.error(f"Skipping {model_path}: missing 'models/xgboost' key.") - continue - - xgb_model = model_configs["models"]["xgboost"].get("model") - if not hasattr(xgb_model, "evals_result"): - _logger.error(f"Skipping {model_path}: model missing 'evals_result'.") - continue - - evals_result = xgb_model.evals_result() - output_file = output_dir / f"training_evaluation_{_joblib_basename(model_path)}.png" - plot_training_curves(evals_result, output_file) - _logger.info(f"Saved plot for {model_path.name} to {output_file}") - _logger.info("Batch plotting completed.")