Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,6 @@ Analysis type: stereo reconstruction or gamma/hadron separation, depending on th
- Plot XGBoost training vs validation metric curves
- Useful for checking convergence and overfitting behavior

Required inputs:

- `--model_file`: trained model `.joblib` containing an XGBoost model
- `--output_file`: output image path (optional; if omitted, plot is shown interactively)

Run:

```bash
Expand All @@ -315,6 +310,14 @@ eventdisplay-ml-plot-training-evaluation \
--output_file diagnostics/training_curves.png
```

or for all joblib files in a directory:

```bash
eventdisplay-ml-plot-training-evaluation \
--model_dir models/ \
--output_dir diagnostics/
```
Comment thread
GernotMaier marked this conversation as resolved.

Output:

- Figure with one panel per tracked metric (for example `rmse`), showing training and test curves.
Expand Down
1 change: 1 addition & 0 deletions docs/changes/60.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve classification hyperparameters with a focus on robustness, and add user-facing plotting CLI options for selecting `--model_dir`/`--output_dir` and `--energy-bin`.
3 changes: 2 additions & 1 deletion src/eventdisplay_ml/diagnostic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

from eventdisplay_ml import utils
from eventdisplay_ml.data_processing import load_training_data

_logger = logging.getLogger(__name__)


def _load_model_cfg(model_file):
"""Load full model dictionary and the first model configuration entry."""
model_dict = joblib.load(model_file)
model_dict = joblib.load(utils.resolve_joblib_path(model_file))
models = model_dict.get("models", {})
model_cfg = next(iter(models.values())) if models else None
return model_dict, model_cfg
Expand Down
11 changes: 6 additions & 5 deletions src/eventdisplay_ml/hyper_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@
"objective": "binary:logistic",
"eval_metric": ["logloss", "auc"],
"n_estimators": 5000,
"early_stopping_rounds": 50,
"max_depth": 7,
"learning_rate": 0.05,
"early_stopping_rounds": 100,
"max_depth": 4,
"learning_rate": 0.02,
"gamma": 0.2,
"subsample": 0.8,
"colsample_bytree": 0.8,
"colsample_bytree": 0.6,
"random_state": None,
"n_jobs": 48,
"n_jobs": 96,
Comment thread
GernotMaier marked this conversation as resolved.
},
}
}
Expand Down
20 changes: 16 additions & 4 deletions src/eventdisplay_ml/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,24 @@ def load_classification_models(model_prefix, model_name):
models = {}
par = {}

pattern = f"{model_prefix.name}_ebin*.joblib"
files = sorted(model_dir_path.glob(pattern))
pattern = re.compile(rf"^{re.escape(model_prefix.name)}_ebin(\d+)\.joblib(?:\.gz)?$")
matched_files = [
file for file in model_dir_path.iterdir() if file.is_file() and pattern.match(file.name)
]
files_by_bin = {}
for file in matched_files:
match = pattern.match(file.name)
if not match:
continue
e_bin = int(match.group(1))
existing = files_by_bin.get(e_bin)
if existing is None or file.name.endswith(".joblib.gz"):
files_by_bin[e_bin] = file
files = [files_by_bin[e_bin] for e_bin in sorted(files_by_bin)]

_logger.info(f"Loading classification models from {files}")
for file in files:
match = re.search(r"_ebin(\d+)\.joblib$", file.name)
match = pattern.match(file.name)
if not match:
_logger.warning(f"Could not extract energy bin from filename: {file.name}")
continue
Expand Down Expand Up @@ -206,7 +218,7 @@ def load_regression_models(model_prefix, model_name):
dict
Model dictionary.
"""
model_path = Path(model_prefix).with_suffix(".joblib")
model_path = utils.resolve_joblib_path(model_prefix)
_logger.info(f"Loading regression model: {model_path}")

model_data = joblib.load(model_path)
Expand Down
4 changes: 3 additions & 1 deletion src/eventdisplay_ml/scripts/diagnostic_shap_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
import numpy as np
import pandas as pd

from eventdisplay_ml import utils

_logger = logging.getLogger(__name__)


def load_model_config(model_file):
"""Load model configuration with cached feature importances."""
_logger.info(f"Loading model from {model_file}")
model_dict = joblib.load(model_file)
model_dict = joblib.load(utils.resolve_joblib_path(model_file))

models = model_dict.get("models")
if not isinstance(models, dict) or not models:
Expand Down
9 changes: 7 additions & 2 deletions src/eventdisplay_ml/scripts/optimize_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from astropy.table import Table
from scipy.interpolate import LinearNDInterpolator, RegularGridInterpolator

from eventdisplay_ml import utils

logging.basicConfig(level=logging.INFO)
_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -106,7 +108,8 @@ def _load_multi_bin_roc(joblib_paths):

for path in joblib_paths:
try:
data = joblib.load(path)
resolved_path = utils.resolve_joblib_path(path)
data = joblib.load(resolved_path)
ebins = data["energy_bins_log10_tev"]
e_min = ebins["E_min"]
e_max = ebins["E_max"]
Expand Down Expand Up @@ -457,7 +460,9 @@ def main():
"""CLI entry point."""
parser = argparse.ArgumentParser(description="Optimize classification cuts.")
parser.add_argument("input_root", help="ROOT file with rate surfaces")
parser.add_argument("roc_files", nargs="+", help="List of ebin*.joblib files")
parser.add_argument(
"roc_files", nargs="+", help="List of ebin* model files (.joblib.gz preferred)."
)
parser.add_argument("source_strength", type=float, help="Fraction of Crab (e.g. 0.1 for 10%%)")
parser.add_argument(
"--source-index",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import numpy as np
import uproot

from eventdisplay_ml import utils

logging.basicConfig(level=logging.INFO)
_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -127,7 +129,8 @@ def load_efficiency_tmva(path, ebin, zebin=0):

def load_efficiency_xgb(path, ebin):
"""Load efficiencies from XGB files."""
data_joblib = joblib.load(Path(path) / f"gammahadron_bdt_ebin{ebin}.joblib")
model_file = utils.resolve_joblib_path(Path(path) / f"gammahadron_bdt_ebin{ebin}")
data_joblib = joblib.load(model_file)
df_xgboost = data_joblib["models"]["xgboost"]["efficiency"]

x_joblib = df_xgboost["threshold"]
Expand All @@ -142,10 +145,18 @@ def main():
parser = argparse.ArgumentParser(description="Plot TMVA and XGBoost metrics.")
parser.add_argument("root_dir", help="Path to the TMVA BDT .root file")
parser.add_argument("joblib_dir", help="Path to the XGB BDT .joblib file")
parser.add_argument(
"--energy-bin",
type=int,
choices=range(9),
default=None,
help="Plot only a single energy bin (0-8). If omitted, all bins are processed.",
)
args = parser.parse_args()

# assume energy binning is identical in XGB and TMVA files.
for ebin in range(9):
energy_bins = [args.energy_bin] if args.energy_bin is not None else range(9)
for ebin in energy_bins:
x_root, y_effs, y_effb = load_efficiency_tmva(args.root_dir, ebin)
x_joblib, y_effs_xgb, y_effb_xgb = load_efficiency_xgb(args.joblib_dir, ebin)

Expand Down
133 changes: 103 additions & 30 deletions src/eventdisplay_ml/scripts/plot_training_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,18 @@
import matplotlib.pyplot as plt
import numpy as np

from eventdisplay_ml import utils

logging.basicConfig(level=logging.INFO)
_logger = logging.getLogger(__name__)


def _joblib_basename(model_path):
"""Return basename without .joblib/.joblib.gz suffixes."""
name = Path(model_path).name
return name.removesuffix(".joblib.gz").removesuffix(".joblib")


def plot_training_curves(evals_result, output_file=None):
"""
Plot training and validation curves from XGBoost evaluation results.
Expand Down Expand Up @@ -119,57 +127,122 @@ def main():
"(stereo or classification)."
)
)
parser.add_argument(
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"--model_file",
required=True,
type=str,
help="Path to the trained model joblib file (e.g., dispdir_bdt.joblib).",
help="Path to a single trained model joblib file (e.g., dispdir_bdt.joblib).",
)
group.add_argument(
"--model_dir",
type=str,
help=(
"Directory containing multiple joblib model files. "
"All *.joblib files will be processed.",
),
)
parser.add_argument(
"--output_file",
type=str,
default=None,
help="Path to save the output plot (PNG/PDF). If not provided, display interactively.",
help=(
"Path to save the output plot (PNG/PDF). If not provided, display interactively. "
"Only for single file mode."
),
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory to save all output plots (required if --model_dir is used).",
)

args = parser.parse_args()
if args.model_file:
model_path = utils.resolve_joblib_path(args.model_file)

model_path = Path(args.model_file)
if not model_path.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
_logger.info(f"Loading model from: {model_path}")
model_configs = joblib.load(model_path)

_logger.info(f"Loading model from: {model_path}")
model_configs = joblib.load(model_path)
# Extract the XGBoost model and its evaluation results
if "models" not in model_configs:
raise ValueError("Model file does not contain 'models' key.")

# Extract the XGBoost model and its evaluation results
if "models" not in model_configs:
raise ValueError("Model file does not contain 'models' key.")
if "xgboost" not in model_configs["models"]:
raise ValueError("Model file does not contain 'xgboost' model.")

if "xgboost" not in model_configs["models"]:
raise ValueError("Model file does not contain 'xgboost' model.")
xgb_model = model_configs["models"]["xgboost"]["model"]

xgb_model = model_configs["models"]["xgboost"]["model"]
if not hasattr(xgb_model, "evals_result"):
raise AttributeError(
"XGBoost model does not have 'evals_result' method. "
"Model may not have been trained with eval_set parameter."
)

if not hasattr(xgb_model, "evals_result"):
raise AttributeError(
"XGBoost model does not have 'evals_result' method. "
"Model may not have been trained with eval_set parameter."
)
evals_result = xgb_model.evals_result()

_logger.info(f"Model type: {type(xgb_model).__name__}")
_logger.info(f"Number of boosting rounds: {xgb_model.get_booster().num_boosted_rounds()}")

evals_result = xgb_model.evals_result()
# Additional model info
if "features" in model_configs:
_logger.info(f"Number of features: {len(model_configs['features'])}")
if "targets" in model_configs:
_logger.info(f"Target variables: {model_configs['targets']}")

_logger.info(f"Model type: {type(xgb_model).__name__}")
_logger.info(f"Number of boosting rounds: {xgb_model.get_booster().num_boosted_rounds()}")
output_file = args.output_file
if output_file is None:
output_file = f"training_evaluation_{_joblib_basename(model_path)}.png"
_logger.info(f"No --output_file given. Saving to {output_file}")

# Additional model info
if "features" in model_configs:
_logger.info(f"Number of features: {len(model_configs['features'])}")
if "targets" in model_configs:
_logger.info(f"Target variables: {model_configs['targets']}")
plot_training_curves(evals_result, output_file)
_logger.info("Plotting completed successfully.")

plot_training_curves(evals_result, args.output_file)
elif args.model_dir:
model_dir = Path(args.model_dir)
if not model_dir.exists() or not model_dir.is_dir():
raise FileNotFoundError(f"Model directory not found: {model_dir}")

if not args.output_dir:
raise ValueError("--output_dir must be specified when using --model_dir.")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

discovered_files = sorted(
set(model_dir.glob("*.joblib")).union(model_dir.glob("*.joblib.gz"))
)
files_by_name = {}
for model_path in discovered_files:
key = _joblib_basename(model_path)
existing = files_by_name.get(key)
if existing is None or model_path.name.endswith(".joblib.gz"):
files_by_name[key] = model_path
joblib_files = [files_by_name[name] for name in sorted(files_by_name)]
if not joblib_files:
raise FileNotFoundError(f"No joblib files found in directory: {model_dir}")

for model_path in joblib_files:
_logger.info(f"Loading model from: {model_path}")
try:
model_configs = joblib.load(model_path)
if "models" not in model_configs or "xgboost" not in model_configs["models"]:
_logger.error(f"Skipping {model_path}: missing 'models/xgboost' key.")
continue

xgb_model = model_configs["models"]["xgboost"].get("model")
if not hasattr(xgb_model, "evals_result"):
_logger.error(f"Skipping {model_path}: model missing 'evals_result'.")
continue

evals_result = xgb_model.evals_result()
output_file = output_dir / f"training_evaluation_{_joblib_basename(model_path)}.png"
plot_training_curves(evals_result, output_file)
_logger.info(f"Saved plot for {model_path.name} to {output_file}")
except Exception as e:
_logger.exception(f"Skipping {model_path}: failed to process model ({e})")
continue

_logger.info("Plotting completed successfully.")
_logger.info("Batch plotting completed.")


if __name__ == "__main__":
Expand Down
21 changes: 20 additions & 1 deletion src/eventdisplay_ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,25 @@
_logger = logging.getLogger(__name__)


def resolve_joblib_path(path_or_prefix):
"""Resolve model path supporting .joblib.gz (preferred) and .joblib."""
path = Path(path_or_prefix)
path_str = str(path)

if path_str.endswith(".joblib.gz"):
candidates = [path, Path(path_str.removesuffix(".gz"))]
elif path_str.endswith(".joblib"):
candidates = [Path(f"{path_str}.gz"), path]
else:
candidates = [Path(f"{path_str}.joblib.gz"), Path(f"{path_str}.joblib"), path]

for candidate in candidates:
if candidate.exists() and candidate.is_file():
return candidate

raise FileNotFoundError(f"Could not resolve model file from '{path_or_prefix}'.")


def read_input_file_list(input_file_list):
"""
Read a list of input files from a text file.
Expand Down Expand Up @@ -129,6 +148,6 @@ def output_file_name(model_prefix, n_tel=None, energy_bin_number=None):
filename = f"{model_prefix!s}_ntel{n_tel}"
if energy_bin_number is not None:
filename += f"_ebin{energy_bin_number}"
filename += ".joblib"
filename += ".joblib.gz"
_logger.info(f"Output filename: {filename}")
return filename