-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
53 lines (49 loc) · 2.27 KB
/
Copy pathtest.py
File metadata and controls
53 lines (49 loc) · 2.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import argparse
import os
import torch
import imageio
import numpy as np
import torch.nn.functional as F
from SAM3UNet import SAM3UNet
from dataset import TestDataset
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, required=True,
help="path to the checkpoint of SAM3-UNet")
parser.add_argument("--sam3_checkpoint", type=str, default=None,
help="path to the pretrained SAM3 checkpoint (.pt); "
"defaults to ./sam3/sam3.pt or ./sam3/checkpoints/sam3.pt if present")
parser.add_argument("--test_image_path", type=str, required=True,
help="path to the image files for testing")
parser.add_argument("--test_gt_path", type=str, required=True,
help="path to the mask files for testing")
parser.add_argument("--save_path", type=str, required=True,
help="path to save the predicted masks")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_loader = TestDataset(args.test_image_path, args.test_gt_path, 1008)
model = SAM3UNet(checkpoint_path=args.sam3_checkpoint, device=str(device)).to(device)
state_dict = torch.load(args.checkpoint, map_location=device)
model.load_state_dict(state_dict, strict=True)
model.eval()
model.to(device)
os.makedirs(args.save_path, exist_ok=True)
for i in range(test_loader.size):
with torch.no_grad():
image, gt, name = test_loader.load_data()
gt = np.asarray(gt, np.float32)
image = image.to(device)
res, _, _ = model(image)
# fix: duplicate sigmoid
# res = torch.sigmoid(res)
res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
res = res.sigmoid().data.cpu()
res = res.numpy().squeeze()
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
res = (res * 255).astype(np.uint8)
# If you want to binarize the prediction results, please uncomment the following three lines.
# Note that this action will affect the calculation of evaluation metrics.
# lambda = 0.5
# res[res >= int(255 * lambda)] = 255
# res[res < int(255 * lambda)] = 0
print("Saving " + name)
imageio.imsave(os.path.join(args.save_path, name[:-4] + ".png"), res)