From 01d20d7b3989aeeccbace167d201155d5875a15a Mon Sep 17 00:00:00 2001 From: Andrew Hearin Date: Mon, 13 Oct 2025 17:13:18 -0500 Subject: [PATCH 1/4] Implement calc_mgas_singlegal2 to demonstrate alternative algorithm for calculating gas mass history --- diffstar/mgas_model.py | 19 ++++++++++++++++++- diffstar/tests/test_mgas_model.py | 25 ++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/diffstar/mgas_model.py b/diffstar/mgas_model.py index 67b3e29..46279ec 100644 --- a/diffstar/mgas_model.py +++ b/diffstar/mgas_model.py @@ -1,6 +1,6 @@ from collections import namedtuple -from diffmah.diffmah_kernels import mah_halopop, mah_singlehalo +from diffmah.diffmah_kernels import _log_mah_kern, mah_halopop, mah_singlehalo from jax import jit as jjit from jax import vmap @@ -77,6 +77,23 @@ def calc_mgas_singlegal(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB): return GalHistory(sfh, smh, dmgas_dt, mgas) +@jjit +def calc_mgas_singlegal2(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB): + log_mah = _log_mah_kern(mah_params, tarr, lgt0) + mgas_inst = fb * 10**log_mah + + ms_params, q_params = sfh_params[:4], sfh_params[4:] + sfh = _sfh_singlegal_kern(tarr, mah_params, ms_params, q_params, lgt0, fb) + smh = cumulative_mstar_formed(tarr, sfh) + + mgas = mgas_inst - smh + + dt = _jax_get_dt_array(tarr) + dmgas_dt = _jax_get_dt_array(mgas) / dt / 1e9 + + return GalHistory(sfh, smh, dmgas_dt, mgas) + + @jjit def calc_mgas_galpop(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB): """Calculate the Diffstar SFH and Mgas for a single galaxy diff --git a/diffstar/tests/test_mgas_model.py b/diffstar/tests/test_mgas_model.py index a2f69e2..96f9f66 100644 --- a/diffstar/tests/test_mgas_model.py +++ b/diffstar/tests/test_mgas_model.py @@ -16,7 +16,7 @@ QUParams, get_bounded_diffstar_params, ) -from ..mgas_model import calc_mgas_galpop, calc_mgas_singlegal +from ..mgas_model import calc_mgas_galpop, calc_mgas_singlegal, calc_mgas_singlegal2 def _get_all_default_params(): @@ -85,3 +85,26 @@ def test_calc_mgas_galpop_evaluates(): assert np.all(gal_history.sfh > 0) assert np.all(gal_history.smh > 0) assert np.all(gal_history.mgash > 0) + + +def test_calc_mgas_singlegal2(): + lgt0, mah_params, u_ms_params_init, u_q_params_init = _get_all_default_u_params() + u_sfh_init = np.array((*u_ms_params_init, *u_q_params_init)) + + n_t = 2_000 + tarr = np.linspace(0.1, 10**lgt0, n_t) + + ran_key = jran.PRNGKey(0) + ntests = 20 + ran_keys = jran.split(ran_key, ntests) + for test_key in ran_keys: + + u_ms_params = jran.normal(test_key, shape=(8,)) + np.array(u_sfh_init) + sfh_u_params = DiffstarUParams(*u_ms_params) + sfh_params = get_bounded_diffstar_params(sfh_u_params) + res = calc_mgas_singlegal(sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB) + res2 = calc_mgas_singlegal2(sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB) + + t_min_compare = 0.5 # Gyr + msk_t_min = tarr > t_min_compare + assert np.allclose(res.mgash[msk_t_min], res2.mgash[msk_t_min], rtol=0.02) From d8c294bd48ff21789a421fdcc3f8167564eb8c70 Mon Sep 17 00:00:00 2001 From: Andrew Hearin Date: Mon, 13 Oct 2025 17:58:22 -0500 Subject: [PATCH 2/4] Add calc_mgas_singlegal3 function that computes dmgas_dt via the grad --- diffstar/mgas_model.py | 60 ++++++++++++++++++++++++++++++- diffstar/tests/test_mgas_model.py | 39 ++++++++++++++++++-- 2 files changed, 96 insertions(+), 3 deletions(-) diff --git a/diffstar/mgas_model.py b/diffstar/mgas_model.py index 46279ec..af5d4d5 100644 --- a/diffstar/mgas_model.py +++ b/diffstar/mgas_model.py @@ -1,10 +1,12 @@ from collections import namedtuple from diffmah.diffmah_kernels import _log_mah_kern, mah_halopop, mah_singlehalo +from jax import grad from jax import jit as jjit +from jax import numpy as jnp from jax import vmap -from .defaults import FB, LGT0 +from .defaults import FB, LGT0, T_TABLE_MIN from .kernels.history_kernel_builders import _sfh_galpop_kern, _sfh_singlegal_kern from .utils import _jax_get_dt_array, cumulative_mstar_formed @@ -13,6 +15,8 @@ GalHistory = namedtuple("GalHistory", ("sfh", "smh", "dmgash", "mgash")) +N_INT_STEPS = 50 + @jjit def calc_mgas_singlegal(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB): @@ -94,6 +98,60 @@ def calc_mgas_singlegal2(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB): return GalHistory(sfh, smh, dmgas_dt, mgas) +@jjit +def _calc_mgas_kern(sfh_params, mah_params, t_obs, lgt0, fb): + + t_table = jnp.linspace(T_TABLE_MIN, t_obs, N_INT_STEPS) + log_mah_table = _log_mah_kern(mah_params, t_table, lgt0) + mgas_inst_table = fb * 10**log_mah_table + + ms_params, q_params = sfh_params[:4], sfh_params[4:] + sfh_table = _sfh_singlegal_kern(t_table, mah_params, ms_params, q_params, lgt0, fb) + smh_table = cumulative_mstar_formed(t_table, sfh_table) + + mgas_table = mgas_inst_table - smh_table + + mgas_obs = mgas_table[-1] + sfr_obs = sfh_table[-1] + mstar_obs = smh_table[-1] + + return mgas_obs, mstar_obs, sfr_obs + + +@jjit +def _calc_dmgas_dt_kern_wrapper(sfh_params, mah_params, t_obs, lgt0, fb): + mgas_obs = _calc_mgas_kern(sfh_params, mah_params, t_obs, lgt0, fb)[0] + return mgas_obs + + +_calc_dmgas_dt_kern_nonorm = jjit(grad(_calc_dmgas_dt_kern_wrapper, argnums=2)) + + +@jjit +def _calc_mgas_and_dmgas_dt_kern(sfh_params, mah_params, t_obs, lgt0, fb): + mgas_obs, mstar_obs, sfr_obs = _calc_mgas_kern( + sfh_params, mah_params, t_obs, lgt0, fb + ) + dmgas_dt = _calc_dmgas_dt_kern_nonorm(sfh_params, mah_params, t_obs, lgt0, fb) / 1e9 + return sfr_obs, mstar_obs, dmgas_dt, mgas_obs + + +@jjit +def _calc_dmgas_dt_kern(sfh_params, mah_params, t_obs, lgt0, fb): + return _calc_dmgas_dt_kern_nonorm(sfh_params, mah_params, t_obs, lgt0, fb) / 1e9 + + +_TARR = (None, None, 0, None, None) +_calc_mgas_and_dmgas_dt_vmap = jjit(vmap(_calc_mgas_and_dmgas_dt_kern, in_axes=_TARR)) + + +@jjit +def calc_mgas_singlegal3(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB): + _res = _calc_mgas_and_dmgas_dt_vmap(sfh_params, mah_params, tarr, lgt0, fb) + sfh, smh, dmgas_dt, mgash = _res + return GalHistory(sfh, smh, dmgas_dt, mgash) + + @jjit def calc_mgas_galpop(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB): """Calculate the Diffstar SFH and Mgas for a single galaxy diff --git a/diffstar/tests/test_mgas_model.py b/diffstar/tests/test_mgas_model.py index 96f9f66..88acb51 100644 --- a/diffstar/tests/test_mgas_model.py +++ b/diffstar/tests/test_mgas_model.py @@ -16,7 +16,12 @@ QUParams, get_bounded_diffstar_params, ) -from ..mgas_model import calc_mgas_galpop, calc_mgas_singlegal, calc_mgas_singlegal2 +from ..mgas_model import ( + calc_mgas_galpop, + calc_mgas_singlegal, + calc_mgas_singlegal2, + calc_mgas_singlegal3, +) def _get_all_default_params(): @@ -106,5 +111,35 @@ def test_calc_mgas_singlegal2(): res2 = calc_mgas_singlegal2(sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB) t_min_compare = 0.5 # Gyr + mgas_min_compare = 1e-2 * res.mgash.max() # Ignore very tiny values msk_t_min = tarr > t_min_compare - assert np.allclose(res.mgash[msk_t_min], res2.mgash[msk_t_min], rtol=0.02) + msk_mgas_below_peak = res.mgash > mgas_min_compare + msk_compare = msk_t_min & msk_mgas_below_peak + assert np.allclose(res.mgash[msk_compare], res2.mgash[msk_compare], rtol=0.02) + + +def test_calc_mgas_singlegal3(): + lgt0, mah_params, u_ms_params_init, u_q_params_init = _get_all_default_u_params() + u_sfh_init = np.array((*u_ms_params_init, *u_q_params_init)) + + n_t = 2_000 + tarr = np.linspace(0.2, 10**lgt0, n_t) + + ran_key = jran.PRNGKey(0) + ntests = 20 + ran_keys = jran.split(ran_key, ntests) + for test_key in ran_keys: + + u_ms_params = jran.normal(test_key, shape=(8,)) + np.array(u_sfh_init) + sfh_u_params = DiffstarUParams(*u_ms_params) + sfh_params = get_bounded_diffstar_params(sfh_u_params) + res = calc_mgas_singlegal(sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB) + res2 = calc_mgas_singlegal3(sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB) + + t_min_compare = 0.5 # Gyr + mgas_min_compare = 1e-2 * res.mgash.max() # Ignore very tiny values + msk_t_min = tarr > t_min_compare + msk_mgas_below_peak = res.mgash > mgas_min_compare + msk_compare = msk_t_min & msk_mgas_below_peak + + assert np.allclose(res.mgash[msk_compare], res2.mgash[msk_compare], rtol=0.1) From dd4863fa1c35ceea67c28612e8fcaa4520f71392 Mon Sep 17 00:00:00 2001 From: Andrew Hearin Date: Mon, 13 Oct 2025 18:14:15 -0500 Subject: [PATCH 3/4] Add wrapper function calc_mgas_galpop3 --- diffstar/mgas_model.py | 11 +++++++++++ diffstar/tests/test_mgas_model.py | 20 ++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/diffstar/mgas_model.py b/diffstar/mgas_model.py index af5d4d5..066961d 100644 --- a/diffstar/mgas_model.py +++ b/diffstar/mgas_model.py @@ -152,6 +152,17 @@ def calc_mgas_singlegal3(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB): return GalHistory(sfh, smh, dmgas_dt, mgash) +_GPOP = (0, 0, None, None, None) +_calc_mgas_and_dmgas_dt_galpop = jjit(vmap(_calc_mgas_and_dmgas_dt_vmap, in_axes=_GPOP)) + + +@jjit +def calc_mgas_galpop3(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB): + _res = _calc_mgas_and_dmgas_dt_galpop(sfh_params, mah_params, tarr, lgt0, fb) + sfh, smh, dmgas_dt, mgash = _res + return GalHistory(sfh, smh, dmgas_dt, mgash) + + @jjit def calc_mgas_galpop(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB): """Calculate the Diffstar SFH and Mgas for a single galaxy diff --git a/diffstar/tests/test_mgas_model.py b/diffstar/tests/test_mgas_model.py index 88acb51..caf77de 100644 --- a/diffstar/tests/test_mgas_model.py +++ b/diffstar/tests/test_mgas_model.py @@ -18,6 +18,7 @@ ) from ..mgas_model import ( calc_mgas_galpop, + calc_mgas_galpop3, calc_mgas_singlegal, calc_mgas_singlegal2, calc_mgas_singlegal3, @@ -143,3 +144,22 @@ def test_calc_mgas_singlegal3(): msk_compare = msk_t_min & msk_mgas_below_peak assert np.allclose(res.mgash[msk_compare], res2.mgash[msk_compare], rtol=0.1) + + +def test_calc_mgas_galpop3_evaluates(): + n_gals = 50 + ZZ = np.zeros(n_gals) + ms_params = DEFAULT_MS_PARAMS._make([ZZ + x for x in DEFAULT_MS_PARAMS]) + q_params = DEFAULT_Q_PARAMS._make([ZZ + x for x in DEFAULT_Q_PARAMS]) + sfh_params = DiffstarUParams(*ms_params, *q_params) + mah_params = DEFAULT_MAH_PARAMS._make([ZZ + x for x in DEFAULT_MAH_PARAMS]) + tarr = np.linspace(0.1, 13.8, 30) + gal_history3 = calc_mgas_galpop3(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB) + + assert gal_history3._fields == ("sfh", "smh", "dmgash", "mgash") + for x in gal_history3: + assert np.all(np.isfinite(x)) + + assert np.all(gal_history3.sfh > 0) + assert np.all(gal_history3.smh > 0) + assert np.all(gal_history3.mgash > 0) From 245e78dd2aef1e7f155ae8fa271dff5005e23103 Mon Sep 17 00:00:00 2001 From: Andrew Hearin Date: Tue, 14 Oct 2025 11:31:59 -0500 Subject: [PATCH 4/4] Set N_INT_STEPS to 20 instead of 50 --- diffstar/mgas_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffstar/mgas_model.py b/diffstar/mgas_model.py index 066961d..9aa46bc 100644 --- a/diffstar/mgas_model.py +++ b/diffstar/mgas_model.py @@ -15,7 +15,7 @@ GalHistory = namedtuple("GalHistory", ("sfh", "smh", "dmgash", "mgash")) -N_INT_STEPS = 50 +N_INT_STEPS = 20 @jjit