Whatever code ends up calling jax.interpreters.xla.pytype_aval_mappings breaks with jax >= 0.7, when it was removed. (I ran into it trying to import linear_gaussian_ssm. Would be nice to update pyproject.toml pinning jax before that or better yet, updating code so it's compatible with later jax versions.
Whatever code ends up calling
jax.interpreters.xla.pytype_aval_mappingsbreaks with jax >= 0.7, when it was removed. (I ran into it trying to importlinear_gaussian_ssm. Would be nice to update pyproject.toml pinning jax before that or better yet, updating code so it's compatible with later jax versions.