Skip to content

feat(modconv): add conditional modulated convolutions for MRI reconstruction#309

Open
georgeyiasemis wants to merge 18 commits into
mainfrom
modulated-convolution
Open

feat(modconv): add conditional modulated convolutions for MRI reconstruction#309
georgeyiasemis wants to merge 18 commits into
mainfrom
modulated-convolution

Conversation

@georgeyiasemis

Copy link
Copy Markdown
Contributor

Description

This PR ports conditional learned reconstruction into DIRECT by introducing modulated convolutions — convolutional layers whose weights are adapted by small MLPs from acquisition metadata (acceleration factor, ACS fraction, field strength).

The implementation follows Conditional Learned Reconstruction for Medical Imaging (Moriakov et al., MIDL 2026; OpenReview PDF). A single trained backbone can be conditioned at inference time on the actual undersampling pattern instead of training separate models per acceleration.

What's New

Modulated convolution package (direct/nn/conv/modulated/)

  • ModConv2d / ModConv3d and transposed variants with modulation types: NONE, FEATURES, FULL, PARTIAL_IN, PARTIAL_OUT, SUM
  • auxiliary_data.py — registry-based auxiliary feature pipeline:
    • prepare_auxiliary_data(data, cfg) builds (batch, aux_in_features) conditioning vectors
    • register_auxiliary_feature() for custom conditioning channels
    • Default features: acceleration, center_fraction, field_strength
  • AdaIN2d / AdaIN3d adaptive instance normalization modules (direct/nn/adain/)

Auxiliary conditioning pipeline

  • CreateSamplingMask requests return_acceleration from BaseMaskFunc masks
  • Batch keys: acceleration, center_fraction
  • MRIModelEngine._attach_auxiliary_data() attaches auxiliary_data each iteration (supervised, SSL, JSSL)
  • FastMRI datasets expose field_strength (1.5 T / 3.0 T from filename)

Triangular acceleration sampling

  • linear_range: true in masking config samples acceleration from a triangular distribution (paper Section 4.3.2)
  • direct/utils/distributions.pytriangular_distribution()
  • BaseMaskFunc returns sampled acceleration metadata via return_acceleration

Model support

Modulated convolutions wired end-to-end through conv-based models:

Model Config field Engine passes auxiliary_data
vSHARP (2D/3D) conv_modulation Yes
VarNet conv_modulation Yes
XPDNet conv_modulation Yes
KIKINet conv_modulation Yes
LPD conv_modulation Yes
JointICNet conv_modulation Yes
IterDualNet conv_modulation Yes
Unet2d conv_modulation Yes
Conv2d, DIDN, MWCNN modulation (layer-level) Via parent model

Shared config fields: conv_modulation, aux_in_features, auxiliary_features, log_aux, fc_hidden_features, fc_activation, fc_groups, num_weights

Documentation & configs

  • projects/modulated_convolution/README.rst — tutorial with paper figures and section references
  • Example configs for vSHARP (knee/prostate) and VarNet (prostate) with triangular acceleration sampling

What's Changed

Masking (direct/common/subsample.py)

  • Unified _draw_acceleration_value() / _draw_acceleration_pair() for uniform_range and linear_range
  • Centralized return_acceleration in BaseMaskFunc.__call__
  • Bug fix: equispaced masks early-return when ACS already covers target acceleration (high <= 0 guard)
  • Bug fix: CartesianMagicMaskFunc integer center line counts no longer capped by min(1/accel, cf); ACS-only edge case handled
  • Bug fix: corrected len(center_fractions) != len(accelerations) validation

U-Net backbone

  • Unet2d / Unet3d / NormUnet* swap Conv2dModConv2d when conv_modulation != NONE
  • Modulated transposed convolutions in decoder blocks

Paper Mapping

Paper (Sec. / Eq.) DIRECT
Eq. 6 — modulated convolution ModConv2d with FEATURES type
Eq. 7 — z = log([R, 100·r_acs]) log_aux: true + acceleration / center_fraction batch keys
Sec. 4.3.2 — triangular R ∈ [4,16] linear_range: true in masking config
MOD S/M/L — MLP [32,8] / [32,16] / [32,32] fc_hidden_features in model config

Related

georgeyiasemis and others added 8 commits June 15, 2026 16:23
- Add ModConv2d/3d and ModConvTranspose2d/3d with multiple modulation types
- Add AdaIN2d/AdaIN3d adaptive instance normalization modules
- Integrate modulated convolutions into UNet2d/3d and VSharpNet/3D
- Add IntOrTuple type alias
- Fix SyntaxWarnings from invalid escape sequences in docstrings
- Fix DeprecationWarning in engine.py (numpy/torch interop)
- Fix UserWarning in gradloss_test.py (tensor from list of ndarrays)

Co-authored-by: Cursor <cursoragent@cursor.com>
Wire conditional modulation through VarNet, KIKINet, JointICNet, IterDualNet,
and LPD, plus Conv2d, DIDN, and MWCNN backbones. Add modulated conv unit tests,
publication references in docstrings, and modulation MLP layout fix. Apply
black/isort formatting across the codebase.

Co-authored-by: Cursor <cursoragent@cursor.com>
Introduce triangular acceleration sampling for training configs, return
sampled acceleration metadata from BaseMaskFunc, and fix integer center
line counts and ACS-only edge cases in MagicMaskFunc.

Co-authored-by: Cursor <cursoragent@cursor.com>
Reorganize modulated convolution layers into a subpackage and add a
central auxiliary-data registry used by conditional reconstruction models.

Co-authored-by: Cursor <cursoragent@cursor.com>
Expose sampled acceleration and center fraction in batches, attach
auxiliary tensors in MRIModelEngine, and register field strength in
FastMRI datasets for configurable conditioning features.

Co-authored-by: Cursor <cursoragent@cursor.com>
…odels

Enable auxiliary conditioning in VarNet, vSHARP, XPDNet, KIKINet, LPD,
JointICNet, IterDualNet, and Unet2d engines and align model configs
with the shared conv_modulation settings.

Co-authored-by: Cursor <cursoragent@cursor.com>
Add a project README and knee/prostate training configs demonstrating
feature-based modulated convolutions with triangular acceleration.

Co-authored-by: Cursor <cursoragent@cursor.com>
Embed architecture and result figures from the MIDL paper, map config
fields to Eq. 7 and Section 3.1, and link to the OpenReview PDF.

Co-authored-by: Cursor <cursoragent@cursor.com>
@codecov

codecov Bot commented Jun 19, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 79.71768% with 273 lines in your changes missing coverage. Please review.
✅ Project coverage is 84.71%. Comparing base (f83b7bb) to head (2efab9a).

Files with missing lines Patch % Lines
direct/nn/conv/modulated/modulated_conv.py 75.11% 106 Missing ⚠️
direct/nn/adain/adain.py 18.42% 62 Missing ⚠️
direct/nn/unet/unet_3d.py 71.87% 27 Missing ⚠️
direct/nn/unet/unet_2d.py 81.65% 20 Missing ⚠️
direct/nn/vsharp/vsharp.py 76.56% 15 Missing ⚠️
direct/common/subsample.py 84.21% 12 Missing ⚠️
direct/nn/crossdomain/multicoil.py 60.00% 4 Missing ⚠️
direct/nn/didn/didn.py 96.03% 4 Missing ⚠️
direct/nn/iterdualnet/iterdualnet.py 83.33% 3 Missing ⚠️
direct/nn/jointicnet/jointicnet.py 81.25% 3 Missing ⚠️
... and 10 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #309      +/-   ##
==========================================
- Coverage   85.73%   84.71%   -1.02%     
==========================================
  Files         103      110       +7     
  Lines        9041    10133    +1092     
==========================================
+ Hits         7751     8584     +833     
- Misses       1290     1549     +259     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

georgeyiasemis and others added 10 commits June 19, 2026 17:45
Add typed modulated-conv factories, normalize optional modulation types, and apply formatting fixes so type checking and tests pass on the modulated-convolution branch.

Co-authored-by: Cursor <cursoragent@cursor.com>
Add .prospector.yml so Codacy skips conflicting pydocstyle rules (D203/D213 etc.) that Ruff does not enforce; fix unused loop variable in DIDN ReconBlock.

Co-authored-by: Cursor <cursoragent@cursor.com>
Disable pydocstyle and noisy pylint limits that conflict with numpydoc-style
docstrings and typical model __init__ signatures, and fix an unused loop variable.

Co-authored-by: Cursor <cursoragent@cursor.com>
Assign acceleration fields in the same branch where they are unpacked, and
wrap long DIDN bibliography lines to satisfy line-length checks.

Co-authored-by: Cursor <cursoragent@cursor.com>
Add factory, modulation-type, validation, UNet/DIDN, auxiliary-data, and
CreateSamplingMask acceleration tests to improve coverage on new modconv code.

Co-authored-by: Cursor <cursoragent@cursor.com>
Restore 50 files that differed from main only due to Black reformatting,
keeping the branch focused on modulated-convolution changes.

Co-authored-by: Cursor <cursoragent@cursor.com>
Build conditioning vectors when image_unet_norm_type is ADAIN even if
conv_modulation is NONE, so AdaIN denoisers receive the acquisition vector.

Co-authored-by: Cursor <cursoragent@cursor.com>
Add fixed training configs under projects/modulated_convolution/configs/vsharp/
with generator and smoke-test scripts; reorganize README and replace flat yaml paths.

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
These were development utilities and should not live under temp/ in git.

Co-authored-by: Cursor <cursoragent@cursor.com>
@georgeyiasemis

Copy link
Copy Markdown
Contributor Author

@jonasteuwen this is ready for review!

@jonasteuwen jonasteuwen left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some remarks

Comment on lines +148 to +150
if uniform_range and linear_range:
raise ValueError("uniform_range and linear_range are mutually exclusive.")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not a range type? range.UNIFORM and range.LINEAR

Comment on lines 134 to 147
@@ -139,36 +145,64 @@ def __init__(
along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each
slice along the fourth last dimension. Default: MaskFuncMode.STATIC.
"""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can start with GoogleDoc here, is it possible to mix formats with sphinx? I think so, so we can translate it slowly.

Comment on lines 30 to +31
uniform_range: bool = False
linear_range: bool = False

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not a RangeMode

Comment thread direct/data/datasets.py
Comment on lines +414 to +415
sample["field_strength"] = np.array([3.0]) if "30T" in str(sample.get("filename", "")) else np.array([1.5])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems very specific to a dataset, shouldn't that be in a subclass of this? Also, these are not the only field strengths we have

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah - it's FastMRI, nevermind

@@ -0,0 +1,231 @@
# Copyright 2025 AI for Oncology Research Group. All Rights Reserved.
#

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New files should be in 2026?

Comment on lines +1 to +2
# Copyright 2025 AI for Oncology Research Group. All Rights Reserved.
#

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

Comment on lines +64 to +65
) -> ModConv2d:
"""Construct :class:`ModConv2d` with typed modulation arguments."""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refer to the base class for the docstring. Also, isn't typically the functional interface defined first and then the class type?

Comment on lines +38 to +40
fc_groups: int = 1
fc_activation: str = "sigmoid"
num_weights: Optional[int] = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not ActivationType? Surely you have defined that somewhere already

Comment on lines +47 to +49
fc_groups: int = 1
fc_activation: str = "sigmoid"
num_weights: Optional[int] = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants