From 2464ea3b7f7e3dc6956f34b10f16a17924ad284f Mon Sep 17 00:00:00 2001 From: Bilal Kabas Date: Sat, 3 Aug 2024 14:09:03 +0300 Subject: [PATCH 1/3] Add corrected ifft2c and fft2c functions for k-space calculations --- data/transforms.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/data/transforms.py b/data/transforms.py index 412b47d..f591b02 100644 --- a/data/transforms.py +++ b/data/transforms.py @@ -70,6 +70,42 @@ def apply_mask(data, mask_func, seed=None): return data * mask, mask +def ifft2c(x, dim=(-2, -1)): + """ Centered 2D Inverse Fast Fourier Transform + + Args: + x (torch.Tensor): Complex valued input data containing at least 3 + dimensions: dimensions -2 & -1 are spatial dimensions. All other + dimensions are assumed to be batch dimensions. + dim (tuple): Dimensions to apply the IFFT along. Default is (-2, -1) + + Returns: + torch.Tensor: The IFFT of the input. + """ + x = torch.fft.ifftshift(x, dim=dim) + x = torch.fft.ifft2(x, dim=dim) + return torch.fft.fftshift(x, dim=dim) + + +def fft2(data, normalized=True): + """ + Apply centered 2 dimensional Fast Fourier Transform. + + Args: + data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions + -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are + assumed to be batch dimensions. + + Returns: + torch.Tensor: The FFT of the input. + """ + assert data.size(-1) == 2 + data = ifftshift(data, dim=(-3, -2)) + data = torch.fft.fft(data, 2, normalized=normalized) + data = fftshift(data, dim=(-3, -2)) + return data + + def fft2(data, normalized=True): """ Apply centered 2 dimensional Fast Fourier Transform. From 735dd61733304ccf6629e4fa1a7deaed4fb900a4 Mon Sep 17 00:00:00 2001 From: Bilal Kabas Date: Sat, 3 Aug 2024 14:10:32 +0300 Subject: [PATCH 2/3] Update Recurrent_Transformer.py --- models/Recurrent_Transformer.py | 47 ++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/models/Recurrent_Transformer.py b/models/Recurrent_Transformer.py index 9d87819..780f9ef 100644 --- a/models/Recurrent_Transformer.py +++ b/models/Recurrent_Transformer.py @@ -10,22 +10,17 @@ from torch.nn import functional as F import numpy as np -class DataConsistencyInKspace(nn.Module): - """ Create data consistency operator - - Warning: note that FFT2 (by the default of torch.fft) is applied to the last 2 axes of the input. - This method detects if the input tensor is 4-dim (2D data) or 5-dim (3D data) - and applies FFT2 to the (nx, ny) axis. - """ +class DataConsistencyInKspace(nn.Module): + """ Data consistency layer in k-space. """ def __init__(self): super(DataConsistencyInKspace, self).__init__() def forward(self, *input, **kwargs): return self.perform(*input) - - def data_consistency(self,k, k0, mask): + + def data_consistency(self, k, k0, mask): """ k - input in k-space k0 - initially sampled elements in k-space @@ -36,23 +31,33 @@ def data_consistency(self,k, k0, mask): return out def perform(self, x, k0, mask): + """ Forward pass to enforce data consistency in k-space. + + Args: + x (torch.Tensor): Input image in spatial domain (batch_size, 2, height, width). + k0 (torch.Tensor): Measured k-space data (batch_size, 2, height, width). + mask (torch.Tensor): Binary mask indicating sampled k-space locations (batch_size, 1, height, width). + + Returns: + torch.Tensor: Corrected image with the same shape as input. """ - x - input in image domain, of shape (n, 2, nx, ny[, nt]) - k0 - initially sampled elements in k-space - mask - corresponding nonzero location - """ - x = x.permute(0, 2, 3, 1) - k0 = k0.permute(0, 2, 3, 1) - mask = mask.permute(0, 2, 3, 1) + x_cx = torch.complex(x[:, 0], x[:, 1]).unsqueeze(1) + k0_cx = torch.complex(k0[:, 0], k0[:, 1]).unsqueeze(1) - k = transforms.fft2(x) + # Fourier transform + x_kspace = transforms.fft2c(x_cx) - out = self.data_consistency(k, k0, mask) - x_res = transforms.ifft2(out) + # Fill in k-space + x_kspace = self.data_consistency(x_kspace, k0_cx, mask) - x_res = x_res.permute(0, 3, 1, 2) + # Inverse Fourier transform + out = transforms.ifft2c(x_kspace) + + # Stack real and imaginary parts + out = torch.cat((out.real, out.imag), dim=1) + + return out - return x_res class RFB(nn.Module): """ From 7025fcffc82578d4042ee9ff7d7269295bba3ecb Mon Sep 17 00:00:00 2001 From: Bilal Kabas Date: Sat, 3 Aug 2024 14:26:59 +0300 Subject: [PATCH 3/3] Update transforms.py --- data/transforms.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/data/transforms.py b/data/transforms.py index f591b02..bb22bb6 100644 --- a/data/transforms.py +++ b/data/transforms.py @@ -70,40 +70,39 @@ def apply_mask(data, mask_func, seed=None): return data * mask, mask -def ifft2c(x, dim=(-2, -1)): - """ Centered 2D Inverse Fast Fourier Transform - +def fft2c(x, dim=(-2, -1)): + """ Centered 2D Fast Fourier Transform + Args: x (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions -2 & -1 are spatial dimensions. All other dimensions are assumed to be batch dimensions. - dim (tuple): Dimensions to apply the IFFT along. Default is (-2, -1) + + dim (tuple): Dimensions to apply the FFT along. Default is (-2, -1) Returns: - torch.Tensor: The IFFT of the input. + torch.Tensor: The FFT of the input. """ x = torch.fft.ifftshift(x, dim=dim) - x = torch.fft.ifft2(x, dim=dim) + x = torch.fft.fft2(x, dim=dim) return torch.fft.fftshift(x, dim=dim) -def fft2(data, normalized=True): - """ - Apply centered 2 dimensional Fast Fourier Transform. +def ifft2c(x, dim=(-2, -1)): + """ Centered 2D Inverse Fast Fourier Transform Args: - data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions - -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are - assumed to be batch dimensions. + x (torch.Tensor): Complex valued input data containing at least 3 + dimensions: dimensions -2 & -1 are spatial dimensions. All other + dimensions are assumed to be batch dimensions. + dim (tuple): Dimensions to apply the IFFT along. Default is (-2, -1) Returns: - torch.Tensor: The FFT of the input. + torch.Tensor: The IFFT of the input. """ - assert data.size(-1) == 2 - data = ifftshift(data, dim=(-3, -2)) - data = torch.fft.fft(data, 2, normalized=normalized) - data = fftshift(data, dim=(-3, -2)) - return data + x = torch.fft.ifftshift(x, dim=dim) + x = torch.fft.ifft2(x, dim=dim) + return torch.fft.fftshift(x, dim=dim) def fft2(data, normalized=True):