-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
179 lines (157 loc) · 8.77 KB
/
main.py
File metadata and controls
179 lines (157 loc) · 8.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import argparse
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch import nn, optim
import copy
import importlib
import logging
from model_loader import load_model
from compression.pruning import Pruning
from compression.quantization import Quantization
from compression.decomposition import Decomposition
from compression.knowledge_distillation import KnowledgeDistillation
from evaluation import Evaluation
# Setup basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class ModelCompressionFramework:
def __init__(self, model=None, model_name=None, custom_model_path=None,
dataset_name=None, custom_dataset_module=None, batch_size=128,
custom_evaluation_fn=None, device=None, accuracy_tolerance=5.0,
num_epochs=150, log_frequency=100, checkpoint_path=None,
quantization_mode="default", pruning_mode="default",
decomposition_mode="default", kd_mode="default"):
self.device = torch.device(device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu"))
logging.info(f"Using device: {self.device}")
# Load model: if a custom model is provided, load that; otherwise load by model_name.
self.model = load_model(model=model, model_name=model_name, custom_model_path=custom_model_path, device=self.device)
self.custom_dataset_module = custom_dataset_module
# Load dataset: either via custom module or standard dataset.
if custom_dataset_module:
logging.info("Loading custom dataset.")
custom_mod = importlib.import_module(custom_dataset_module)
if hasattr(custom_mod, "get_custom_dataset"):
train_dataset, val_dataset = custom_mod.get_custom_dataset()
else:
raise ValueError("The custom dataset module must implement a 'get_custom_dataset' function.")
elif dataset_name:
logging.info(f"Loading standard dataset: {dataset_name}")
train_dataset, val_dataset = self.load_standard_dataset(dataset_name, batch_size)
else:
raise ValueError("Either a dataset name or a custom dataset module must be provided.")
self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
self.val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
self.custom_evaluation_fn = custom_evaluation_fn or Evaluation(self.device).default_evaluation
self.accuracy_tolerance = accuracy_tolerance
self.num_epochs = num_epochs
self.log_frequency = log_frequency
self.checkpoint_path = checkpoint_path
self.quantization_mode = quantization_mode
self.pruning_mode = pruning_mode
self.decomposition_mode = decomposition_mode
self.kd_mode = kd_mode
self.compressed_models = {}
self.original_metrics = {}
self.compressed_metrics = {}
# Define transformation used for evaluation or preprocessing if required.
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def load_standard_dataset(self, dataset_name, batch_size):
# Example implementation to load a standard dataset.
if dataset_name.lower() == "cifar10":
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
else:
raise ValueError(f"Unsupported standard dataset: {dataset_name}")
return train_dataset, val_dataset
def train(self):
# Placeholder for a training procedure.
logging.info("Training started...")
# Insert your training logic here
logging.info("Training completed.")
def compress_model(self):
# Apply each compression technique based on the selected mode.
logging.info("Applying model compression techniques.")
# Quantization
logging.info(f"Starting Quantization with mode: {self.quantization_mode}")
quantized_model = Quantization(self.model, self.device, mode=self.quantization_mode).apply()
self.compressed_models['quantization'] = quantized_model
# Pruning
logging.info(f"Starting Pruning with mode: {self.pruning_mode}")
pruned_model = Pruning(self.model, self.device, mode=self.pruning_mode).apply()
self.compressed_models['pruning'] = pruned_model
# Decomposition
logging.info(f"Starting Decomposition with mode: {self.decomposition_mode}")
decomposed_model = Decomposition(self.model, self.device, mode=self.decomposition_mode).apply()
self.compressed_models['decomposition'] = decomposed_model
# Knowledge Distillation
logging.info(f"Starting Knowledge Distillation with mode: {self.kd_mode}")
kd_model = KnowledgeDistillation(self.model, self.device, mode=self.kd_mode).apply()
self.compressed_models['knowledge_distillation'] = kd_model
def evaluate_models(self):
# Evaluate the original and compressed models.
logging.info("Evaluating the original model.")
self.original_metrics = self.custom_evaluation_fn(self.model, self.val_loader, self.device)
logging.info(f"Original model metrics: {self.original_metrics}")
for key, comp_model in self.compressed_models.items():
logging.info(f"Evaluating model compressed with {key}.")
self.compressed_metrics[key] = self.custom_evaluation_fn(comp_model, self.val_loader, self.device)
logging.info(f"Metrics for {key}: {self.compressed_metrics[key]}")
def run(self):
# Execute the full model compression pipeline.
try:
self.train()
self.compress_model()
self.evaluate_models()
except Exception as e:
logging.error("An error occurred during the compression pipeline.", exc_info=True)
raise e
def parse_args():
parser = argparse.ArgumentParser(description='Model Compression Framework')
parser.add_argument('--model_name', type=str, default=None, help='Name of the model to load')
parser.add_argument('--custom_model_path', type=str, default=None, help='Path to a custom model file')
parser.add_argument('--dataset_name', type=str, default=None, help='Name of the standard dataset to use (e.g., cifar10)')
parser.add_argument('--custom_dataset_module', type=str, default=None, help='Custom dataset module name')
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training and evaluation')
parser.add_argument('--device', type=str, default=None, help='Device to use for computations')
parser.add_argument('--num_epochs', type=int, default=150, help='Number of training epochs')
parser.add_argument('--log_frequency', type=int, default=100, help='Logging frequency during training')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Path to save checkpoints')
parser.add_argument('--accuracy_tolerance', type=float, default=5.0, help='Accuracy tolerance for compression')
parser.add_argument('--quantization_mode', type=str, default="default", help='Mode to use for quantization')
parser.add_argument('--pruning_mode', type=str, default="default", help='Mode to use for pruning')
parser.add_argument('--decomposition_mode', type=str, default="default", help='Mode to use for decomposition')
parser.add_argument('--kd_mode', type=str, default="default", help='Mode to use for knowledge distillation')
return parser.parse_args()
def main():
args = parse_args()
framework = ModelCompressionFramework(
model_name=args.model_name,
custom_model_path=args.custom_model_path,
dataset_name=args.dataset_name,
custom_dataset_module=args.custom_dataset_module,
batch_size=args.batch_size,
device=args.device,
num_epochs=args.num_epochs,
log_frequency=args.log_frequency,
checkpoint_path=args.checkpoint_path,
accuracy_tolerance=args.accuracy_tolerance,
quantization_mode=args.quantization_mode,
pruning_mode=args.pruning_mode,
decomposition_mode=args.decomposition_mode,
kd_mode=args.kd_mode
)
framework.run()
if __name__ == '__main__':
main()