Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions glass/_array_api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
from __future__ import annotations

import functools
from typing import TYPE_CHECKING, Any
import typing

import array_api_compat

if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from collections.abc import Callable, Sequence
from types import ModuleType
from typing import Any

import numpy as np

Expand Down
4 changes: 2 additions & 2 deletions glass/_rng.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

from __future__ import annotations

from typing import TYPE_CHECKING
import typing

if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from types import ModuleType

from glass._types import DTypeLike, FloatArray, IntArray, UnifiedGenerator
Expand Down
13 changes: 7 additions & 6 deletions glass/_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import TYPE_CHECKING, Any
import typing
from typing import Any

if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from collections.abc import Sequence
from typing import ParamSpec, TypeAlias, TypeVar
from typing import TypeAlias

import jaxtyping
import numpy as np
Expand All @@ -13,9 +14,9 @@
import glass.jax
from glass import _rng

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
P = typing.ParamSpec("P")
R = typing.TypeVar("R")
T = typing.TypeVar("T")

AnyArray: TypeAlias = np.typing.NDArray[Any] | jaxtyping.Array | Array
ComplexArray: TypeAlias = np.typing.NDArray[np.complex128] | jaxtyping.Array | Array
Expand Down
4 changes: 2 additions & 2 deletions glass/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from __future__ import annotations

import typing
import warnings
from typing import TYPE_CHECKING

import array_api_compat
import array_api_extra as xpx

if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from glass._types import FloatArray


Expand Down
13 changes: 8 additions & 5 deletions glass/arraytools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from __future__ import annotations

from typing import TYPE_CHECKING
import typing

import array_api_compat
import array_api_extra as xpx

from glass._array_api_utils import xp_additions as uxpx

if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from types import ModuleType

from glass._types import AnyArray, FloatArray, IntArray
Expand Down Expand Up @@ -175,9 +175,12 @@ def trapezoid_product(
x: FloatArray
x, _ = f
for x_, _ in ff:
x = xpx.union1d( # ty: ignore[invalid-assignment]
x[(x >= x_[0]) & (x <= x_[-1])],
x_[(x_ >= x[0]) & (x_ <= x[-1])],
x = typing.cast(
"FloatArray",
xpx.union1d(
x[(x >= x_[0]) & (x <= x_[-1])],
x_[(x_ >= x[0]) & (x_ <= x[-1])],
),
)
y = uxpx.interp(x, *f)
for f_ in ff:
Expand Down
4 changes: 2 additions & 2 deletions glass/cosmology.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Module for cosmology.api utilities."""

from typing import Protocol
import typing

import cosmology.api

Expand All @@ -18,6 +18,6 @@ class Cosmology(
cosmology.api.HasOmegaM0[AnyArray], # ty: ignore[invalid-type-arguments]
cosmology.api.HasOmegaM[AnyArray, AnyArray], # ty: ignore[invalid-type-arguments]
cosmology.api.HasTransverseComovingDistance[AnyArray, AnyArray], # ty: ignore[invalid-type-arguments]
Protocol,
typing.Protocol,
):
"""Cosmology protocol for GLASS."""
12 changes: 6 additions & 6 deletions glass/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import itertools
import math
import sys
import typing
import warnings
from collections.abc import Sequence
from typing import TYPE_CHECKING

import numpy as np
import transformcl
Expand All @@ -23,7 +23,7 @@
from glass import _rng
from glass._array_api_utils import xp_additions as uxpx

if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
from types import ModuleType
from typing import Literal
Expand All @@ -41,7 +41,7 @@
if sys.version_info >= (3, 13):
from warnings import deprecated
else:
if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from glass._types import P, R

def deprecated(msg: str, /) -> Callable[[Callable[P, R]], Callable[P, R]]:
Expand Down Expand Up @@ -226,7 +226,7 @@ def cls2cov(
cov = xpx.at(cov)[:n, i].set(cl)
cov = xpx.at(cov)[n:, i].set(0.0)
cov /= 2
yield cov # ty: ignore[invalid-yield]
yield typing.cast("FloatArray", cov)


def discretized_cls(
Expand Down Expand Up @@ -551,7 +551,7 @@ def getcl(
cl = cl[: lmax + 1]
else:
cl = xpx.pad(cl, (0, lmax + 1 - cl.shape[0]))
return cl # ty: ignore[invalid-return-type]
return typing.cast("FloatArray", cl)


def enumerate_spectra(
Expand Down Expand Up @@ -1023,7 +1023,7 @@ def cov_from_spectra(
cov = xpx.at(cov)[:size, i, j].set(cl_flat[:size])
cov = xpx.at(cov)[:size, j, i].set(cl_flat[:size])

return cov # ty: ignore[invalid-return-type]
return typing.cast("AnyArray", cov)


def check_posdef_spectra(spectra: AngularPowerSpectra) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions glass/galaxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from __future__ import annotations

import math
import typing
import warnings
from typing import TYPE_CHECKING

import array_api_compat
import array_api_extra as xpx
Expand All @@ -32,7 +32,7 @@
from glass import _rng
from glass._array_api_utils import xp_additions as uxpx

if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from types import ModuleType

from glass._types import FloatArray, UnifiedGenerator
Expand Down
6 changes: 3 additions & 3 deletions glass/grf/_core.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Protocol
import typing

import transformcl

if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from types import NotImplementedType

from glass._types import AnyArray


class Transformation(Protocol):
class Transformation(typing.Protocol):
"""Protocol for transformations of Gaussian random fields."""

def __call__(self, x: AnyArray, var: float, /) -> AnyArray:
Expand Down
4 changes: 2 additions & 2 deletions glass/grf/_solver.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING
import typing

import numpy as np
from transformcl import cltocorr, corrtocl

import glass.grf

if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from glass._types import AnyArray


Expand Down
4 changes: 2 additions & 2 deletions glass/grf/_transformations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import typing
from dataclasses import dataclass
from typing import TYPE_CHECKING

if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from types import NotImplementedType

from glass._types import AnyArray
Expand Down
4 changes: 2 additions & 2 deletions glass/harmonics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from __future__ import annotations

from typing import TYPE_CHECKING
import typing

import array_api_compat

if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from glass._types import ComplexArray, FloatArray


Expand Down
4 changes: 2 additions & 2 deletions glass/healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from __future__ import annotations

import typing
from collections.abc import Sequence
from typing import TYPE_CHECKING

import healpix
import healpy
Expand All @@ -14,7 +14,7 @@
import glass._array_api_utils as _utils
from glass import _rng

if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from types import ModuleType

from glass._types import ComplexArray, DTypeLike, FloatArray, IntArray
Expand Down
4 changes: 2 additions & 2 deletions glass/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

import math
import threading
from typing import TYPE_CHECKING
import typing

import jax.dtypes
import jax.numpy as jnp
import jax.random
import jax.scipy
import jax.typing

if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from typing import Self

from jaxtyping import PRNGKeyArray
Expand Down
Loading
Loading