From 751c8a010c539393cff81ffe23dac1f88de5cac3 Mon Sep 17 00:00:00 2001 From: fangfangssj <1135470306@qq.com> Date: Wed, 11 Mar 2026 03:23:19 +0000 Subject: [PATCH 1/2] [Feature] Copy forward graph to original name in backward graph extraction Add shutil import and copy forward graph to original model name directory after extraction, so output contains model_name/, model_name_forward/, and model_name_backward/. Co-Authored-By: Claude Opus 4.6 --- graph_net/torch/sample_pass/backward_graph_extractor.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/graph_net/torch/sample_pass/backward_graph_extractor.py b/graph_net/torch/sample_pass/backward_graph_extractor.py index 7afa21364..aee3331ae 100644 --- a/graph_net/torch/sample_pass/backward_graph_extractor.py +++ b/graph_net/torch/sample_pass/backward_graph_extractor.py @@ -1,4 +1,5 @@ import os +import shutil import inspect from pathlib import Path @@ -34,6 +35,11 @@ def __call__(self): self.get_extractor("forward")(gm_holder["forward_gm"], forward_inputs) self.get_extractor("backward")(gm_holder["backward_gm"], backward_inputs) + forward_dir = os.path.join(self.output_dir, f"{self.model_name}_forward") + original_name_dir = os.path.join(self.output_dir, self.model_name) + if os.path.exists(forward_dir) and not os.path.exists(original_name_dir): + shutil.copytree(forward_dir, original_name_dir) + def get_extractor(self, suffix): return BuiltinGraphExtractor( name=f"{self.model_name}_{suffix}", From 60fb2c69f701576ac00126daaf2512b3fbecc5fd Mon Sep 17 00:00:00 2001 From: fangfangssj <1135470306@qq.com> Date: Wed, 11 Mar 2026 03:40:05 +0000 Subject: [PATCH 2/2] [Feature] Copy original graph before backward graph extraction Copy the original input graph to output directory before extracting forward and backward graphs, so output contains: - model_name/ (original graph) - model_name_forward/ (AOT captured forward graph) - model_name_backward/ (backward graph) Co-Authored-By: Claude Opus 4.6 --- graph_net/torch/sample_pass/backward_graph_extractor.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/graph_net/torch/sample_pass/backward_graph_extractor.py b/graph_net/torch/sample_pass/backward_graph_extractor.py index aee3331ae..f27b1da9b 100644 --- a/graph_net/torch/sample_pass/backward_graph_extractor.py +++ b/graph_net/torch/sample_pass/backward_graph_extractor.py @@ -28,6 +28,10 @@ def __call__(self): ) module.train() + original_name_dir = os.path.join(self.output_dir, self.model_name) + if not os.path.exists(original_name_dir): + shutil.copytree(self.model_path, original_name_dir) + forward_inputs = self.set_requires_grad_for_forward_inputs( self.model_path, module, forward_inputs ) @@ -35,11 +39,6 @@ def __call__(self): self.get_extractor("forward")(gm_holder["forward_gm"], forward_inputs) self.get_extractor("backward")(gm_holder["backward_gm"], backward_inputs) - forward_dir = os.path.join(self.output_dir, f"{self.model_name}_forward") - original_name_dir = os.path.join(self.output_dir, self.model_name) - if os.path.exists(forward_dir) and not os.path.exists(original_name_dir): - shutil.copytree(forward_dir, original_name_dir) - def get_extractor(self, suffix): return BuiltinGraphExtractor( name=f"{self.model_name}_{suffix}",