Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 88 additions & 2 deletions diffstar/mgas_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections import namedtuple

from diffmah.diffmah_kernels import mah_halopop, mah_singlehalo
from diffmah.diffmah_kernels import _log_mah_kern, mah_halopop, mah_singlehalo
from jax import grad
from jax import jit as jjit
from jax import numpy as jnp
from jax import vmap

from .defaults import FB, LGT0
from .defaults import FB, LGT0, T_TABLE_MIN
from .kernels.history_kernel_builders import _sfh_galpop_kern, _sfh_singlegal_kern
from .utils import _jax_get_dt_array, cumulative_mstar_formed

Expand All @@ -13,6 +15,8 @@

GalHistory = namedtuple("GalHistory", ("sfh", "smh", "dmgash", "mgash"))

N_INT_STEPS = 20


@jjit
def calc_mgas_singlegal(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB):
Expand Down Expand Up @@ -77,6 +81,88 @@ def calc_mgas_singlegal(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB):
return GalHistory(sfh, smh, dmgas_dt, mgas)


@jjit
def calc_mgas_singlegal2(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB):
log_mah = _log_mah_kern(mah_params, tarr, lgt0)
mgas_inst = fb * 10**log_mah

ms_params, q_params = sfh_params[:4], sfh_params[4:]
sfh = _sfh_singlegal_kern(tarr, mah_params, ms_params, q_params, lgt0, fb)
smh = cumulative_mstar_formed(tarr, sfh)

mgas = mgas_inst - smh

dt = _jax_get_dt_array(tarr)
dmgas_dt = _jax_get_dt_array(mgas) / dt / 1e9

return GalHistory(sfh, smh, dmgas_dt, mgas)


@jjit
def _calc_mgas_kern(sfh_params, mah_params, t_obs, lgt0, fb):

t_table = jnp.linspace(T_TABLE_MIN, t_obs, N_INT_STEPS)
log_mah_table = _log_mah_kern(mah_params, t_table, lgt0)
mgas_inst_table = fb * 10**log_mah_table

ms_params, q_params = sfh_params[:4], sfh_params[4:]
sfh_table = _sfh_singlegal_kern(t_table, mah_params, ms_params, q_params, lgt0, fb)
smh_table = cumulative_mstar_formed(t_table, sfh_table)

mgas_table = mgas_inst_table - smh_table

mgas_obs = mgas_table[-1]
sfr_obs = sfh_table[-1]
mstar_obs = smh_table[-1]

return mgas_obs, mstar_obs, sfr_obs


@jjit
def _calc_dmgas_dt_kern_wrapper(sfh_params, mah_params, t_obs, lgt0, fb):
mgas_obs = _calc_mgas_kern(sfh_params, mah_params, t_obs, lgt0, fb)[0]
return mgas_obs


_calc_dmgas_dt_kern_nonorm = jjit(grad(_calc_dmgas_dt_kern_wrapper, argnums=2))


@jjit
def _calc_mgas_and_dmgas_dt_kern(sfh_params, mah_params, t_obs, lgt0, fb):
mgas_obs, mstar_obs, sfr_obs = _calc_mgas_kern(
sfh_params, mah_params, t_obs, lgt0, fb
)
dmgas_dt = _calc_dmgas_dt_kern_nonorm(sfh_params, mah_params, t_obs, lgt0, fb) / 1e9
return sfr_obs, mstar_obs, dmgas_dt, mgas_obs


@jjit
def _calc_dmgas_dt_kern(sfh_params, mah_params, t_obs, lgt0, fb):
return _calc_dmgas_dt_kern_nonorm(sfh_params, mah_params, t_obs, lgt0, fb) / 1e9


_TARR = (None, None, 0, None, None)
_calc_mgas_and_dmgas_dt_vmap = jjit(vmap(_calc_mgas_and_dmgas_dt_kern, in_axes=_TARR))


@jjit
def calc_mgas_singlegal3(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB):
_res = _calc_mgas_and_dmgas_dt_vmap(sfh_params, mah_params, tarr, lgt0, fb)
sfh, smh, dmgas_dt, mgash = _res
return GalHistory(sfh, smh, dmgas_dt, mgash)


_GPOP = (0, 0, None, None, None)
_calc_mgas_and_dmgas_dt_galpop = jjit(vmap(_calc_mgas_and_dmgas_dt_vmap, in_axes=_GPOP))


@jjit
def calc_mgas_galpop3(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB):
_res = _calc_mgas_and_dmgas_dt_galpop(sfh_params, mah_params, tarr, lgt0, fb)
sfh, smh, dmgas_dt, mgash = _res
return GalHistory(sfh, smh, dmgas_dt, mgash)


@jjit
def calc_mgas_galpop(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB):
"""Calculate the Diffstar SFH and Mgas for a single galaxy
Expand Down
80 changes: 79 additions & 1 deletion diffstar/tests/test_mgas_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
QUParams,
get_bounded_diffstar_params,
)
from ..mgas_model import calc_mgas_galpop, calc_mgas_singlegal
from ..mgas_model import (
calc_mgas_galpop,
calc_mgas_galpop3,
calc_mgas_singlegal,
calc_mgas_singlegal2,
calc_mgas_singlegal3,
)


def _get_all_default_params():
Expand Down Expand Up @@ -85,3 +91,75 @@ def test_calc_mgas_galpop_evaluates():
assert np.all(gal_history.sfh > 0)
assert np.all(gal_history.smh > 0)
assert np.all(gal_history.mgash > 0)


def test_calc_mgas_singlegal2():
lgt0, mah_params, u_ms_params_init, u_q_params_init = _get_all_default_u_params()
u_sfh_init = np.array((*u_ms_params_init, *u_q_params_init))

n_t = 2_000
tarr = np.linspace(0.1, 10**lgt0, n_t)

ran_key = jran.PRNGKey(0)
ntests = 20
ran_keys = jran.split(ran_key, ntests)
for test_key in ran_keys:

u_ms_params = jran.normal(test_key, shape=(8,)) + np.array(u_sfh_init)
sfh_u_params = DiffstarUParams(*u_ms_params)
sfh_params = get_bounded_diffstar_params(sfh_u_params)
res = calc_mgas_singlegal(sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB)
res2 = calc_mgas_singlegal2(sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB)

t_min_compare = 0.5 # Gyr
mgas_min_compare = 1e-2 * res.mgash.max() # Ignore very tiny values
msk_t_min = tarr > t_min_compare
msk_mgas_below_peak = res.mgash > mgas_min_compare
msk_compare = msk_t_min & msk_mgas_below_peak
assert np.allclose(res.mgash[msk_compare], res2.mgash[msk_compare], rtol=0.02)


def test_calc_mgas_singlegal3():
lgt0, mah_params, u_ms_params_init, u_q_params_init = _get_all_default_u_params()
u_sfh_init = np.array((*u_ms_params_init, *u_q_params_init))

n_t = 2_000
tarr = np.linspace(0.2, 10**lgt0, n_t)

ran_key = jran.PRNGKey(0)
ntests = 20
ran_keys = jran.split(ran_key, ntests)
for test_key in ran_keys:

u_ms_params = jran.normal(test_key, shape=(8,)) + np.array(u_sfh_init)
sfh_u_params = DiffstarUParams(*u_ms_params)
sfh_params = get_bounded_diffstar_params(sfh_u_params)
res = calc_mgas_singlegal(sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB)
res2 = calc_mgas_singlegal3(sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB)

t_min_compare = 0.5 # Gyr
mgas_min_compare = 1e-2 * res.mgash.max() # Ignore very tiny values
msk_t_min = tarr > t_min_compare
msk_mgas_below_peak = res.mgash > mgas_min_compare
msk_compare = msk_t_min & msk_mgas_below_peak

assert np.allclose(res.mgash[msk_compare], res2.mgash[msk_compare], rtol=0.1)


def test_calc_mgas_galpop3_evaluates():
n_gals = 50
ZZ = np.zeros(n_gals)
ms_params = DEFAULT_MS_PARAMS._make([ZZ + x for x in DEFAULT_MS_PARAMS])
q_params = DEFAULT_Q_PARAMS._make([ZZ + x for x in DEFAULT_Q_PARAMS])
sfh_params = DiffstarUParams(*ms_params, *q_params)
mah_params = DEFAULT_MAH_PARAMS._make([ZZ + x for x in DEFAULT_MAH_PARAMS])
tarr = np.linspace(0.1, 13.8, 30)
gal_history3 = calc_mgas_galpop3(sfh_params, mah_params, tarr, lgt0=LGT0, fb=FB)

assert gal_history3._fields == ("sfh", "smh", "dmgash", "mgash")
for x in gal_history3:
assert np.all(np.isfinite(x))

assert np.all(gal_history3.sfh > 0)
assert np.all(gal_history3.smh > 0)
assert np.all(gal_history3.mgash > 0)
Loading