diff --git a/README.md b/README.md index cbf08cd..8c7b784 100644 --- a/README.md +++ b/README.md @@ -302,11 +302,6 @@ Analysis type: stereo reconstruction or gamma/hadron separation, depending on th - Plot XGBoost training vs validation metric curves - Useful for checking convergence and overfitting behavior -Required inputs: - -- `--model_file`: trained model `.joblib` containing an XGBoost model -- `--output_file`: output image path (optional; if omitted, plot is shown interactively) - Run: ```bash @@ -315,6 +310,14 @@ eventdisplay-ml-plot-training-evaluation \ --output_file diagnostics/training_curves.png ``` +or for all joblib files in a directory: + +```bash +eventdisplay-ml-plot-training-evaluation \ + --model_dir models/ \ + --output_dir diagnostics/ +``` + Output: - Figure with one panel per tracked metric (for example `rmse`), showing training and test curves. diff --git a/docs/changes/60.feature.md b/docs/changes/60.feature.md new file mode 100644 index 0000000..9aa0782 --- /dev/null +++ b/docs/changes/60.feature.md @@ -0,0 +1 @@ +Improve classification hyperparameters with a focus on robustness, and add user-facing plotting CLI options for selecting `--model_dir`/`--output_dir` and `--energy-bin`. diff --git a/src/eventdisplay_ml/diagnostic_utils.py b/src/eventdisplay_ml/diagnostic_utils.py index aae271e..315d175 100644 --- a/src/eventdisplay_ml/diagnostic_utils.py +++ b/src/eventdisplay_ml/diagnostic_utils.py @@ -14,6 +14,7 @@ from sklearn.metrics import mean_squared_error from sklearn.model_selection import train_test_split +from eventdisplay_ml import utils from eventdisplay_ml.data_processing import load_training_data _logger = logging.getLogger(__name__) @@ -21,7 +22,7 @@ def _load_model_cfg(model_file): """Load full model dictionary and the first model configuration entry.""" - model_dict = joblib.load(model_file) + model_dict = joblib.load(utils.resolve_joblib_path(model_file)) models = model_dict.get("models", {}) model_cfg = next(iter(models.values())) if models else None return model_dict, model_cfg diff --git a/src/eventdisplay_ml/hyper_parameters.py b/src/eventdisplay_ml/hyper_parameters.py index 0a24979..dc52ffe 100644 --- a/src/eventdisplay_ml/hyper_parameters.py +++ b/src/eventdisplay_ml/hyper_parameters.py @@ -33,13 +33,14 @@ "objective": "binary:logistic", "eval_metric": ["logloss", "auc"], "n_estimators": 5000, - "early_stopping_rounds": 50, - "max_depth": 7, - "learning_rate": 0.05, + "early_stopping_rounds": 100, + "max_depth": 4, + "learning_rate": 0.02, + "gamma": 0.2, "subsample": 0.8, - "colsample_bytree": 0.8, + "colsample_bytree": 0.6, "random_state": None, - "n_jobs": 48, + "n_jobs": 96, }, } } diff --git a/src/eventdisplay_ml/models.py b/src/eventdisplay_ml/models.py index 1578d7e..059468b 100644 --- a/src/eventdisplay_ml/models.py +++ b/src/eventdisplay_ml/models.py @@ -97,12 +97,24 @@ def load_classification_models(model_prefix, model_name): models = {} par = {} - pattern = f"{model_prefix.name}_ebin*.joblib" - files = sorted(model_dir_path.glob(pattern)) + pattern = re.compile(rf"^{re.escape(model_prefix.name)}_ebin(\d+)\.joblib(?:\.gz)?$") + matched_files = [ + file for file in model_dir_path.iterdir() if file.is_file() and pattern.match(file.name) + ] + files_by_bin = {} + for file in matched_files: + match = pattern.match(file.name) + if not match: + continue + e_bin = int(match.group(1)) + existing = files_by_bin.get(e_bin) + if existing is None or file.name.endswith(".joblib.gz"): + files_by_bin[e_bin] = file + files = [files_by_bin[e_bin] for e_bin in sorted(files_by_bin)] _logger.info(f"Loading classification models from {files}") for file in files: - match = re.search(r"_ebin(\d+)\.joblib$", file.name) + match = pattern.match(file.name) if not match: _logger.warning(f"Could not extract energy bin from filename: {file.name}") continue @@ -206,7 +218,7 @@ def load_regression_models(model_prefix, model_name): dict Model dictionary. """ - model_path = Path(model_prefix).with_suffix(".joblib") + model_path = utils.resolve_joblib_path(model_prefix) _logger.info(f"Loading regression model: {model_path}") model_data = joblib.load(model_path) diff --git a/src/eventdisplay_ml/scripts/diagnostic_shap_summary.py b/src/eventdisplay_ml/scripts/diagnostic_shap_summary.py index 67c34e3..e879d64 100644 --- a/src/eventdisplay_ml/scripts/diagnostic_shap_summary.py +++ b/src/eventdisplay_ml/scripts/diagnostic_shap_summary.py @@ -22,13 +22,15 @@ import numpy as np import pandas as pd +from eventdisplay_ml import utils + _logger = logging.getLogger(__name__) def load_model_config(model_file): """Load model configuration with cached feature importances.""" _logger.info(f"Loading model from {model_file}") - model_dict = joblib.load(model_file) + model_dict = joblib.load(utils.resolve_joblib_path(model_file)) models = model_dict.get("models") if not isinstance(models, dict) or not models: diff --git a/src/eventdisplay_ml/scripts/optimize_classification.py b/src/eventdisplay_ml/scripts/optimize_classification.py index 1c75da9..f2dbe31 100644 --- a/src/eventdisplay_ml/scripts/optimize_classification.py +++ b/src/eventdisplay_ml/scripts/optimize_classification.py @@ -33,6 +33,8 @@ from astropy.table import Table from scipy.interpolate import LinearNDInterpolator, RegularGridInterpolator +from eventdisplay_ml import utils + logging.basicConfig(level=logging.INFO) _logger = logging.getLogger(__name__) @@ -106,7 +108,8 @@ def _load_multi_bin_roc(joblib_paths): for path in joblib_paths: try: - data = joblib.load(path) + resolved_path = utils.resolve_joblib_path(path) + data = joblib.load(resolved_path) ebins = data["energy_bins_log10_tev"] e_min = ebins["E_min"] e_max = ebins["E_max"] @@ -457,7 +460,9 @@ def main(): """CLI entry point.""" parser = argparse.ArgumentParser(description="Optimize classification cuts.") parser.add_argument("input_root", help="ROOT file with rate surfaces") - parser.add_argument("roc_files", nargs="+", help="List of ebin*.joblib files") + parser.add_argument( + "roc_files", nargs="+", help="List of ebin* model files (.joblib.gz preferred)." + ) parser.add_argument("source_strength", type=float, help="Fraction of Crab (e.g. 0.1 for 10%%)") parser.add_argument( "--source-index", diff --git a/src/eventdisplay_ml/scripts/plot_classification_performance_metrics.py b/src/eventdisplay_ml/scripts/plot_classification_performance_metrics.py index a46c11e..23cd5ae 100644 --- a/src/eventdisplay_ml/scripts/plot_classification_performance_metrics.py +++ b/src/eventdisplay_ml/scripts/plot_classification_performance_metrics.py @@ -21,6 +21,8 @@ import numpy as np import uproot +from eventdisplay_ml import utils + logging.basicConfig(level=logging.INFO) _logger = logging.getLogger(__name__) @@ -127,7 +129,8 @@ def load_efficiency_tmva(path, ebin, zebin=0): def load_efficiency_xgb(path, ebin): """Load efficiencies from XGB files.""" - data_joblib = joblib.load(Path(path) / f"gammahadron_bdt_ebin{ebin}.joblib") + model_file = utils.resolve_joblib_path(Path(path) / f"gammahadron_bdt_ebin{ebin}") + data_joblib = joblib.load(model_file) df_xgboost = data_joblib["models"]["xgboost"]["efficiency"] x_joblib = df_xgboost["threshold"] @@ -142,10 +145,18 @@ def main(): parser = argparse.ArgumentParser(description="Plot TMVA and XGBoost metrics.") parser.add_argument("root_dir", help="Path to the TMVA BDT .root file") parser.add_argument("joblib_dir", help="Path to the XGB BDT .joblib file") + parser.add_argument( + "--energy-bin", + type=int, + choices=range(9), + default=None, + help="Plot only a single energy bin (0-8). If omitted, all bins are processed.", + ) args = parser.parse_args() # assume energy binning is identical in XGB and TMVA files. - for ebin in range(9): + energy_bins = [args.energy_bin] if args.energy_bin is not None else range(9) + for ebin in energy_bins: x_root, y_effs, y_effb = load_efficiency_tmva(args.root_dir, ebin) x_joblib, y_effs_xgb, y_effb_xgb = load_efficiency_xgb(args.joblib_dir, ebin) diff --git a/src/eventdisplay_ml/scripts/plot_training_evaluation.py b/src/eventdisplay_ml/scripts/plot_training_evaluation.py index b9c9c1a..3f2986a 100644 --- a/src/eventdisplay_ml/scripts/plot_training_evaluation.py +++ b/src/eventdisplay_ml/scripts/plot_training_evaluation.py @@ -26,10 +26,18 @@ import matplotlib.pyplot as plt import numpy as np +from eventdisplay_ml import utils + logging.basicConfig(level=logging.INFO) _logger = logging.getLogger(__name__) +def _joblib_basename(model_path): + """Return basename without .joblib/.joblib.gz suffixes.""" + name = Path(model_path).name + return name.removesuffix(".joblib.gz").removesuffix(".joblib") + + def plot_training_curves(evals_result, output_file=None): """ Plot training and validation curves from XGBoost evaluation results. @@ -119,57 +127,122 @@ def main(): "(stereo or classification)." ) ) - parser.add_argument( + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( "--model_file", - required=True, type=str, - help="Path to the trained model joblib file (e.g., dispdir_bdt.joblib).", + help="Path to a single trained model joblib file (e.g., dispdir_bdt.joblib).", + ) + group.add_argument( + "--model_dir", + type=str, + help=( + "Directory containing multiple joblib model files. " + "All *.joblib files will be processed.", + ), ) parser.add_argument( "--output_file", type=str, default=None, - help="Path to save the output plot (PNG/PDF). If not provided, display interactively.", + help=( + "Path to save the output plot (PNG/PDF). If not provided, display interactively. " + "Only for single file mode." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Directory to save all output plots (required if --model_dir is used).", ) args = parser.parse_args() + if args.model_file: + model_path = utils.resolve_joblib_path(args.model_file) - model_path = Path(args.model_file) - if not model_path.exists(): - raise FileNotFoundError(f"Model file not found: {model_path}") + _logger.info(f"Loading model from: {model_path}") + model_configs = joblib.load(model_path) - _logger.info(f"Loading model from: {model_path}") - model_configs = joblib.load(model_path) + # Extract the XGBoost model and its evaluation results + if "models" not in model_configs: + raise ValueError("Model file does not contain 'models' key.") - # Extract the XGBoost model and its evaluation results - if "models" not in model_configs: - raise ValueError("Model file does not contain 'models' key.") + if "xgboost" not in model_configs["models"]: + raise ValueError("Model file does not contain 'xgboost' model.") - if "xgboost" not in model_configs["models"]: - raise ValueError("Model file does not contain 'xgboost' model.") + xgb_model = model_configs["models"]["xgboost"]["model"] - xgb_model = model_configs["models"]["xgboost"]["model"] + if not hasattr(xgb_model, "evals_result"): + raise AttributeError( + "XGBoost model does not have 'evals_result' method. " + "Model may not have been trained with eval_set parameter." + ) - if not hasattr(xgb_model, "evals_result"): - raise AttributeError( - "XGBoost model does not have 'evals_result' method. " - "Model may not have been trained with eval_set parameter." - ) + evals_result = xgb_model.evals_result() + + _logger.info(f"Model type: {type(xgb_model).__name__}") + _logger.info(f"Number of boosting rounds: {xgb_model.get_booster().num_boosted_rounds()}") - evals_result = xgb_model.evals_result() + # Additional model info + if "features" in model_configs: + _logger.info(f"Number of features: {len(model_configs['features'])}") + if "targets" in model_configs: + _logger.info(f"Target variables: {model_configs['targets']}") - _logger.info(f"Model type: {type(xgb_model).__name__}") - _logger.info(f"Number of boosting rounds: {xgb_model.get_booster().num_boosted_rounds()}") + output_file = args.output_file + if output_file is None: + output_file = f"training_evaluation_{_joblib_basename(model_path)}.png" + _logger.info(f"No --output_file given. Saving to {output_file}") - # Additional model info - if "features" in model_configs: - _logger.info(f"Number of features: {len(model_configs['features'])}") - if "targets" in model_configs: - _logger.info(f"Target variables: {model_configs['targets']}") + plot_training_curves(evals_result, output_file) + _logger.info("Plotting completed successfully.") - plot_training_curves(evals_result, args.output_file) + elif args.model_dir: + model_dir = Path(args.model_dir) + if not model_dir.exists() or not model_dir.is_dir(): + raise FileNotFoundError(f"Model directory not found: {model_dir}") + + if not args.output_dir: + raise ValueError("--output_dir must be specified when using --model_dir.") + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + discovered_files = sorted( + set(model_dir.glob("*.joblib")).union(model_dir.glob("*.joblib.gz")) + ) + files_by_name = {} + for model_path in discovered_files: + key = _joblib_basename(model_path) + existing = files_by_name.get(key) + if existing is None or model_path.name.endswith(".joblib.gz"): + files_by_name[key] = model_path + joblib_files = [files_by_name[name] for name in sorted(files_by_name)] + if not joblib_files: + raise FileNotFoundError(f"No joblib files found in directory: {model_dir}") + + for model_path in joblib_files: + _logger.info(f"Loading model from: {model_path}") + try: + model_configs = joblib.load(model_path) + if "models" not in model_configs or "xgboost" not in model_configs["models"]: + _logger.error(f"Skipping {model_path}: missing 'models/xgboost' key.") + continue + + xgb_model = model_configs["models"]["xgboost"].get("model") + if not hasattr(xgb_model, "evals_result"): + _logger.error(f"Skipping {model_path}: model missing 'evals_result'.") + continue + + evals_result = xgb_model.evals_result() + output_file = output_dir / f"training_evaluation_{_joblib_basename(model_path)}.png" + plot_training_curves(evals_result, output_file) + _logger.info(f"Saved plot for {model_path.name} to {output_file}") + except Exception as e: + _logger.exception(f"Skipping {model_path}: failed to process model ({e})") + continue - _logger.info("Plotting completed successfully.") + _logger.info("Batch plotting completed.") if __name__ == "__main__": diff --git a/src/eventdisplay_ml/utils.py b/src/eventdisplay_ml/utils.py index 39770ef..3489c3f 100644 --- a/src/eventdisplay_ml/utils.py +++ b/src/eventdisplay_ml/utils.py @@ -7,6 +7,25 @@ _logger = logging.getLogger(__name__) +def resolve_joblib_path(path_or_prefix): + """Resolve model path supporting .joblib.gz (preferred) and .joblib.""" + path = Path(path_or_prefix) + path_str = str(path) + + if path_str.endswith(".joblib.gz"): + candidates = [path, Path(path_str.removesuffix(".gz"))] + elif path_str.endswith(".joblib"): + candidates = [Path(f"{path_str}.gz"), path] + else: + candidates = [Path(f"{path_str}.joblib.gz"), Path(f"{path_str}.joblib"), path] + + for candidate in candidates: + if candidate.exists() and candidate.is_file(): + return candidate + + raise FileNotFoundError(f"Could not resolve model file from '{path_or_prefix}'.") + + def read_input_file_list(input_file_list): """ Read a list of input files from a text file. @@ -129,6 +148,6 @@ def output_file_name(model_prefix, n_tel=None, energy_bin_number=None): filename = f"{model_prefix!s}_ntel{n_tel}" if energy_bin_number is not None: filename += f"_ebin{energy_bin_number}" - filename += ".joblib" + filename += ".joblib.gz" _logger.info(f"Output filename: {filename}") return filename