Skip to content

Spatially varying PSF #269

@ConnorStoneAstro

Description

@ConnorStoneAstro

Is your feature request related to a problem? Please describe.

PSFs are known to vary across an image. It would be very useful to have a PSF model that naturally depended on position.

Describe the solution you'd like

I have heard that one can use a PSF expressed as a basis set, convolve with each of the basis elements and then combine using certain weights to effectively have a spatially varying PSF convolution.

Describe alternatives you've considered

Using caskade functional parameter relations it is possible to make a PSF depend on position, but a new PSF model object is needed for each position.

Additional context

The Following code written by Nicolas Payot achieves this, though I haven't examined it in detail.

class PSFex(nn.Module):
    """
    Convolve an image with a spatially-varying PSFEx model whose basis
    images are the monomials' coefficients.

    Instantiate once, then call many times:

        out = psf_layer(img_crop, x0_pix, y0_pix)

    Parameters
    ----------
    psf_path  : str
    device    : 'cuda' | 'cpu'
    dtype     : torch.dtype   (float32 or float64)
    """

    def __init__(self, psf_path, device='cuda', dtype=torch.float32):
        super().__init__()
        dev = torch.device(device)

        self.des_psfex = galsim.des.DES_PSFEx(psf_path)

        # 1. polynomial metadata
        self.fit_order = int(self.des_psfex.fit_order)
        self.x0, self.y0 = float(self.des_psfex.x_zero), float(self.des_psfex.y_zero)
        self.xs, self.ys = float(self.des_psfex.x_scale), float(self.des_psfex.y_scale)

        # 2. exponent list  (K == #basis images)
        pairs = _poly_pairs(self.fit_order)
        if len(pairs) != len(self.des_psfex.basis):
            raise ValueError(
                f"PSFEx file has {len(self.des_psfex.basis)} basis images but "
                f"{len(pairs)} polynomial terms for fit_order={self.fit_order}."
            )
        self.register_buffer('poly_pairs',
                             _np_to_tensor(pairs, dtype=torch.int16, device=dev))  # (K,2)

        # 3. basis cube  (flip for true convolution)
        basis_np = self.des_psfex.basis  # (K,pH,pW)  big-endian FITS
        basis = _np_to_tensor(basis_np, dtype=dtype, device=dev)
        self.register_buffer('basis',
                             basis.flip(-1, -2).unsqueeze(1))  # (K,1,pH,pW)
        pH, pW = basis.shape[-2:]
        self.pad = (pW // 2, pH // 2)  # same padding

    # ------------------------------------------------------------------
    # forward
    # ------------------------------------------------------------------
    def forward(self, image, x0_pix: int, y0_pix: int):
        """
        Parameters
        ----------
        image   : (H,W)  or  (B,1,H,W) tensor
        x0_pix  : column index (in full frame) of image[0,0]
        y0_pix  : row    index (in full frame) of image[0,0]
        """
        original_shape = image.shape
        if image.ndim == 2:
            image = image.unsqueeze(0).unsqueeze(0)  # (1,1,H,W)
        elif not (image.ndim == 4 and image.shape[1] == 1):
            raise ValueError("image must be (H,W) or (B,1,H,W)")

        B, _, H, W = image.shape
        dtype, dev = image.dtype, image.device

        # ---- 1. scaled coordinate grids for this crop
        yy, xx = torch.meshgrid(
            torch.arange(H, device=dev, dtype=dtype) + y0_pix,
            torch.arange(W, device=dev, dtype=dtype) + x0_pix,
            indexing='ij'
        )
        xt = (xx - self.x0) / self.xs
        yt = (yy - self.y0) / self.ys

        # ---- 2. pre-compute powers  xt^k , yt^k  up to fit_order
        max_d = self.fit_order
        xt_p = [torch.ones_like(xt)]
        yt_p = [torch.ones_like(yt)]
        for _ in range(max_d):
            xt_p.append(xt_p[-1] * xt)
            yt_p.append(yt_p[-1] * yt)

        xt_stack = torch.stack(xt_p)  # (d+1, H, W)
        yt_stack = torch.stack(yt_p)  # (d+1, H, W)

        # ---- 3. coefficient maps  c_k(x,y)  =  xt^nx * yt^ny
        nx = self.poly_pairs[:, 0].long()  # (K,)
        ny = self.poly_pairs[:, 1].long()  # (K,)
        coef_maps = xt_stack[nx] * yt_stack[ny]  # (K, H, W)

        # ---- 4. stationary convolutions and weighted sum
        filtered = F.conv2d(image, self.basis, padding=self.pad)  # (B,K,H,W)
        out = torch.sum(filtered * coef_maps.unsqueeze(0), dim=1, keepdim=True)

        return out.squeeze(0).squeeze(0) if original_shape == (H, W) else out

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions