diff --git a/graph_net/torch/sample_pass/backward_graph_extractor.py b/graph_net/torch/sample_pass/backward_graph_extractor.py index 7afa21364..f27b1da9b 100644 --- a/graph_net/torch/sample_pass/backward_graph_extractor.py +++ b/graph_net/torch/sample_pass/backward_graph_extractor.py @@ -1,4 +1,5 @@ import os +import shutil import inspect from pathlib import Path @@ -27,6 +28,10 @@ def __call__(self): ) module.train() + original_name_dir = os.path.join(self.output_dir, self.model_name) + if not os.path.exists(original_name_dir): + shutil.copytree(self.model_path, original_name_dir) + forward_inputs = self.set_requires_grad_for_forward_inputs( self.model_path, module, forward_inputs )