feat(modconv): add conditional modulated convolutions for MRI reconstruction#309
feat(modconv): add conditional modulated convolutions for MRI reconstruction#309georgeyiasemis wants to merge 18 commits into
Conversation
- Add ModConv2d/3d and ModConvTranspose2d/3d with multiple modulation types - Add AdaIN2d/AdaIN3d adaptive instance normalization modules - Integrate modulated convolutions into UNet2d/3d and VSharpNet/3D - Add IntOrTuple type alias - Fix SyntaxWarnings from invalid escape sequences in docstrings - Fix DeprecationWarning in engine.py (numpy/torch interop) - Fix UserWarning in gradloss_test.py (tensor from list of ndarrays) Co-authored-by: Cursor <cursoragent@cursor.com>
Wire conditional modulation through VarNet, KIKINet, JointICNet, IterDualNet, and LPD, plus Conv2d, DIDN, and MWCNN backbones. Add modulated conv unit tests, publication references in docstrings, and modulation MLP layout fix. Apply black/isort formatting across the codebase. Co-authored-by: Cursor <cursoragent@cursor.com>
Introduce triangular acceleration sampling for training configs, return sampled acceleration metadata from BaseMaskFunc, and fix integer center line counts and ACS-only edge cases in MagicMaskFunc. Co-authored-by: Cursor <cursoragent@cursor.com>
Reorganize modulated convolution layers into a subpackage and add a central auxiliary-data registry used by conditional reconstruction models. Co-authored-by: Cursor <cursoragent@cursor.com>
Expose sampled acceleration and center fraction in batches, attach auxiliary tensors in MRIModelEngine, and register field strength in FastMRI datasets for configurable conditioning features. Co-authored-by: Cursor <cursoragent@cursor.com>
…odels Enable auxiliary conditioning in VarNet, vSHARP, XPDNet, KIKINet, LPD, JointICNet, IterDualNet, and Unet2d engines and align model configs with the shared conv_modulation settings. Co-authored-by: Cursor <cursoragent@cursor.com>
Add a project README and knee/prostate training configs demonstrating feature-based modulated convolutions with triangular acceleration. Co-authored-by: Cursor <cursoragent@cursor.com>
Embed architecture and result figures from the MIDL paper, map config fields to Eq. 7 and Section 3.1, and link to the OpenReview PDF. Co-authored-by: Cursor <cursoragent@cursor.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #309 +/- ##
==========================================
- Coverage 85.73% 84.71% -1.02%
==========================================
Files 103 110 +7
Lines 9041 10133 +1092
==========================================
+ Hits 7751 8584 +833
- Misses 1290 1549 +259 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Add typed modulated-conv factories, normalize optional modulation types, and apply formatting fixes so type checking and tests pass on the modulated-convolution branch. Co-authored-by: Cursor <cursoragent@cursor.com>
Add .prospector.yml so Codacy skips conflicting pydocstyle rules (D203/D213 etc.) that Ruff does not enforce; fix unused loop variable in DIDN ReconBlock. Co-authored-by: Cursor <cursoragent@cursor.com>
Disable pydocstyle and noisy pylint limits that conflict with numpydoc-style docstrings and typical model __init__ signatures, and fix an unused loop variable. Co-authored-by: Cursor <cursoragent@cursor.com>
Assign acceleration fields in the same branch where they are unpacked, and wrap long DIDN bibliography lines to satisfy line-length checks. Co-authored-by: Cursor <cursoragent@cursor.com>
Add factory, modulation-type, validation, UNet/DIDN, auxiliary-data, and CreateSamplingMask acceleration tests to improve coverage on new modconv code. Co-authored-by: Cursor <cursoragent@cursor.com>
Restore 50 files that differed from main only due to Black reformatting, keeping the branch focused on modulated-convolution changes. Co-authored-by: Cursor <cursoragent@cursor.com>
Build conditioning vectors when image_unet_norm_type is ADAIN even if conv_modulation is NONE, so AdaIN denoisers receive the acquisition vector. Co-authored-by: Cursor <cursoragent@cursor.com>
Add fixed training configs under projects/modulated_convolution/configs/vsharp/ with generator and smoke-test scripts; reorganize README and replace flat yaml paths. Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
These were development utilities and should not live under temp/ in git. Co-authored-by: Cursor <cursoragent@cursor.com>
|
@jonasteuwen this is ready for review! |
| if uniform_range and linear_range: | ||
| raise ValueError("uniform_range and linear_range are mutually exclusive.") | ||
|
|
There was a problem hiding this comment.
Why not a range type? range.UNIFORM and range.LINEAR
| @@ -139,36 +145,64 @@ def __init__( | |||
| along the fourth last dimension. Similarly for MaskFuncMode.MULTISLICE, the mask will be created for each | |||
| slice along the fourth last dimension. Default: MaskFuncMode.STATIC. | |||
| """ | |||
There was a problem hiding this comment.
I guess we can start with GoogleDoc here, is it possible to mix formats with sphinx? I think so, so we can translate it slowly.
| uniform_range: bool = False | ||
| linear_range: bool = False |
| sample["field_strength"] = np.array([3.0]) if "30T" in str(sample.get("filename", "")) else np.array([1.5]) | ||
|
|
There was a problem hiding this comment.
This seems very specific to a dataset, shouldn't that be in a subclass of this? Also, these are not the only field strengths we have
There was a problem hiding this comment.
Ah - it's FastMRI, nevermind
| @@ -0,0 +1,231 @@ | |||
| # Copyright 2025 AI for Oncology Research Group. All Rights Reserved. | |||
| # | |||
There was a problem hiding this comment.
New files should be in 2026?
| # Copyright 2025 AI for Oncology Research Group. All Rights Reserved. | ||
| # |
| ) -> ModConv2d: | ||
| """Construct :class:`ModConv2d` with typed modulation arguments.""" |
There was a problem hiding this comment.
Refer to the base class for the docstring. Also, isn't typically the functional interface defined first and then the class type?
| fc_groups: int = 1 | ||
| fc_activation: str = "sigmoid" | ||
| num_weights: Optional[int] = None |
There was a problem hiding this comment.
Why not ActivationType? Surely you have defined that somewhere already
| fc_groups: int = 1 | ||
| fc_activation: str = "sigmoid" | ||
| num_weights: Optional[int] = None |
Description
This PR ports conditional learned reconstruction into DIRECT by introducing modulated convolutions — convolutional layers whose weights are adapted by small MLPs from acquisition metadata (acceleration factor, ACS fraction, field strength).
The implementation follows Conditional Learned Reconstruction for Medical Imaging (Moriakov et al., MIDL 2026; OpenReview PDF). A single trained backbone can be conditioned at inference time on the actual undersampling pattern instead of training separate models per acceleration.
What's New
Modulated convolution package (
direct/nn/conv/modulated/)ModConv2d/ModConv3dand transposed variants with modulation types:NONE,FEATURES,FULL,PARTIAL_IN,PARTIAL_OUT,SUMauxiliary_data.py— registry-based auxiliary feature pipeline:prepare_auxiliary_data(data, cfg)builds(batch, aux_in_features)conditioning vectorsregister_auxiliary_feature()for custom conditioning channelsacceleration,center_fraction,field_strengthdirect/nn/adain/)Auxiliary conditioning pipeline
CreateSamplingMaskrequestsreturn_accelerationfromBaseMaskFuncmasksacceleration,center_fractionMRIModelEngine._attach_auxiliary_data()attachesauxiliary_dataeach iteration (supervised, SSL, JSSL)field_strength(1.5 T / 3.0 T from filename)Triangular acceleration sampling
linear_range: truein masking config samples acceleration from a triangular distribution (paper Section 4.3.2)direct/utils/distributions.py—triangular_distribution()BaseMaskFuncreturns sampled acceleration metadata viareturn_accelerationModel support
Modulated convolutions wired end-to-end through conv-based models:
auxiliary_dataconv_modulationconv_modulationconv_modulationconv_modulationconv_modulationconv_modulationconv_modulationconv_modulationmodulation(layer-level)Shared config fields:
conv_modulation,aux_in_features,auxiliary_features,log_aux,fc_hidden_features,fc_activation,fc_groups,num_weightsDocumentation & configs
projects/modulated_convolution/README.rst— tutorial with paper figures and section referencesWhat's Changed
Masking (
direct/common/subsample.py)_draw_acceleration_value()/_draw_acceleration_pair()foruniform_rangeandlinear_rangereturn_accelerationinBaseMaskFunc.__call__high <= 0guard)CartesianMagicMaskFuncinteger center line counts no longer capped bymin(1/accel, cf); ACS-only edge case handledlen(center_fractions) != len(accelerations)validationU-Net backbone
Unet2d/Unet3d/NormUnet*swapConv2d→ModConv2dwhenconv_modulation != NONEPaper Mapping
ModConv2dwithFEATUREStypez = log([R, 100·r_acs])log_aux: true+acceleration/center_fractionbatch keysR ∈ [4,16]linear_range: truein masking config[32,8]/[32,16]/[32,32]fc_hidden_featuresin model configRelated
projects/modulated_convolution/README.rst