Skip to content

Incompatible with jax>=0.7 #448

@kjohnsen

Description

@kjohnsen

Whatever code ends up calling jax.interpreters.xla.pytype_aval_mappings breaks with jax >= 0.7, when it was removed. (I ran into it trying to import linear_gaussian_ssm. Would be nice to update pyproject.toml pinning jax before that or better yet, updating code so it's compatible with later jax versions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions