Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions src/caustics/lenses/func/pixelated_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,17 @@
raise ValueError(f"Invalid convolution mode: {convolution_mode}")
# Scale is distance from center of image to center of pixel on the edge
scale = fov / 2
_x_view_scale = (x - x0).view(-1) / scale
_y_view_scale = (y - y0).view(-1) / scale
deflection_angle_x = interp2d(
deflection_angle_maps[0], _x_view_scale, _y_view_scale
).reshape(x.shape)
deflection_angle_y = interp2d(
deflection_angle_maps[1], _x_view_scale, _y_view_scale
).reshape(x.shape)
return deflection_angle_x, deflection_angle_y
x = (x - x0) / scale
y = (y - y0) / scale
deflection_angle = interp2d(
deflection_angle_maps,
x,
y,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
).reshape(2, *x.shape)
return deflection_angle[0], deflection_angle[1]


def potential_pixelated_convergence(
Expand Down Expand Up @@ -322,12 +324,20 @@
potential_map = _unpad_fft(potential, n_pix)
elif convolution_mode == "conv2d":
convergence_map_flipped = convergence_map.flip((-1, -2))[None, None]
potential_map = F.conv2d(

Check warning on line 327 in src/caustics/lenses/func/pixelated_convergence.py

View workflow job for this annotation

GitHub Actions / build (3.10, ubuntu-latest)

Using padding='same' with even kernel lengths and odd dilation may require a zero-padded copy of the input be created (Triggered internally at /pytorch/aten/src/ATen/native/Convolution.cpp:1036.)

Check warning on line 327 in src/caustics/lenses/func/pixelated_convergence.py

View workflow job for this annotation

GitHub Actions / build (3.10, windows-latest)

Using padding='same' with even kernel lengths and odd dilation may require a zero-padded copy of the input be created (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\native\Convolution.cpp:1037.)

Check warning on line 327 in src/caustics/lenses/func/pixelated_convergence.py

View workflow job for this annotation

GitHub Actions / build (3.10, macOS-latest)

Using padding='same' with even kernel lengths and odd dilation may require a zero-padded copy of the input be created (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/Convolution.cpp:1037.)
potential_kernel[None, None], convergence_map_flipped, padding="same"
).squeeze() * (pixelscale**2 / torch.pi)
else:
raise ValueError(f"Invalid convolution mode: {convolution_mode}")
scale = fov / 2
return interp2d(
potential_map, (x - x0).view(-1) / scale, (y - y0).view(-1) / scale
).reshape(x.shape)
x = (x - x0) / scale
y = (y - y0) / scale
potential = interp2d(
potential_map.unsqueeze(0),
x,
y,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
).squeeze(0)
return potential
10 changes: 0 additions & 10 deletions src/caustics/lenses/multiplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
self.lenses = tuple(lenses)
for lens in self.lenses:
if lens.z_s.static:
warn(

Check warning on line 47 in src/caustics/lenses/multiplane.py

View workflow job for this annotation

GitHub Actions / build (3.10, ubuntu-latest)

Lens plane sie_2 has a static source redshift. This is now overwritten by the Multiplane (multiplane) source redshift. To prevent this warning, set the source redshift of the lens plane to be dynamic before adding to multiplane system.

Check warning on line 47 in src/caustics/lenses/multiplane.py

View workflow job for this annotation

GitHub Actions / build (3.10, ubuntu-latest)

Lens plane sie_1 has a static source redshift. This is now overwritten by the Multiplane (multiplane) source redshift. To prevent this warning, set the source redshift of the lens plane to be dynamic before adding to multiplane system.

Check warning on line 47 in src/caustics/lenses/multiplane.py

View workflow job for this annotation

GitHub Actions / build (3.10, ubuntu-latest)

Lens plane sie_0 has a static source redshift. This is now overwritten by the Multiplane (multiplane) source redshift. To prevent this warning, set the source redshift of the lens plane to be dynamic before adding to multiplane system.

Check warning on line 47 in src/caustics/lenses/multiplane.py

View workflow job for this annotation

GitHub Actions / build (3.10, windows-latest)

Lens plane sie_2 has a static source redshift. This is now overwritten by the Multiplane (multiplane) source redshift. To prevent this warning, set the source redshift of the lens plane to be dynamic before adding to multiplane system.

Check warning on line 47 in src/caustics/lenses/multiplane.py

View workflow job for this annotation

GitHub Actions / build (3.10, windows-latest)

Lens plane sie_1 has a static source redshift. This is now overwritten by the Multiplane (multiplane) source redshift. To prevent this warning, set the source redshift of the lens plane to be dynamic before adding to multiplane system.

Check warning on line 47 in src/caustics/lenses/multiplane.py

View workflow job for this annotation

GitHub Actions / build (3.10, windows-latest)

Lens plane sie_0 has a static source redshift. This is now overwritten by the Multiplane (multiplane) source redshift. To prevent this warning, set the source redshift of the lens plane to be dynamic before adding to multiplane system.

Check warning on line 47 in src/caustics/lenses/multiplane.py

View workflow job for this annotation

GitHub Actions / build (3.10, macOS-latest)

Lens plane sie_2 has a static source redshift. This is now overwritten by the Multiplane (multiplane) source redshift. To prevent this warning, set the source redshift of the lens plane to be dynamic before adding to multiplane system.

Check warning on line 47 in src/caustics/lenses/multiplane.py

View workflow job for this annotation

GitHub Actions / build (3.10, macOS-latest)

Lens plane sie_1 has a static source redshift. This is now overwritten by the Multiplane (multiplane) source redshift. To prevent this warning, set the source redshift of the lens plane to be dynamic before adding to multiplane system.

Check warning on line 47 in src/caustics/lenses/multiplane.py

View workflow job for this annotation

GitHub Actions / build (3.10, macOS-latest)

Lens plane sie_0 has a static source redshift. This is now overwritten by the Multiplane (multiplane) source redshift. To prevent this warning, set the source redshift of the lens plane to be dynamic before adding to multiplane system.
f"Lens plane {lens.name} has a static source redshift. This is now overwritten by the Multiplane ({self.name}) source redshift. To prevent this warning, set the source redshift of the lens plane to be dynamic before adding to multiplane system."
)
lens.z_s = self.z_s
Expand Down Expand Up @@ -81,7 +81,6 @@
# Compute physical position on first lens plane
D = self.cosmology.transverse_comoving_distance(z_ls[lens_planes[0]])
X, Y = x * arcsec_to_rad * D, y * arcsec_to_rad * D # fmt: skip

# Initial angles are observation angles
theta_x, theta_y = x, y

Expand Down Expand Up @@ -200,15 +199,6 @@
ray_coords=True,
)

@forward
def effective_reduced_deflection_angle(
self,
x: Tensor,
y: Tensor,
) -> tuple[Tensor, Tensor]:
bx, by = self.raytrace(x, y)
return x - bx, y - by

@forward
def surface_density(
self,
Expand Down
16 changes: 10 additions & 6 deletions src/caustics/lenses/pixelated_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import torch
from torch import Tensor
from torch.nn.functional import grid_sample
import numpy as np
from caskade import forward, Param

from ..utils import interp2d
from .base import ThinLens, CosmologyType, NameType, ZType
from . import func

Expand All @@ -17,7 +17,7 @@
_null_params = {
"x0": 0.0,
"y0": 0.0,
"convergence_map": np.logspace(0, 1, 100, dtype=np.float32).reshape(10, 10),
"convergence_map": np.random.normal(size=(10, 10)).astype(np.float32),
}

def __init__(
Expand Down Expand Up @@ -391,8 +391,12 @@
"""
fov_x = convergence_map.shape[1] * self.pixelscale
fov_y = convergence_map.shape[0] * self.pixelscale
return interp2d(
convergence_map * scale,
(x - x0).view(-1) / fov_x * 2,
(y - y0).view(-1) / fov_y * 2,
return grid_sample(

Check warning on line 394 in src/caustics/lenses/pixelated_convergence.py

View check run for this annotation

Codecov / codecov/patch

src/caustics/lenses/pixelated_convergence.py#L394

Added line #L394 was not covered by tests
convergence_map.reshape(1, 1, *convergence_map.shape) * scale,
torch.stack(
((x - x0).view(-1) / fov_x * 2, (y - y0).view(-1) / fov_y * 2), dim=1
).reshape(1, 1, -1, 2),
mode="bilinear",
padding_mode="zeros",
align_corners=False,
).reshape(x.shape)
18 changes: 11 additions & 7 deletions src/caustics/lenses/pixelated_deflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import numpy as np
from caskade import forward, Param

from ..utils import interp2d
from .base import ThinLens, CosmologyType, NameType, ZType
from ..utils import interp2d

__all__ = ("PixelatedDeflection",)

Expand Down Expand Up @@ -141,13 +141,17 @@ def reduced_deflection_angle(
"""
fov_x = deflection_map.shape[2] * pixelscale
fov_y = deflection_map.shape[1] * pixelscale
shape = x.shape
x = (x - x0).view(-1) / fov_x * 2
y = (y - y0).view(-1) / fov_y * 2
return (
interp2d(deflection_map[0], x, y).reshape(shape),
interp2d(deflection_map[1], x, y).reshape(shape),
x = (x - x0) * (2 / fov_x)
y = (y - y0) * (2 / fov_y)
deflection_angle = interp2d(
deflection_map,
x,
y,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
)
return deflection_angle[0], deflection_angle[1]

@forward
def potential(self, x, y, **kwargs):
Expand Down
17 changes: 11 additions & 6 deletions src/caustics/light/pixelated.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# mypy: disable-error-code="union-attr"
from typing import Optional, Union, Annotated

import torch
from torch import Tensor
from torch.nn.functional import grid_sample
from caskade import forward, Param

from ..utils import interp2d
from .base import Source, NameType

__all__ = ("Pixelated",)
Expand Down Expand Up @@ -173,9 +174,13 @@ def brightness(
"""
fov_x = pixelscale * image.shape[1]
fov_y = pixelscale * image.shape[0]
return interp2d(
image * scale,
(x - x0).view(-1) / fov_x * 2,
(y - y0).view(-1) / fov_y * 2, # make coordinates bounds at half the fov
shape = x.shape
x = (x - x0).view(-1) / fov_x * 2
y = (y - y0).view(-1) / fov_y * 2
return grid_sample(
image.reshape(1, 1, *image.shape) * scale,
torch.stack((x, y), dim=1).reshape(1, 1, -1, 2),
mode="bilinear",
padding_mode=padding_mode,
).reshape(x.shape)
align_corners=False,
).reshape(shape)
20 changes: 13 additions & 7 deletions src/caustics/light/pixelated_time.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# mypy: disable-error-code="operator,union-attr"
from typing import Optional, Union, Annotated

import torch
from torch import Tensor
from torch.nn.functional import grid_sample
from caskade import forward, Param

from ..utils import interp3d
from .base import Source, NameType

__all__ = ("PixelatedTime",)
Expand Down Expand Up @@ -188,9 +189,14 @@ def brightness(
"""
fov_x = self.pixelscale * cube.shape[2]
fov_y = self.pixelscale * cube.shape[1]
return interp3d(
cube * scale,
(x - x0).view(-1) / fov_x * 2,
(y - y0).view(-1) / fov_y * 2,
(t - self.t_end / 2).view(-1) / self.t_end * 2,
).reshape(x.shape)
shape = x.shape
x = (x - x0).view(-1) / fov_x * 2
y = (y - y0).view(-1) / fov_y * 2
t = (t - self.t_end / 2).view(-1) / self.t_end * 2
return grid_sample(
cube.reshape(1, 1, *cube.shape),
torch.stack((x, y, t), dim=1).reshape(1, 1, 1, -1, 3),
mode="bilinear",
padding_mode="zeros",
align_corners=False,
).reshape(shape)
139 changes: 82 additions & 57 deletions src/caustics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from torch import Tensor
from torch.func import jacfwd
from torch.nn.functional import grid_sample
from scipy.special import roots_legendre

from .constants import rad_to_deg, deg_to_rad
Expand Down Expand Up @@ -873,83 +874,110 @@
im: Tensor,
x: Tensor,
y: Tensor,
method: Literal["linear", "nearest"] = "linear",
padding_mode: str = "zeros",
mode: Literal["bilinear", "nearest"] = "bilinear",
padding_mode: Literal["zeros", "border"] = "zeros",
align_corners: bool = False,
) -> Tensor:
"""
Interpolates a 2D image at specified coordinates. Similar to
`torch.nn.functional.grid_sample` with `align_corners=False`.
Sample a 2-D image at arbitrary normalized coordinates.

Parameters
----------
im: Tensor
A 2D tensor representing the image.
x: Tensor
A 0D or 1D tensor of x coordinates at which to interpolate.
y: Tensor
A 0D or 1D tensor of y coordinates at which to interpolate.
method: (str, optional)
Interpolation method. Either 'nearest' or 'linear'. Defaults to
'linear'.
padding_mode: (str, optional)
Defines the padding mode when out-of-bound indices are encountered.
Either 'zeros', 'clamp', or 'extrapolate'. Defaults to 'zeros' which
fills padded coordinates with zeros. The 'clamp' mode clamps the
coordinates to the image boundaries (essentially taking the border
values out to infinity). The 'extrapolate' mode extrapolates the outer
linear interpolation beyond the last pixel boundary.
im : Tensor
Input image of shape ``(C, H, W)`` where ``C`` is the number of channels,
``H`` the height and ``W`` the width. The tensor must be 3-dimensional.
x : Tensor
0-D or 1-D tensor containing the *x* coordinates in the normalized device
coordinate (NDC) system, i.e. ``−1`` maps to the center of the leftmost
pixel and ``+1`` to the center of the rightmost pixel.
y : Tensor
0-D or 1-D tensor with the *y* coordinates in NDC. Must have the same shape
as ``x``.
mode : {"bilinear", "nearest"}, default "bilinear"
Interpolation algorithm. Behaves exactly like the ``mode`` argument of
``grid_sample``.
padding_mode : {"zeros", "border"}, default "zeros"
How coordinates that fall outside the image are handled:
• ``"zeros"`` — out-of-range samples are filled with zeros.
• ``"border"`` — coordinates are clamped to the valid range (border values
are repeated to infinity).
align_corners : bool, default ``False``
Forwarded to ``grid_sample``. When ``True``, the extrema of the NDC system
(−1 and +1) are mapped exactly to the image corners; otherwise they map to
the centers of the corner pixels.

Raises
------
ValueError
If `im` is not a 2D tensor.
If ``im`` is not 3-D or if its shape is not ``(C, H, W)``.
ValueError
If `x` is not a 0D or 1D tensor.
ValueError
If `y` is not a 0D or 1D tensor.
ValueError
If `padding_mode` is not 'extrapolate' or 'zeros'.
ValueError
If `method` is not 'nearest' or 'linear'.
If ``mode`` or ``padding_mode`` is not one of the supported options.

Returns
-------
Tensor
Tensor with the same shape as `x` and `y` containing the interpolated
values.
"""
if im.ndim != 2:
raise ValueError(f"im must be 2D (received {im.ndim}D tensor)")
if x.ndim > 1:
raise ValueError(f"x must be 0 or 1D (received {x.ndim}D tensor)")
if y.ndim > 1:
raise ValueError(f"y must be 0 or 1D (received {y.ndim}D tensor)")
if padding_mode not in ["extrapolate", "clamp", "zeros"]:

if im.ndim != 3:
raise ValueError(f"im must be 3D (received {im.ndim}D tensor)")

Check warning on line 924 in src/caustics/utils.py

View check run for this annotation

Codecov / codecov/patch

src/caustics/utils.py#L924

Added line #L924 was not covered by tests
if padding_mode not in ["border", "zeros"]:
raise ValueError(f"{padding_mode} is not a valid padding mode")

if padding_mode == "clamp":
x = x.clamp(-1, 1)
y = y.clamp(-1, 1)
else:
idxs_out_of_bounds = (y < -1) | (y > 1) | (x < -1) | (x > 1)
shape = x.shape
x = x.flatten()
y = y.flatten()
if (
not (x.requires_grad or y.requires_grad)
and torch.autograd.forward_ad._current_level == -1
):
return grid_sample(
im.unsqueeze(0),
torch.stack((x, y), dim=1).unsqueeze(0).unsqueeze(0),
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
).reshape(im.shape[0], *shape)

# Convert coordinates to pixel indices
h, w = im.shape
x = 0.5 * ((x + 1) * w - 1)
y = 0.5 * ((y + 1) * h - 1)
_, h, w = im.shape
if align_corners:
x = 0.5 * ((x + 1) * (w - 1))
y = 0.5 * ((y + 1) * (h - 1))

Check warning on line 947 in src/caustics/utils.py

View check run for this annotation

Codecov / codecov/patch

src/caustics/utils.py#L946-L947

Added lines #L946 - L947 were not covered by tests
else:
x = 0.5 * ((x + 1) * w - 1)
y = 0.5 * ((y + 1) * h - 1)

if method == "nearest":
result = im[y.round().long().clamp(0, h - 1), x.round().long().clamp(0, w - 1)]
elif method == "linear":
x0 = x.floor().long().clamp(0, w - 2)
y0 = y.floor().long().clamp(0, h - 2)
if mode == "nearest":
result = im[

Check warning on line 953 in src/caustics/utils.py

View check run for this annotation

Codecov / codecov/patch

src/caustics/utils.py#L953

Added line #L953 was not covered by tests
..., y.round().long().clamp(0, h - 1), x.round().long().clamp(0, w - 1)
]
if padding_mode == "zeros":
valid = ((x.abs() <= 1) & (y.abs() <= 1)).float()
result = result * valid

Check warning on line 958 in src/caustics/utils.py

View check run for this annotation

Codecov / codecov/patch

src/caustics/utils.py#L956-L958

Added lines #L956 - L958 were not covered by tests
elif mode == "bilinear":
x = x.clamp(-1, w)
y = y.clamp(-1, h)
x0 = x.floor().long()
y0 = y.floor().long()
x1 = x0 + 1
y1 = y0 + 1

fa = im[y0, x0]
fb = im[y1, x0]
fc = im[y0, x1]
fd = im[y1, x1]
def get_val(ix, iy):
ix_clip = ix.clamp(0, w - 1)
iy_clip = iy.clamp(0, h - 1)
val = im[..., iy_clip, ix_clip]
if padding_mode == "zeros":
valid = (ix >= 0) & (ix < w) & (iy >= 0) & (iy < h)
return val * valid.float()
elif padding_mode == "border":
return val

Check warning on line 975 in src/caustics/utils.py

View check run for this annotation

Codecov / codecov/patch

src/caustics/utils.py#L974-L975

Added lines #L974 - L975 were not covered by tests

fa = get_val(x0, y0)
fb = get_val(x0, y1)
fc = get_val(x1, y0)
fd = get_val(x1, y1)

dx1 = x1 - x
dx0 = x - x0
Expand All @@ -958,12 +986,9 @@

result = fa * dx1 * dy1 + fb * dx1 * dy0 + fc * dx0 * dy1 + fd * dx0 * dy0 # fmt: skip
else:
raise ValueError(f"{method} is not a valid interpolation method")
raise ValueError(f"{mode} is not a valid interpolation method")

Check warning on line 989 in src/caustics/utils.py

View check run for this annotation

Codecov / codecov/patch

src/caustics/utils.py#L989

Added line #L989 was not covered by tests

if padding_mode == "zeros": # else padding_mode == "extrapolate"
result = torch.where(idxs_out_of_bounds, torch.zeros_like(result), result)

return result
return result.reshape(im.shape[0], *shape)


def interp3d(
Expand Down
Loading
Loading