diff --git a/dsps/utils.py b/dsps/utils.py index 7293b08..640dd79 100644 --- a/dsps/utils.py +++ b/dsps/utils.py @@ -1,5 +1,4 @@ -""" -""" +""" """ from jax import jit as jjit from jax import lax, nn @@ -291,10 +290,7 @@ def trapz(xarr, yarr): result : float """ - res_init = xarr[0], yarr[0], 0.0 - scan_data = xarr, yarr - cumtrapz = scan(_cumtrapz_scan_func, res_init, scan_data)[1] - return cumtrapz[-1] + return jnp.trapezoid(yarr, xarr) @jjit