diff --git a/diffstar/mgas_model.py b/diffstar/mgas_model.py index 67b3e29..9aa46bc 100644 --- a/diffstar/mgas_model.py +++ b/diffstar/mgas_model.py @@ -1,10 +1,12 @@ from collections import namedtuple -from diffmah.diffmah_kernels import mah_halopop, mah_singlehalo +from diffmah.diffmah_kernels import _log_mah_kern, mah_halopop, mah_singlehalo +from jax import grad from jax import jit as jjit +from jax import numpy as jnp from jax import vmap -from .defaults import FB, LGT0 +from .defaults import FB, LGT0, T_TABLE_MIN from .kernels.history_kernel_builders import _sfh_galpop_kern, _sfh_singlegal_kern from .utils import _jax_get_dt_array, cumulative_mstar_formed @@ -13,6 +15,8 @@ GalHistory = namedtuple("GalHistory", ("sfh", "smh", "dmgash", "mgash")) +N_INT_STEPS = 20 + @jjit def calc_mgas_singlegal(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB): @@ -77,6 +81,88 @@ def calc_mgas_singlegal(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB): return GalHistory(sfh, smh, dmgas_dt, mgas) +@jjit +def calc_mgas_singlegal2(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB): + log_mah = _log_mah_kern(mah_params, tarr, lgt0) + mgas_inst = fb * 10**log_mah + + ms_params, q_params = sfh_params[:4], sfh_params[4:] + sfh = _sfh_singlegal_kern(tarr, mah_params, ms_params, q_params, lgt0, fb) + smh = cumulative_mstar_formed(tarr, sfh) + + mgas = mgas_inst - smh + + dt = _jax_get_dt_array(tarr) + dmgas_dt = _jax_get_dt_array(mgas) / dt / 1e9 + + return GalHistory(sfh, smh, dmgas_dt, mgas) + + +@jjit +def _calc_mgas_kern(sfh_params, mah_params, t_obs, lgt0, fb): + + t_table = jnp.linspace(T_TABLE_MIN, t_obs, N_INT_STEPS) + log_mah_table = _log_mah_kern(mah_params, t_table, lgt0) + mgas_inst_table = fb * 10**log_mah_table + + ms_params, q_params = sfh_params[:4], sfh_params[4:] + sfh_table = _sfh_singlegal_kern(t_table, mah_params, ms_params, q_params, lgt0, fb) + smh_table = cumulative_mstar_formed(t_table, sfh_table) + + mgas_table = mgas_inst_table - smh_table + + mgas_obs = mgas_table[-1] + sfr_obs = sfh_table[-1] + mstar_obs = smh_table[-1] + + return mgas_obs, mstar_obs, sfr_obs + + +@jjit +def _calc_dmgas_dt_kern_wrapper(sfh_params, mah_params, t_obs, lgt0, fb): + mgas_obs = _calc_mgas_kern(sfh_params, mah_params, t_obs, lgt0, fb)[0] + return mgas_obs + + +_calc_dmgas_dt_kern_nonorm = jjit(grad(_calc_dmgas_dt_kern_wrapper, argnums=2)) + + +@jjit +def _calc_mgas_and_dmgas_dt_kern(sfh_params, mah_params, t_obs, lgt0, fb): + mgas_obs, mstar_obs, sfr_obs = _calc_mgas_kern( + sfh_params, mah_params, t_obs, lgt0, fb + ) + dmgas_dt = _calc_dmgas_dt_kern_nonorm(sfh_params, mah_params, t_obs, lgt0, fb) / 1e9 + return sfr_obs, mstar_obs, dmgas_dt, mgas_obs + + +@jjit +def _calc_dmgas_dt_kern(sfh_params, mah_params, t_obs, lgt0, fb): + return _calc_dmgas_dt_kern_nonorm(sfh_params, mah_params, t_obs, lgt0, fb) / 1e9 + + +_TARR = (None, None, 0, None, None) +_calc_mgas_and_dmgas_dt_vmap = jjit(vmap(_calc_mgas_and_dmgas_dt_kern, in_axes=_TARR)) + + +@jjit +def calc_mgas_singlegal3(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB): + _res = _calc_mgas_and_dmgas_dt_vmap(sfh_params, mah_params, tarr, lgt0, fb) + sfh, smh, dmgas_dt, mgash = _res + return GalHistory(sfh, smh, dmgas_dt, mgash) + + +_GPOP = (0, 0, None, None, None) +_calc_mgas_and_dmgas_dt_galpop = jjit(vmap(_calc_mgas_and_dmgas_dt_vmap, in_axes=_GPOP)) + + +@jjit +def calc_mgas_galpop3(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB): + _res = _calc_mgas_and_dmgas_dt_galpop(sfh_params, mah_params, tarr, lgt0, fb) + sfh, smh, dmgas_dt, mgash = _res + return GalHistory(sfh, smh, dmgas_dt, mgash) + + @jjit def calc_mgas_galpop(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB): """Calculate the Diffstar SFH and Mgas for a single galaxy diff --git a/diffstar/tests/test_mgas_model.py b/diffstar/tests/test_mgas_model.py index a2f69e2..caf77de 100644 --- a/diffstar/tests/test_mgas_model.py +++ b/diffstar/tests/test_mgas_model.py @@ -16,7 +16,13 @@ QUParams, get_bounded_diffstar_params, ) -from ..mgas_model import calc_mgas_galpop, calc_mgas_singlegal +from ..mgas_model import ( + calc_mgas_galpop, + calc_mgas_galpop3, + calc_mgas_singlegal, + calc_mgas_singlegal2, + calc_mgas_singlegal3, +) def _get_all_default_params(): @@ -85,3 +91,75 @@ def test_calc_mgas_galpop_evaluates(): assert np.all(gal_history.sfh > 0) assert np.all(gal_history.smh > 0) assert np.all(gal_history.mgash > 0) + + +def test_calc_mgas_singlegal2(): + lgt0, mah_params, u_ms_params_init, u_q_params_init = _get_all_default_u_params() + u_sfh_init = np.array((*u_ms_params_init, *u_q_params_init)) + + n_t = 2_000 + tarr = np.linspace(0.1, 10**lgt0, n_t) + + ran_key = jran.PRNGKey(0) + ntests = 20 + ran_keys = jran.split(ran_key, ntests) + for test_key in ran_keys: + + u_ms_params = jran.normal(test_key, shape=(8,)) + np.array(u_sfh_init) + sfh_u_params = DiffstarUParams(*u_ms_params) + sfh_params = get_bounded_diffstar_params(sfh_u_params) + res = calc_mgas_singlegal(sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB) + res2 = calc_mgas_singlegal2(sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB) + + t_min_compare = 0.5 # Gyr + mgas_min_compare = 1e-2 * res.mgash.max() # Ignore very tiny values + msk_t_min = tarr > t_min_compare + msk_mgas_below_peak = res.mgash > mgas_min_compare + msk_compare = msk_t_min & msk_mgas_below_peak + assert np.allclose(res.mgash[msk_compare], res2.mgash[msk_compare], rtol=0.02) + + +def test_calc_mgas_singlegal3(): + lgt0, mah_params, u_ms_params_init, u_q_params_init = _get_all_default_u_params() + u_sfh_init = np.array((*u_ms_params_init, *u_q_params_init)) + + n_t = 2_000 + tarr = np.linspace(0.2, 10**lgt0, n_t) + + ran_key = jran.PRNGKey(0) + ntests = 20 + ran_keys = jran.split(ran_key, ntests) + for test_key in ran_keys: + + u_ms_params = jran.normal(test_key, shape=(8,)) + np.array(u_sfh_init) + sfh_u_params = DiffstarUParams(*u_ms_params) + sfh_params = get_bounded_diffstar_params(sfh_u_params) + res = calc_mgas_singlegal(sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB) + res2 = calc_mgas_singlegal3(sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB) + + t_min_compare = 0.5 # Gyr + mgas_min_compare = 1e-2 * res.mgash.max() # Ignore very tiny values + msk_t_min = tarr > t_min_compare + msk_mgas_below_peak = res.mgash > mgas_min_compare + msk_compare = msk_t_min & msk_mgas_below_peak + + assert np.allclose(res.mgash[msk_compare], res2.mgash[msk_compare], rtol=0.1) + + +def test_calc_mgas_galpop3_evaluates(): + n_gals = 50 + ZZ = np.zeros(n_gals) + ms_params = DEFAULT_MS_PARAMS._make([ZZ + x for x in DEFAULT_MS_PARAMS]) + q_params = DEFAULT_Q_PARAMS._make([ZZ + x for x in DEFAULT_Q_PARAMS]) + sfh_params = DiffstarUParams(*ms_params, *q_params) + mah_params = DEFAULT_MAH_PARAMS._make([ZZ + x for x in DEFAULT_MAH_PARAMS]) + tarr = np.linspace(0.1, 13.8, 30) + gal_history3 = calc_mgas_galpop3(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB) + + assert gal_history3._fields == ("sfh", "smh", "dmgash", "mgash") + for x in gal_history3: + assert np.all(np.isfinite(x)) + + assert np.all(gal_history3.sfh > 0) + assert np.all(gal_history3.smh > 0) + assert np.all(gal_history3.mgash > 0)