-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy patheval.py
More file actions
74 lines (59 loc) · 2.1 KB
/
eval.py
File metadata and controls
74 lines (59 loc) · 2.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# -*- coding: utf-8 -*-
"""Untitled10.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1jP0hfA5O4j1AounpfAvq69V5YAC7RT8Z
"""
# eval.py
import os
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from models import DeepJSCC
from metrics import psnr
@torch.no_grad()
def eval_psnr(model, loader, device, snr_db=10):
model.eval()
acc = 0.0
for x, _ in loader:
x = x.to(device)
xhat = model(x, snr_db)
acc += psnr(x, xhat).item() * x.size(0)
return acc / len(loader.dataset)
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print("[device]", device, flush=True)
train_snr_list = [1, 4, 7, 13, 19]
snr_test_list = [1, 4, 7, 13, 19]
latent_ch = 8
tfm = transforms.Compose([transforms.ToTensor()])
testset = datasets.CIFAR10(root="./data", train=False, download=True, transform=tfm)
test_loader = DataLoader(testset, batch_size=256, shuffle=False, num_workers=0,
pin_memory=(device == "cuda"))
results = {}
for snr_tr in train_snr_list:
ckpt_path = f"checkpoints/deepjscc_snrtrain_{snr_tr}dB.pth"
ckpt = torch.load(ckpt_path, map_location=device)
model = DeepJSCC(latent_ch=latent_ch).to(device)
model.load_state_dict(ckpt["state_dict"])
curve = []
for snr_te in snr_test_list:
curve.append(eval_psnr(model, test_loader, device, snr_db=snr_te))
results[snr_tr] = curve
plt.figure()
for snr_tr in train_snr_list:
plt.plot(snr_test_list, results[snr_tr], marker="o",
label=f"Deep JSCC (SNR_train={snr_tr}dB)")
plt.xlabel("SNR_test (dB)")
plt.ylabel("PSNR (dB)")
plt.title("AWGN channel (k/n=1/6)")
plt.grid(True)
plt.legend()
os.makedirs("results", exist_ok=True)
outpath = "results/awgn_psnr_curves.png"
plt.savefig(outpath, dpi=200, bbox_inches="tight")
plt.show()
print("[saved]", outpath)
if __name__ == "__main__":
main()