diff --git a/train.py b/train.py index 6763d00..13001a8 100644 --- a/train.py +++ b/train.py @@ -13,9 +13,8 @@ from plot import plot_cs -def main(): - # arg parser - data_config, encoder_config, decoder_config = parse_args() + +def main(data_config, encoder_config, decoder_config): # load the simulation data and create a dataloader dataloader = senseiver_dataloader(data_config, num_workers=4) @@ -75,5 +74,13 @@ def main(): output_im = model.test(dataloader, num_pix=2048, split_time=10) torch.save(output_im, f'{path}/res.torch') -if __name__=='__main__': - main() + +if __name__ == "__main__": + # arg parser + data_config, encoder_config, decoder_config = parse_args() + + main( + data_config=data_config, + encoder_config=encoder_config, + decoder_config=decoder_config, + )