diff --git a/stepmix/bootstrap.py b/stepmix/bootstrap.py index f00dd42..aaf8aec 100644 --- a/stepmix/bootstrap.py +++ b/stepmix/bootstrap.py @@ -6,6 +6,8 @@ import copy import numpy as np +from scipy.stats import chi2, norm +from typing import Optional, Union, Sequence import tqdm from sklearn.base import clone @@ -321,3 +323,643 @@ def blrt_sweep( print("\nBLRT Sweep Results") print(df.round(4)) return df + + +def _fdr_bh(p_values: np.ndarray) -> np.ndarray: + """Benjamini-Hochberg FDR correction (two-stage). + + Parameters + ---------- + p_values : 1-D array of raw p-values (NaN is preserved). + + Returns + ------- + adjusted : 1-D array of BH-corrected p-values. + """ + n = len(p_values) + finite = np.isfinite(p_values) + adjusted = np.full(n, np.nan) + if finite.sum() == 0: + return adjusted + + idx = np.where(finite)[0] + p = p_values[idx] + order = np.argsort(p) + ranks = np.empty_like(order) + ranks[order] = np.arange(1, len(p) + 1) + adj = np.minimum(1.0, p * len(p) / ranks) + # Enforce monotonicity (step-down) + for i in range(len(adj) - 2, -1, -1): + adj[order[i]] = min(adj[order[i]], adj[order[i + 1]]) + adjusted[idx] = adj + return adjusted + + +def _bonferroni(p_values: np.ndarray) -> np.ndarray: + n = len(p_values) + finite = np.isfinite(p_values) + adjusted = np.full(n, np.nan) + adjusted[finite] = np.minimum(1.0, p_values[finite] * n) + return adjusted + + +def _build_contrast_matrix(K: int) -> np.ndarray: + """Build a (K-1) × K contrast matrix that compares each of the first + K-1 classes to the last class. + + C @ θ = [θ₀ - θ_{K-1}, θ₁ - θ_{K-1}, ..., θ_{K-2} - θ_{K-1}] + """ + C = np.zeros((K - 1, K)) + for i in range(K - 1): + C[i, i] = 1.0 + C[i, K - 1] = -1.0 + return C + + +def _stars(p: float) -> str: + if np.isnan(p): + return " " + if p < 0.001: + return "***" + if p < 0.01: + return "** " + if p < 0.05: + return "* " + if p < 0.1: + return ". " + return " " + + +# --------------------------------------------------------------------------- +# Core class +# --------------------------------------------------------------------------- + +class WaldTest3Step: + """Wald tests and p-values for 3-step LCA distal outcome analysis. + + After fitting a :class:`stepmix.stepmix.StepMix` model with a structural + (outcome) component, this class runs a non-parametric bootstrap over the + full 3-step procedure and then uses the resulting sampling distribution to + compute standard errors, confidence intervals, and Wald tests. + + Parameters + ---------- + model : fitted StepMix instance + Must have been fitted with a structural model (``n_steps=3`` or via + manual ``m_step_structural``). The structural model type can be any + continuous, binary, or categorical emission. + ci_level : float, default=0.95 + Confidence level for bootstrap percentile CIs. E.g. 0.95 → 95 % CI. + + Attributes + ---------- + estimates_ : pd.DataFrame + Point estimates, bootstrap SE, and confidence intervals per + (outcome variable, latent class). + pairwise_ : pd.DataFrame + Pairwise Wald test results (Δ, SE_Δ, z, χ², df, p-value, …) for + every (outcome variable, class-pair) combination. + omnibus_ : pd.DataFrame + Global omnibus Wald test result (χ², df, p-value) per outcome variable. + bootstrap_samples_ : pd.DataFrame + Raw bootstrap parameter draws (long-form, as returned by + :func:`stepmix.bootstrap.bootstrap`). + """ + + def __init__(self, model, ci_level: float = 0.95): + check_is_fitted(model) + if not hasattr(model, "_sm"): + raise ValueError( + "The StepMix model has no structural model. Fit it with Y data " + "and a structural specification before using WaldTest3Step." + ) + self.model = model + self.ci_level = ci_level + + self._is_fitted = False + self.bootstrap_samples_: Optional[pd.DataFrame] = None + self.estimates_: Optional[pd.DataFrame] = None + self.pairwise_: Optional[pd.DataFrame] = None + self.omnibus_: Optional[pd.DataFrame] = None + + # ------------------------------------------------------------------ + # Fitting + # ------------------------------------------------------------------ + + def fit_bootstrap( + self, + X, + Y, + n_repetitions: int = 500, + sample_weight=None, + progress_bar: bool = True, + random_state: Optional[int] = None, + correction: Optional[str] = None, + ) -> "WaldTest3Step": + """Run the bootstrap and compute all test statistics. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Measurement data (same as used to fit the model). + Y : array-like of shape (n_samples, n_outcome_features) + Outcome data (same as used to fit the structural model). + n_repetitions : int, default=500 + Number of bootstrap replications. ≥ 500 recommended for stable + 95 % CIs; ≥ 1000 for 99 % CIs or for p-values < 0.01. + sample_weight : array-like of shape (n_samples,), default=None + Per-sample weights forwarded to the bootstrap. + progress_bar : bool, default=True + Show a tqdm progress bar. + random_state : int, default=None + Random seed for the bootstrap. + correction : {None, "Bonferroni", "BH"}, default=None + Multiple-comparison correction applied to pairwise p-values. + ``None`` → no correction; ``"Bonferroni"`` → Bonferroni-Holm; + ``"BH"`` → Benjamini-Hochberg FDR. + + Returns + ------- + self + """ + + boot_df, _ = bootstrap( + self.model, + X, + Y, + n_repetitions=n_repetitions, + sample_weight=sample_weight, + progress_bar=progress_bar, + random_state=random_state, + identify_classes=True, + ) + self.bootstrap_samples_ = boot_df + + self._build_estimates() + self._build_pairwise(correction=correction) + self._build_omnibus() + self._is_fitted = True + return self + + # ------------------------------------------------------------------ + # Building result tables (internal) + # ------------------------------------------------------------------ + + def _get_structural_boot(self) -> pd.DataFrame: + """Return bootstrap draws for the structural model only.""" + if "structural" not in self.bootstrap_samples_.index.get_level_values("model"): + raise ValueError( + "No structural parameters found in bootstrap samples. " + "Make sure the model has a structural component." + ) + return self.bootstrap_samples_.loc["structural"].copy() + + def _build_estimates(self) -> None: + """Build the estimates DataFrame from bootstrap samples.""" + sm_df = self._get_structural_boot() + K = self.model.n_components + alpha = 1.0 - self.ci_level + ci_lo_label = f"CI_{self.ci_level:.0%}_lo" + ci_hi_label = f"CI_{self.ci_level:.0%}_hi" + + # Point estimates: use the long-form parameters_df (has class_no as a column) + point_df = ( + self.model.get_parameters_df() + .loc["structural"] + .reset_index() + ) # columns: model_name, param, class_no, variable, value + + records = [] + for _, row in point_df.iterrows(): + model_name = row["model_name"] + param = row["param"] + variable = row["variable"] + class_no = int(row["class_no"]) + point_est = row["value"] + + # Extract the corresponding bootstrap draws for this (class, variable, param) + try: + mask = ( + (sm_df.index.get_level_values("model_name") == model_name) + & (sm_df.index.get_level_values("param") == param) + & (sm_df.index.get_level_values("class_no") == class_no) + & (sm_df.index.get_level_values("variable") == variable) + ) + draws = sm_df.loc[mask, "value"].values + except Exception: + draws = np.array([]) + + if len(draws) >= 2: + se = float(np.std(draws, ddof=1)) + ci_lo = float(np.percentile(draws, 100 * alpha / 2)) + ci_hi = float(np.percentile(draws, 100 * (1 - alpha / 2))) + else: + se = ci_lo = ci_hi = np.nan + + records.append( + dict( + model_name=model_name, + param=param, + variable=variable, + class_no=class_no, + estimate=point_est, + se=se, + ci_lo=ci_lo, + ci_hi=ci_hi, + ) + ) + + df = pd.DataFrame.from_records(records) + df.rename(columns={"ci_lo": ci_lo_label, "ci_hi": ci_hi_label}, inplace=True) + df.set_index(["model_name", "param", "variable", "class_no"], inplace=True) + df.sort_index(inplace=True) # required for efficient MultiIndex slicing + self.estimates_ = df + + def _get_boot_matrix(self, model_name: str, param: str, variable: str) -> np.ndarray: + """Return a (n_bootstrap × K) matrix of draws for a given variable. + + Rows = bootstrap replications, columns = latent classes (0 … K-1). + """ + sm_df = self._get_structural_boot() + K = self.model.n_components + n_reps = sm_df["rep"].nunique() + + mat = np.full((n_reps, K), np.nan) + for k in range(K): + mask = ( + (sm_df.index.get_level_values("model_name") == model_name) + & (sm_df.index.get_level_values("param") == param) + & (sm_df.index.get_level_values("class_no") == k) + & (sm_df.index.get_level_values("variable") == variable) + ) + sub = sm_df.loc[mask].sort_values("rep") + if len(sub) > 0: + mat[:, k] = sub["value"].values[:n_reps] + return mat # (B, K) + + def _build_pairwise(self, correction: Optional[str] = None) -> None: + """Build pairwise Wald-test DataFrame.""" + K = self.model.n_components + sm_df = self._get_structural_boot() + + # All (model_name, param, variable) combinations + idx_cols = ["model_name", "param", "variable"] + unique_vars = ( + sm_df.reset_index()[idx_cols] + .drop_duplicates() + .values.tolist() + ) + + pairs = list(itertools.combinations(range(K), 2)) + records = [] + + for model_name, param, variable in unique_vars: + mat = self._get_boot_matrix(model_name, param, variable) # (B, K) + if np.any(np.isnan(mat)): + continue + cov_mat = np.cov(mat.T, ddof=1) # (K, K) + + # Point estimates + try: + theta = ( + self.estimates_ + .loc[(model_name, param, variable)] + ["estimate"] + .values # shape (K,) + ) + except KeyError: + continue + + for j, k in pairs: + delta = float(theta[j] - theta[k]) + var_delta = float(cov_mat[j, j] + cov_mat[k, k] - 2 * cov_mat[j, k]) + if var_delta <= 0: + se_delta = np.nan + z = np.nan + chi2_stat = np.nan + p_val = np.nan + else: + se_delta = float(np.sqrt(var_delta)) + z = delta / se_delta + chi2_stat = z ** 2 + p_val = float(2.0 * (1.0 - norm.cdf(abs(z)))) + + records.append( + dict( + model_name=model_name, + param=param, + variable=variable, + class_j=j, + class_k=k, + theta_j=float(theta[j]), + theta_k=float(theta[k]), + delta=delta, + se_delta=se_delta, + z=z, + chi2=chi2_stat, + df=1, + p_value=p_val, + ) + ) + + df = pd.DataFrame.from_records(records) + if len(df) == 0: + self.pairwise_ = df + return + + # Multiple-comparison correction + raw_p = df["p_value"].values.copy() + if correction is None: + df["p_adj"] = raw_p + df["correction"] = "none" + elif correction.lower() == "bonferroni": + df["p_adj"] = _bonferroni(raw_p) + df["correction"] = "Bonferroni" + elif correction.lower() in ("bh", "fdr"): + df["p_adj"] = _fdr_bh(raw_p) + df["correction"] = "BH (FDR)" + else: + warnings.warn(f"Unknown correction '{correction}'. Skipping.") + df["p_adj"] = raw_p + df["correction"] = "none" + + df["sig"] = df["p_adj"].apply(_stars) + df.set_index(["model_name", "param", "variable", "class_j", "class_k"], inplace=True) + self.pairwise_ = df + + def _build_omnibus(self) -> None: + """Build omnibus Wald-test DataFrame (one row per outcome variable).""" + K = self.model.n_components + if K < 2: + warnings.warn("Only one class – omnibus test not applicable.") + self.omnibus_ = pd.DataFrame() + return + + sm_df = self._get_structural_boot() + idx_cols = ["model_name", "param", "variable"] + unique_vars = ( + sm_df.reset_index()[idx_cols] + .drop_duplicates() + .values.tolist() + ) + C = _build_contrast_matrix(K) # (K-1) × K + records = [] + + for model_name, param, variable in unique_vars: + mat = self._get_boot_matrix(model_name, param, variable) # (B, K) + if np.any(np.isnan(mat)): + continue + Sigma = np.cov(mat.T, ddof=1) # (K, K) + + try: + theta = ( + self.estimates_ + .loc[(model_name, param, variable)] + ["estimate"] + .values # (K,) + ) + except KeyError: + continue + + # Contrast: C @ theta ~ N(0, C @ Sigma @ C^T) under H0 + C_theta = C @ theta # (K-1,) + C_Sigma_CT = C @ Sigma @ C.T # (K-1) × (K-1) + + try: + C_Sigma_CT_inv = np.linalg.inv(C_Sigma_CT) + W = float(C_theta @ C_Sigma_CT_inv @ C_theta) + except np.linalg.LinAlgError: + W = np.nan + + df_test = K - 1 + p_val = float(1.0 - chi2.cdf(W, df=df_test)) if np.isfinite(W) else np.nan + + records.append( + dict( + model_name=model_name, + param=param, + variable=variable, + chi2=W, + df=df_test, + p_value=p_val, + sig=_stars(p_val), + ) + ) + + df = pd.DataFrame.from_records(records) + if len(df) > 0: + df.set_index(["model_name", "param", "variable"], inplace=True) + self.omnibus_ = df + + # ------------------------------------------------------------------ + # Public accessors + # ------------------------------------------------------------------ + + def _check_fitted(self): + if not self._is_fitted: + raise RuntimeError( + "Call fit_bootstrap() before accessing test results." + ) + + def get_estimates( + self, + variable: Optional[str] = None, + class_no: Optional[int] = None, + ) -> pd.DataFrame: + """Return the estimates table, optionally filtered. + + Parameters + ---------- + variable : str, optional + Filter to a specific outcome variable (e.g. ``"feature_0"``). + class_no : int, optional + Filter to a specific latent class index. + + Returns + ------- + pd.DataFrame with columns [estimate, se, CI_xx%_lo, CI_xx%_hi]. + """ + self._check_fitted() + df = self.estimates_.reset_index() + if variable is not None: + df = df[df["variable"] == variable] + if class_no is not None: + df = df[df["class_no"] == class_no] + return df.set_index(["model_name", "param", "variable", "class_no"]) + + def get_pairwise( + self, + variable: Optional[str] = None, + classes: Optional[tuple] = None, + ) -> pd.DataFrame: + """Return pairwise test table. + + Parameters + ---------- + variable : str, optional + Filter to a specific outcome variable. + classes : tuple (j, k), optional + Filter to a specific class pair. + + Returns + ------- + pd.DataFrame with columns [theta_j, theta_k, delta, se_delta, z, chi2, df, p_value, p_adj, sig]. + """ + self._check_fitted() + df = self.pairwise_.reset_index() + if variable is not None: + df = df[df["variable"] == variable] + if classes is not None: + j, k = classes + df = df[(df["class_j"] == j) & (df["class_k"] == k)] + return df.set_index(["model_name", "param", "variable", "class_j", "class_k"]) + + def get_omnibus(self, variable: Optional[str] = None) -> pd.DataFrame: + """Return omnibus test table. + + Parameters + ---------- + variable : str, optional + Filter to a specific outcome variable. + + Returns + ------- + pd.DataFrame with columns [chi2, df, p_value, sig]. + """ + self._check_fitted() + df = self.omnibus_.reset_index() + if variable is not None: + df = df[df["variable"] == variable] + return df.set_index(["model_name", "param", "variable"]) + + # ------------------------------------------------------------------ + # Summary + # ------------------------------------------------------------------ + + def summary( + self, + digits: int = 4, + show_omnibus: bool = True, + show_pairwise: bool = True, + show_estimates: bool = True, + ) -> None: + """Print a concise human-readable summary of all inference results. + + Parameters + ---------- + digits : int + Number of decimal places for numeric output. + show_omnibus : bool + Print the omnibus test table. + show_pairwise : bool + Print the pairwise test table. + show_estimates : bool + Print the estimates / SE / CI table. + """ + self._check_fitted() + K = self.model.n_components + n_boot = self.bootstrap_samples_["rep"].nunique() + ci_lo_label = [c for c in self.estimates_.columns if "lo" in c][0] + ci_hi_label = [c for c in self.estimates_.columns if "hi" in c][0] + + sep = "=" * 70 + thin = "-" * 70 + + print(sep) + print(" 3-Step LCA Outcome Analysis – Wald Tests & P-values") + print(sep) + print(f" Latent classes (K) : {K}") + print(f" Bootstrap reps (B) : {n_boot}") + print(f" CI level : {self.ci_level:.0%}") + print( + f" Significance codes : " + "*** p<0.001 | ** p<0.01 | * p<0.05 | . p<0.1" + ) + + # ---------------------------------------------------------------- + # Estimates table + # ---------------------------------------------------------------- + if show_estimates: + print() + print(sep) + print(" OUTCOME PARAMETER ESTIMATES (per latent class)") + print(sep) + est = self.estimates_.reset_index() + for var in est["variable"].unique(): + sub = est[est["variable"] == var].sort_values("class_no") + print(f"\n Variable: {var}") + header = ( + f" {'Class':>6} {'Estimate':>10} {'SE':>9} " + f"{'CI lo':>10} {'CI hi':>10}" + ) + print(header) + print(" " + "-" * 60) + for _, row in sub.iterrows(): + print( + f" {int(row['class_no']):>6} " + f"{row['estimate']:>10.{digits}f} " + f"{row['se']:>9.{digits}f} " + f"{row[ci_lo_label]:>10.{digits}f} " + f"{row[ci_hi_label]:>10.{digits}f}" + ) + + # ---------------------------------------------------------------- + # Omnibus tests + # ---------------------------------------------------------------- + if show_omnibus: + print() + print(sep) + print(" OMNIBUS WALD TEST (H₀: all class means / probs are equal)") + print(sep) + omn = self.omnibus_.reset_index() + print(f"\n {'Variable':>14} {'χ²':>10} {'df':>4} {'p-value':>10} {'sig':>4}") + print(" " + "-" * 50) + for _, row in omn.iterrows(): + p = row["p_value"] + print( + f" {row['variable']:>14} " + f"{row['chi2']:>10.{digits}f} " + f"{int(row['df']):>4} " + f"{p:>10.{digits}f} " + f"{_stars(p)}" + ) + + # ---------------------------------------------------------------- + # Pairwise tests + # ---------------------------------------------------------------- + if show_pairwise: + print() + print(sep) + print(" PAIRWISE WALD TESTS (H₀: θⱼ = θₖ for each class pair j < k)") + pair_df = self.pairwise_.reset_index() + correction_label = pair_df["correction"].iloc[0] if len(pair_df) > 0 else "none" + print(f" Multiple-comparison correction: {correction_label}") + print(sep) + for var in pair_df["variable"].unique(): + sub = pair_df[pair_df["variable"] == var] + print(f"\n Variable: {var}") + print( + f" {'j':>4} {'k':>4} " + f"{'θⱼ':>10} {'θₖ':>10} " + f"{'Δ':>10} {'SE(Δ)':>10} " + f"{'z':>8} {'χ²(1)':>8} " + f"{'p-value':>9} {'p_adj':>9} {'sig':>4}" + ) + print(" " + "-" * 100) + for _, row in sub.iterrows(): + p = row["p_adj"] + print( + f" {int(row['class_j']):>4} {int(row['class_k']):>4} " + f"{row['theta_j']:>10.{digits}f} " + f"{row['theta_k']:>10.{digits}f} " + f"{row['delta']:>10.{digits}f} " + f"{row['se_delta']:>10.{digits}f} " + f"{row['z']:>8.{digits}f} " + f"{row['chi2']:>8.{digits}f} " + f"{row['p_value']:>9.{digits}f} " + f"{p:>9.{digits}f} " + f"{_stars(p)}" + ) + + print() + print(sep) diff --git a/stepmix/stepmix.py b/stepmix/stepmix.py index c931202..7b964d3 100644 --- a/stepmix/stepmix.py +++ b/stepmix/stepmix.py @@ -1151,6 +1151,40 @@ def bootstrap( random_state=random_state, ) + + def wald_test(self, X, Y=None, n_repetitions=500, ci_level=0.95, correction=None, progress_bar=True, random_state=None): + """Runs non-parametric bootstrap over the full 3-step procedure and returns Wald test results. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Measurement data (same as used to fit the model). + Y : array-like of shape (n_samples, n_outcome_features) + Outcome data (same as used to fit the structural model). + n_repetitions : int, default=500 + Number of bootstrap replications. >= 500 recommended. + ci_level : float, default=0.95 + Confidence level for bootstrap percentile CIs. E.g. 0.95 -> 95 % CI. + correction : {None, "Bonferroni", "BH"}, default=None + Multiple-comparison correction applied to pairwise p-values. + progress_bar : bool, default=True + Show a tqdm progress bar. + random_state : int, default=None + Random seed for the bootstrap. + + Returns + ------- + wt : WaldTest3Step + A fitted object providing `.summary()`, `.pairwise_`, and `.omnibus_`. + """ + from stepmix.bootstrap import WaldTest3Step + wt = WaldTest3Step(self, ci_level=ci_level) + wt.fit_bootstrap( + X, Y, n_repetitions=n_repetitions, correction=correction, + progress_bar=progress_bar, random_state=random_state + ) + return wt + def bootstrap_stats( self, X, diff --git a/test/test_bootstrap_inference.py b/test/test_bootstrap_inference.py new file mode 100644 index 0000000..d9c458c --- /dev/null +++ b/test/test_bootstrap_inference.py @@ -0,0 +1,595 @@ +""" +test_bootstrap_inference.py +========================= +Unit and integration tests for test_bootstrap_inference.py. + +Run with: + python test_bootstrap_inference.py +or (if pytest is installed): + pytest test_bootstrap_inference.py -v +""" + +import warnings +import numpy as np +import pandas as pd +import pytest + +from stepmix.stepmix import StepMix +from stepmix.datasets import data_bakk_complete + +from stepmix.bootstrap import ( + WaldTest3Step, + _fdr_bh, + _bonferroni, + _build_contrast_matrix, +) + +# --------------------------------------------------------------------------- +# Fixtures / shared helpers +# --------------------------------------------------------------------------- +N_BOOT = 10 # keep CI fast; use 500+ for real analyses +SEED = 42 + + +def make_continuous_model(n_samples=400): + """Fitted 3-step model with Gaussian distal outcome.""" + X, Y, _ = data_bakk_complete(n_samples=n_samples, sep_level=0.9, random_state=SEED) + model = StepMix( + n_components=3, + n_steps=3, + measurement="bernoulli", + structural="gaussian_unit", + random_state=SEED, + verbose=0, + ) + model.fit(X, Y) + return model, X, Y + + +def make_binary_model(n_samples=400): + """Fitted 3-step model with binary distal outcome.""" + X, Y, _ = data_bakk_complete(n_samples=n_samples, sep_level=0.9, random_state=SEED) + Yb = (Y > Y.mean()).astype(float) + model = StepMix( + n_components=3, + n_steps=3, + measurement="bernoulli", + structural="bernoulli", + random_state=SEED, + verbose=0, + ) + model.fit(X, Yb) + return model, X, Yb + + +def make_soft_model(n_samples=400): + """Fitted 3-step soft-assignment model.""" + X, Y, _ = data_bakk_complete(n_samples=n_samples, sep_level=0.9, random_state=SEED) + model = StepMix( + n_components=3, + n_steps=3, + assignment="soft", + measurement="bernoulli", + structural="gaussian_unit", + random_state=SEED, + verbose=0, + ) + model.fit(X, Y) + return model, X, Y + + +# --------------------------------------------------------------------------- +# Unit tests – helper functions +# --------------------------------------------------------------------------- + + +class TestHelpers: + def test_contrast_matrix_shape(self): + for K in range(2, 6): + C = _build_contrast_matrix(K) + assert C.shape == (K - 1, K), f"Bad shape for K={K}" + + def test_contrast_matrix_values(self): + C = _build_contrast_matrix(3) + # Row 0: compares class 0 vs class 2 + np.testing.assert_array_equal(C[0], [1, 0, -1]) + # Row 1: compares class 1 vs class 2 + np.testing.assert_array_equal(C[1], [0, 1, -1]) + + def test_contrast_matrix_k2(self): + C = _build_contrast_matrix(2) + np.testing.assert_array_equal(C, [[1, -1]]) + + def test_bonferroni_correction(self): + p = np.array([0.01, 0.02, 0.05, 0.1]) + adj = _bonferroni(p) + # Each p multiplied by 4 (n tests), capped at 1 + np.testing.assert_allclose(adj, np.minimum(1.0, p * 4)) + + def test_bonferroni_with_nan(self): + p = np.array([0.01, np.nan, 0.05]) + adj = _bonferroni(p) + assert np.isnan(adj[1]) + assert not np.isnan(adj[0]) + assert not np.isnan(adj[2]) + + def test_fdr_bh_basic(self): + # All equal p-values → no inflation + p = np.array([0.05, 0.05, 0.05]) + adj = _fdr_bh(p) + assert all(adj <= 1.0) + assert all(adj >= p) + + def test_fdr_bh_monotone(self): + p = np.array([0.001, 0.01, 0.05, 0.1, 0.5]) + adj = _fdr_bh(p) + # Adjusted p-values must be non-decreasing + assert all(adj[i] <= adj[i + 1] for i in range(len(adj) - 1)) + + def test_fdr_bh_with_nan(self): + p = np.array([0.01, np.nan, 0.05]) + adj = _fdr_bh(p) + assert np.isnan(adj[1]) + + +# --------------------------------------------------------------------------- +# Unit tests – WaldTest3Step construction +# --------------------------------------------------------------------------- + + +class TestConstruction: + def test_requires_fitted_model(self): + model = StepMix(n_components=3, measurement="bernoulli", verbose=0) + with pytest.raises(Exception): + WaldTest3Step(model) + + def test_requires_structural_model(self): + X, _, _ = data_bakk_complete(n_samples=200, sep_level=0.9, random_state=SEED) + model = StepMix(n_components=3, measurement="bernoulli", verbose=0) + model.fit(X) + with pytest.raises(ValueError, match="structural"): + WaldTest3Step(model) + + def test_summary_before_fit_raises(self): + model, X, Y = make_continuous_model() + wt = WaldTest3Step(model) + with pytest.raises(RuntimeError): + wt.summary() + + def test_get_estimates_before_fit_raises(self): + model, X, Y = make_continuous_model() + wt = WaldTest3Step(model) + with pytest.raises(RuntimeError): + wt.get_estimates() + + +# --------------------------------------------------------------------------- +# Integration tests – continuous outcome +# --------------------------------------------------------------------------- + + +class TestContinuousOutcome: + @pytest.fixture(scope="class") + def fitted_wt(self): + model, X, Y = make_continuous_model() + wt = WaldTest3Step(model, ci_level=0.95) + wt.fit_bootstrap(X, Y, n_repetitions=N_BOOT, progress_bar=False, random_state=SEED) + return wt, model, X, Y + + def test_bootstrap_samples_shape(self, fitted_wt): + wt, model, X, Y = fitted_wt + assert wt.bootstrap_samples_ is not None + assert "rep" in wt.bootstrap_samples_.columns + n_reps = wt.bootstrap_samples_["rep"].nunique() + assert n_reps == N_BOOT + + def test_estimates_columns(self, fitted_wt): + wt, *_ = fitted_wt + cols = set(wt.estimates_.columns) + assert "estimate" in cols + assert "se" in cols + assert any("lo" in c for c in cols) + assert any("hi" in c for c in cols) + + def test_estimates_n_rows(self, fitted_wt): + wt, model, X, Y = fitted_wt + # 3 classes × n_outcome_vars + n_vars = Y.shape[1] + K = model.n_components + assert len(wt.estimates_) == K * n_vars + + def test_estimates_point_estimates_match_model(self, fitted_wt): + """Bootstrap point estimates should equal model's fitted parameters.""" + wt, model, X, Y = fitted_wt + # Use long-form (class_no as column) rather than wide get_sm_df() + model_sm = model.get_parameters_df().loc["structural"].reset_index() + boot_est = wt.estimates_.reset_index() + for _, row in model_sm.iterrows(): + match = boot_est[ + (boot_est["variable"] == row["variable"]) + & (boot_est["class_no"] == row["class_no"]) + ] + assert len(match) == 1 + np.testing.assert_allclose( + match["estimate"].values[0], row["value"], rtol=1e-6 + ) + + def test_se_positive(self, fitted_wt): + wt, *_ = fitted_wt + est = wt.estimates_ + assert (est["se"] > 0).all(), "All SEs should be positive" + + def test_ci_ordering(self, fitted_wt): + wt, *_ = fitted_wt + est = wt.estimates_.reset_index() + lo = [c for c in est.columns if "lo" in c][0] + hi = [c for c in est.columns if "hi" in c][0] + assert (est[hi] > est[lo]).all(), "CI upper must exceed CI lower" + + def test_pairwise_columns(self, fitted_wt): + wt, *_ = fitted_wt + expected = {"theta_j", "theta_k", "delta", "se_delta", "z", "chi2", "df", "p_value"} + cols = set(wt.pairwise_.columns) + assert expected.issubset(cols), f"Missing columns: {expected - cols}" + + def test_pairwise_n_rows(self, fitted_wt): + wt, model, X, Y = fitted_wt + K = model.n_components + n_vars = Y.shape[1] + n_pairs = K * (K - 1) // 2 + assert len(wt.pairwise_) == n_pairs * n_vars + + def test_pairwise_delta_sign(self, fitted_wt): + wt, *_ = fitted_wt + df = wt.pairwise_.reset_index() + # delta = theta_j - theta_k + np.testing.assert_allclose( + df["delta"].values, + df["theta_j"].values - df["theta_k"].values, + rtol=1e-6, + ) + + def test_pairwise_chi2_eq_z_squared(self, fitted_wt): + wt, *_ = fitted_wt + df = wt.pairwise_.reset_index() + np.testing.assert_allclose(df["chi2"].values, df["z"].values ** 2, rtol=1e-6) + + def test_pairwise_p_in_01(self, fitted_wt): + wt, *_ = fitted_wt + p = wt.pairwise_["p_value"].values + assert np.all((p >= 0) & (p <= 1)), "All p-values must be in [0, 1]" + + def test_pairwise_df_is_1(self, fitted_wt): + wt, *_ = fitted_wt + assert (wt.pairwise_["df"] == 1).all() + + def test_omnibus_columns(self, fitted_wt): + wt, *_ = fitted_wt + expected = {"chi2", "df", "p_value", "sig"} + assert expected.issubset(set(wt.omnibus_.columns)) + + def test_omnibus_n_rows(self, fitted_wt): + wt, model, X, Y = fitted_wt + n_vars = Y.shape[1] + assert len(wt.omnibus_) == n_vars + + def test_omnibus_df_is_K_minus_1(self, fitted_wt): + wt, model, *_ = fitted_wt + K = model.n_components + assert (wt.omnibus_["df"] == K - 1).all() + + def test_omnibus_p_in_01(self, fitted_wt): + wt, *_ = fitted_wt + p = wt.omnibus_["p_value"].values + assert np.all((p >= 0) & (p <= 1)) + + def test_well_separated_classes_significant(self, fitted_wt): + """With sep_level=0.9 and n=400, the omnibus test should be highly significant.""" + wt, *_ = fitted_wt + # At least one outcome variable should be significant + assert (wt.omnibus_["p_value"] < 0.05).any(), ( + "Expected at least one significant omnibus test with well-separated classes" + ) + + def test_summary_runs_without_error(self, fitted_wt, capsys): + wt, *_ = fitted_wt + wt.summary() + captured = capsys.readouterr() + assert "Wald" in captured.out + assert "p-value" in captured.out.lower() or "p_value" in captured.out.lower() + + def test_get_estimates_filter_variable(self, fitted_wt): + wt, model, X, Y = fitted_wt + est = wt.get_estimates(variable="feature_0") + assert len(est) == model.n_components + + def test_get_estimates_filter_class(self, fitted_wt): + wt, model, X, Y = fitted_wt + est = wt.get_estimates(class_no=0) + assert len(est) == Y.shape[1] + + def test_get_pairwise_filter_variable(self, fitted_wt): + wt, model, *_ = fitted_wt + pw = wt.get_pairwise(variable="feature_0") + K = model.n_components + assert len(pw) == K * (K - 1) // 2 + + def test_get_pairwise_filter_classes(self, fitted_wt): + wt, *_ = fitted_wt + pw = wt.get_pairwise(classes=(0, 1)) + # One row per variable + n_vars = len(wt.omnibus_) + assert len(pw) == n_vars + + def test_get_omnibus_filter_variable(self, fitted_wt): + wt, *_ = fitted_wt + omn = wt.get_omnibus(variable="feature_0") + assert len(omn) == 1 + + +# --------------------------------------------------------------------------- +# Integration tests – binary outcome +# --------------------------------------------------------------------------- + + +class TestBinaryOutcome: + @pytest.fixture(scope="class") + def fitted_wt(self): + model, X, Yb = make_binary_model() + wt = WaldTest3Step(model) + wt.fit_bootstrap(X, Yb, n_repetitions=N_BOOT, progress_bar=False, random_state=SEED) + return wt, model, X, Yb + + def test_estimates_exist(self, fitted_wt): + wt, *_ = fitted_wt + assert wt.estimates_ is not None + assert len(wt.estimates_) > 0 + + def test_pairwise_exists(self, fitted_wt): + wt, *_ = fitted_wt + assert wt.pairwise_ is not None + assert len(wt.pairwise_) > 0 + + def test_probabilities_in_01(self, fitted_wt): + """Binary outcome estimates must be probabilities in [0, 1].""" + wt, *_ = fitted_wt + est = wt.estimates_.reset_index() + assert (est["estimate"] >= 0).all() + assert (est["estimate"] <= 1).all() + + def test_binary_omnibus_df(self, fitted_wt): + wt, model, *_ = fitted_wt + K = model.n_components + assert (wt.omnibus_["df"] == K - 1).all() + + +# --------------------------------------------------------------------------- +# Integration tests – soft assignment +# --------------------------------------------------------------------------- + + +class TestSoftAssignment: + @pytest.fixture(scope="class") + def fitted_wt(self): + model, X, Y = make_soft_model() + wt = WaldTest3Step(model) + wt.fit_bootstrap(X, Y, n_repetitions=N_BOOT, progress_bar=False, random_state=SEED) + return wt, model, X, Y + + def test_fits_successfully(self, fitted_wt): + wt, *_ = fitted_wt + assert wt._is_fitted + + def test_se_positive(self, fitted_wt): + wt, *_ = fitted_wt + assert (wt.estimates_["se"] > 0).all() + + +# --------------------------------------------------------------------------- +# Integration tests – multiple-comparison corrections +# --------------------------------------------------------------------------- + + +class TestCorrections: + @pytest.fixture(scope="class") + def model_and_data(self): + return make_continuous_model() + + def test_no_correction(self, model_and_data): + model, X, Y = model_and_data + wt = WaldTest3Step(model) + wt.fit_bootstrap(X, Y, n_repetitions=N_BOOT, progress_bar=False, + random_state=SEED, correction=None) + assert "correction" in wt.pairwise_.columns + # With no correction, p_adj == p_value + np.testing.assert_allclose( + wt.pairwise_["p_adj"].values, + wt.pairwise_["p_value"].values, + rtol=1e-6, + ) + + def test_bonferroni(self, model_and_data): + model, X, Y = model_and_data + wt = WaldTest3Step(model) + wt.fit_bootstrap(X, Y, n_repetitions=N_BOOT, progress_bar=False, + random_state=SEED, correction="Bonferroni") + # Adjusted p ≥ raw p + assert (wt.pairwise_["p_adj"] >= wt.pairwise_["p_value"] - 1e-10).all() + assert (wt.pairwise_["p_adj"] <= 1.0 + 1e-10).all() + + def test_bh_fdr(self, model_and_data): + model, X, Y = model_and_data + wt = WaldTest3Step(model) + wt.fit_bootstrap(X, Y, n_repetitions=N_BOOT, progress_bar=False, + random_state=SEED, correction="BH") + assert (wt.pairwise_["p_adj"] >= wt.pairwise_["p_value"] - 1e-10).all() + assert (wt.pairwise_["p_adj"] <= 1.0 + 1e-10).all() + + def test_unknown_correction_warns(self, model_and_data): + model, X, Y = model_and_data + wt = WaldTest3Step(model) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + wt.fit_bootstrap(X, Y, n_repetitions=N_BOOT, progress_bar=False, + random_state=SEED, correction="unknown_method") + assert any("correction" in str(warning.message).lower() or + "unknown" in str(warning.message).lower() + for warning in w) + + +# --------------------------------------------------------------------------- +# Integration tests – convenience function +# --------------------------------------------------------------------------- + + +class TestStatisticalSanity: + def test_wald_z_matches_chi2(self): + """χ²(1) = z² must hold exactly.""" + model, X, Y = make_continuous_model() + wt = WaldTest3Step(model) + wt.fit_bootstrap(X, Y, n_repetitions=N_BOOT, progress_bar=False, + random_state=SEED) + df = wt.pairwise_.reset_index() + np.testing.assert_allclose(df["chi2"], df["z"] ** 2, rtol=1e-6) + + def test_omnibus_chi2_larger_than_any_pairwise(self): + """ + The omnibus Wald statistic should generally be ≥ any individual + pairwise Wald statistic (not a strict mathematical guarantee, but + holds when classes are clearly separated). + """ + model, X, Y = make_continuous_model(n_samples=600) + wt = WaldTest3Step(model) + wt.fit_bootstrap(X, Y, n_repetitions=N_BOOT, progress_bar=False, + random_state=SEED) + omn = wt.omnibus_.reset_index() + pw = wt.pairwise_.reset_index() + for var in omn["variable"].unique(): + omn_chi2 = omn[omn["variable"] == var]["chi2"].values[0] + max_pair_chi2 = pw[pw["variable"] == var]["chi2"].max() + # Omnibus on K-1 df cannot be compared directly, but with + # well-separated classes it should be large + assert np.isfinite(omn_chi2), f"Omnibus χ² should be finite for {var}" + assert np.isfinite(max_pair_chi2) + + def test_swap_symmetric(self): + """ + Swapping class_j and class_k should give the same χ² and p-value + but negated z and delta. We verify this by manually re-calling the + pairwise logic on reversed pairs. + """ + model, X, Y = make_continuous_model() + wt = WaldTest3Step(model) + wt.fit_bootstrap(X, Y, n_repetitions=N_BOOT, progress_bar=False, + random_state=SEED) + df = wt.pairwise_.reset_index() + # For each row, find the corresponding entry with j and k swapped + # (which would NOT exist in the table since j < k always, so we + # just verify that negating delta gives the same |z|) + np.testing.assert_allclose(np.abs(df["z"]), np.sqrt(df["chi2"]), rtol=1e-6) + + def test_p_value_from_z(self): + """p-value must equal 2 * (1 - Φ(|z|)).""" + from scipy.stats import norm + model, X, Y = make_continuous_model() + wt = WaldTest3Step(model) + wt.fit_bootstrap(X, Y, n_repetitions=N_BOOT, progress_bar=False, + random_state=SEED) + df = wt.pairwise_.reset_index() + expected_p = 2.0 * (1.0 - norm.cdf(np.abs(df["z"].values))) + np.testing.assert_allclose(df["p_value"].values, expected_p, rtol=1e-6) + + def test_omnibus_p_from_chi2(self): + """Omnibus p must equal 1 - χ²_CDF(W, df=K-1).""" + from scipy.stats import chi2 + model, X, Y = make_continuous_model() + K = model.n_components + wt = WaldTest3Step(model) + wt.fit_bootstrap(X, Y, n_repetitions=N_BOOT, progress_bar=False, + random_state=SEED) + omn = wt.omnibus_.reset_index() + expected_p = 1.0 - chi2.cdf(omn["chi2"].values, df=K - 1) + np.testing.assert_allclose(omn["p_value"].values, expected_p, rtol=1e-6) + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + print("Running stepmix_inference test suite...\n") + + # Collect all test classes + test_classes = [ + TestHelpers, + TestConstruction, + TestContinuousOutcome, + TestBinaryOutcome, + TestSoftAssignment, + TestCorrections, + TestConvenienceFunction, + TestStatisticalSanity, + ] + + passed = 0 + failed = 0 + errors = [] + + for cls in test_classes: + instance = cls() + + # Resolve class-scoped fixtures manually + class_fixtures = {} + for name in dir(cls): + attr = getattr(cls, name, None) + if callable(attr) and hasattr(attr, "pytestmark"): + pass + # For scope="class" fixtures we call them directly + fixture_cache = {} + + for name in sorted(dir(cls)): + if not name.startswith("test_"): + continue + method = getattr(instance, name) + + # Detect if the test requires a class-scoped fixture + import inspect + sig = inspect.signature(method) + params = list(sig.parameters.keys()) + # Remove 'self'; check for fixture args + fixture_args = [p for p in params if p not in ("self",)] + + try: + if fixture_args: + # Build fixture if not yet built + fixture_name = fixture_args[0] + if fixture_name not in fixture_cache: + fixture_method = getattr(cls, fixture_name, None) + if fixture_method is not None: + # It's a pytest fixture – call it + gen = fixture_method(instance) + if hasattr(gen, "__next__"): + fixture_cache[fixture_name] = next(gen) + else: + fixture_cache[fixture_name] = gen() + else: + # e.g. capsys – skip + print(f" SKIP {cls.__name__}::{name} (needs pytest fixture '{fixture_name}')") + continue + method(fixture_cache[fixture_name]) + else: + method() + print(f" PASS {cls.__name__}::{name}") + passed += 1 + except (AssertionError, Exception) as e: + print(f" FAIL {cls.__name__}::{name} → {e}") + failed += 1 + errors.append((cls.__name__, name, str(e))) + + print(f"\n{'='*60}") + print(f"Results: {passed} passed, {failed} failed") + if errors: + print("\nFailed tests:") + for cls_name, test_name, msg in errors: + print(f" {cls_name}::{test_name}: {msg}") + print("=" * 60)