-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy patheval.py
More file actions
66 lines (47 loc) · 1.48 KB
/
Copy patheval.py
File metadata and controls
66 lines (47 loc) · 1.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
from utils import create_loader
from models import student_model, check_accuracy
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--batchsize",type=int)
parser.add_argument("--workers",type=int)
parser.add_argument("--device")
parser.add_argument("--test-size",type=float)
parser.add_argument("--model-path",type=str)
args = parser.parse_args()
# Hyperparameters
batch_size = 8
num_workers = 0
device = "cpu"
test_size = 0.1
model_path = ""
if(args.batchsize):
batch_size = args.batchsize
if(args.workers):
num_workers = args.workers
if(args.device):
if(args.device == "cuda"):
device = "cuda" if torch.cuda.is_available() else "cpu"
if(args.device == "mps"):
device = "mps" if torch.backends.mps.is_available() else "cpu"
if(args.test_size):
test_size = args.test_size
if(args.model_path):
model_path = args.model_path
else:
print("Please specify model path")
exit()
print(args)
print("Using device: {} ".format(device))
print("Evaluating on random {}k samples".format(int(test_size*500)))
# Load Model
model = student_model
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['model_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print("Model loaded from epoch {} with loss {}".format(epoch,loss))
model = model.eval()
model.to(device)
test_loader = create_loader(batch_size= batch_size , num_workers= num_workers, fraction=test_size)
check_accuracy(model,test_loader,device)