From 7863c8bbfc35501eccb76c968381370d9d0cc17b Mon Sep 17 00:00:00 2001 From: Eric Dill Date: Wed, 17 Jul 2024 09:13:35 -0400 Subject: [PATCH] minimal refactor --- train.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index 6763d00..13001a8 100644 --- a/train.py +++ b/train.py @@ -13,9 +13,8 @@ from plot import plot_cs -def main(): - # arg parser - data_config, encoder_config, decoder_config = parse_args() + +def main(data_config, encoder_config, decoder_config): # load the simulation data and create a dataloader dataloader = senseiver_dataloader(data_config, num_workers=4) @@ -75,5 +74,13 @@ def main(): output_im = model.test(dataloader, num_pix=2048, split_time=10) torch.save(output_im, f'{path}/res.torch') -if __name__=='__main__': - main() + +if __name__ == "__main__": + # arg parser + data_config, encoder_config, decoder_config = parse_args() + + main( + data_config=data_config, + encoder_config=encoder_config, + decoder_config=decoder_config, + )