-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodels.py
More file actions
153 lines (114 loc) · 5.62 KB
/
Copy pathmodels.py
File metadata and controls
153 lines (114 loc) · 5.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.nn.parallel
import torchvision.transforms as transforms
import pandas as pd
from utils import mtcnn
import os
# print(f"PyTorch version: {torch.__version__}")
# device = "cuda" if torch.cuda.is_available() else "cpu"
# print(f"Using device: {device}")
root_dir ="CASIA_dataset/Images"
# For a teacher model pretrained on CASIA-Webface
from facenet_pytorch import InceptionResnetV1
model = InceptionResnetV1(pretrained='casia-webface').eval()
teacher_model = model
from torchvision.models import mobilenetv2
class CustomModel(mobilenetv2.MobileNetV2):
def __init__(self,base_model):
super(CustomModel,self).__init__()
self.features=base_model.features
self.classifier=base_model.classifier
def forward(self, x):
x = self.features(x)
x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
flattened_conv_output = torch.flatten(x, 1)
x = self.classifier(flattened_conv_output)
flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, kernel_size=258, stride=2)
return x, flattened_conv_output, flattened_conv_output_after_pooling
student_model_base = mobilenetv2.MobileNetV2(num_classes=10575)
student_model = CustomModel(student_model_base)
#TRAINING LOSS FUNCTION
def train_cosine_loss(teacher, student, train_loader, epochs, learning_rate, hidden_rep_loss_weight, ce_loss_weight, device, log_file="loss_log.txt", model_save_dir="output_logs"):
ce_loss = nn.CrossEntropyLoss()
cosine_loss = nn.CosineEmbeddingLoss()
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
teacher.to(device)
student.to(device)
teacher.eval() # Teacher set to evaluation mode
student.train() # Student to train mode
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
print("Using", torch.cuda.device_count(), "GPUs!")
teacher = nn.DataParallel(teacher)
student = nn.DataParallel(student)
loss_values = [] # List to store loss values
for epoch in range(1,epochs+1):
running_loss = 0.0
# Use tqdm to display a progress bar
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False)
for i,(inputs, labels) in enumerate(progress_bar):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
# Forward pass with the teacher model and keep only the hidden representation
with torch.no_grad():
teacher.classify = False
teacher_hidden_representation = teacher(inputs)
# Forward pass with the student model
student_logits, _ ,student_hidden_representation = student(inputs)
# Calculate the cosine loss. Target is a vector of ones. From the loss formula above we can see that is the case where loss minimization leads to cosine similarity increase.
hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))
# Calculate the true label loss
label_loss = ce_loss(student_logits, labels)
# Weighted sum of the two losses
loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss
loss.backward()
optimizer.step()
running_loss += loss.item()
progress_bar.set_postfix({"Loss": running_loss / (i + 1)}) # Update the progress bar
# Additional information
EPOCH = epoch
LOSS = running_loss
# Save the model after each epoch
torch.save({
'epoch': EPOCH,
'model_state_dict': student_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}, os.path.join(model_save_dir,f"current_model.pt"))
if epoch%5 == 0:
PATH = os.path.join(model_save_dir,f"model{epoch}.pt")
torch.save({
'epoch': EPOCH,
'model_state_dict': student_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}, PATH)
print(f"Epoch {epoch}/{epochs}, Loss: {running_loss / len(train_loader)}")
# Save loss values to a file
log_file_path = os.path.join(model_save_dir, log_file)
with open(log_file_path, 'w') as f:
for loss_value in loss_values:
f.write(f"{loss_value}\n")
# EVALUATE STUDENT MODEL
def check_accuracy(model, test_loader, device, isTeacher=False):
model.eval() # Set model to evaluation mode
correct = 0
total = 0
if isTeacher:
model.classify = True
with torch.no_grad():
progress_bar = tqdm(test_loader, leave=False)
for i, (inputs, labels) in enumerate(progress_bar):
inputs, labels = inputs.to(device), labels.to(device)
if isTeacher:
outputs = model(inputs)
else:
outputs,_,_ = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
progress_bar.set_postfix({"Current Accuracy": correct / (total)}) # Update the progress bar
print(f"Accuracy = {100 * correct / total}")