diff --git a/README.md b/README.md index 57c3617..02f58c7 100644 --- a/README.md +++ b/README.md @@ -78,19 +78,26 @@ cd src ./train.sh ``` * We implement our method by PyTorch and conduct experiments on 2 NVIDIA 2080Ti GPUs. -* We adopt pre-trained [ResNet-18](https://download.pytorch.org/models/resnet18-5c106cde.pth) and [Swin-B-224](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth) as backbone networks, which are saved in PRE folder. +* We adopt pre-trained [ResNet-18](https://download.pytorch.org/models/resnet18-5c106cde.pth) and [Swin-B-224](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth) as backbone networks, which are saved in **pre** folder. * We train our method on 3 settings : DUTS-TR, DUTS-TR+HRSOD and UHRSD_TR+HRSOD_TR. -* After training, the trained models will be saved in MODEL folder. +* After training, the trained models will be saved in **model** folder. ### Test -The trained model can be download here: [Google Drive](https://drive.google.com/drive/folders/1hXwCvrdmvkaRePXWPTw5tjFXmrrzHPtt?usp=sharing) +The trained model can be download here: [Google Drive](https://drive.google.com/drive/folders/1hXwCvrdmvkaRePXWPTw5tjFXmrrzHPtt?usp=sharing) +Rename the downloaded file to *model-31* and save it in **model** folder. +To test on the datasets, change working directory to **src** and run *test.py* as follows: ``` cd src python test.py ``` -* After testing, saliency maps will be saved in RESULT folder +To inference on custom images in a folder, change working directory to **src** and run *test_images.py* as follows: +``` +cd src +python test_images.py /path/to/folder +``` +* After testing, saliency maps will be saved in **result** folder. diff --git a/model/PGNet_DUT+HR/.gitignore b/model/PGNet_DUT+HR/.gitignore new file mode 100644 index 0000000..86d0cb2 --- /dev/null +++ b/model/PGNet_DUT+HR/.gitignore @@ -0,0 +1,4 @@ +# Ignore everything in this directory +* +# Except this file +!.gitignore \ No newline at end of file diff --git a/model/README.md b/model/README.md new file mode 100644 index 0000000..14fc576 --- /dev/null +++ b/model/README.md @@ -0,0 +1,4 @@ +The PGNet_DUT+HR trained model can be download here: [Google Drive](https://drive.google.com/drive/folders/1hXwCvrdmvkaRePXWPTw5tjFXmrrzHPtt?usp=sharing) +1. Download the trained model file. +2. Move it into the PGNet_DUT+HR folder within this folder. +3. Rename the downloaded file to *model-31*. \ No newline at end of file diff --git a/pre/.gitignore b/pre/.gitignore new file mode 100644 index 0000000..9705cf9 --- /dev/null +++ b/pre/.gitignore @@ -0,0 +1,5 @@ +# Ignore everything in this directory +* +# Except these file +!.gitignore +!.README.md \ No newline at end of file diff --git a/pre/README.md b/pre/README.md new file mode 100644 index 0000000..b96ca51 --- /dev/null +++ b/pre/README.md @@ -0,0 +1,4 @@ +We adopt pre-trained [ResNet-18](https://download.pytorch.org/models/resnet18-5c106cde.pth) and [Swin-B-224](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth) as backbone networks. +1. Download both pre-trained models and move them to this folder. +2. Rename downloaded ResNet-18 model file to *resnet18.pth*. +3. Rename downloaded Swin-B-224 model file to *swin224.pth*. \ No newline at end of file diff --git a/result/.gitignore b/result/.gitignore new file mode 100644 index 0000000..fbabbb6 --- /dev/null +++ b/result/.gitignore @@ -0,0 +1,5 @@ +# Ignore everything in this directory +* +# Except these file +!.gitignore +!README.md \ No newline at end of file diff --git a/src/dataset.py b/src/dataset.py old mode 100644 new mode 100755 index c196ee5..b25b735 --- a/src/dataset.py +++ b/src/dataset.py @@ -92,6 +92,7 @@ def __init__(self, cfg): img_name = each.split("/")[-1] img_name = img_name.split(".")[0] self.samples.append(img_name) + def __getitem__(self, idx): name = self.samples[idx] tig='.jpg' @@ -129,6 +130,23 @@ def __len__(self): return len(self.samples) +class DataImage(Data): + def __init__(self, cfg): + super().__init__(cfg) + self.samples = [os.path.join(self.cfg.datapath, i) for i in os.listdir(self.cfg.datapath)] + + def __getitem__(self, idx): + name = self.samples[idx] + image = cv2.imread(name).astype(np.float32) + image = image[:,:,::-1].copy() + + mask = image[:,:,0] + shape = mask.shape # + image, mask = self.normalize(image, mask) + image, mask = self.resize(image, mask) + image, mask = self.totensor(image, mask) + return image, mask, shape, name + ########################### Testing Script ########################### if __name__=='__main__': import matplotlib.pyplot as plt diff --git a/src/test.py b/src/test.py index 6ddfb60..d19dec4 100644 --- a/src/test.py +++ b/src/test.py @@ -41,6 +41,14 @@ def save(self): if __name__=='__main__': for path in ['../data/DAVIS-S','../data/UHRSD_TE','../data/HRSOD_TE','../data/DUT-OMRON','../data/HKU-IS','../data/ECSSD','../data/DUTS-TE','../data/PASCAL-S']: - for model in ['model-27','model-28','model-29','model-30','model-31','model-32']: - t = Test(dataset,PGNet, path,'./PGNet_DUT+HR/'+model) + if not os.path.isdir(path): + print(f'Skipping dataset. Directory does not exist: {path}') + continue + for model in ['model-27','model-28','model-29','model-30','model-31','model-32']: + model_path = os.path.join('model', 'PGNet_DUT+HR', model) + if os.path.isfile(model_path): + print(f'Testing model {model_path}') + t = Test(dataset,PGNet, path, model_path) t.save() + else: + print(f'Skipping {model_path} because file does not exist.') diff --git a/src/test_bigbird.py b/src/test_bigbird.py new file mode 100755 index 0000000..cbfeaaa --- /dev/null +++ b/src/test_bigbird.py @@ -0,0 +1,81 @@ +import os +import sys +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from dataset import Data, Config, DataImage +from PGNet import PGNet + + +class Test(object): + def __init__(self, network, path, model): + ## dataset + self.model = model + self.cfg = Config(datapath=path, snapshot=model, mode='test') + self.data = DataImage(self.cfg) + self.loader = DataLoader(self.data, batch_size=1, shuffle=False, num_workers=2) + ## network + self.net = network(self.cfg) + self.net.train(False) + self.net.cuda() + + def save(self, path, forwards=1): + print(f'Saving results in {path}') + os.makedirs(path, exist_ok=True) + + with torch.no_grad(): + for image, mask, shape, name in self.loader: + image = image.cuda().float() + mask = mask.cuda().float() + + # Successive iteration of forwards on previous results + for i in range(forwards): + p = self.net(image, shape=None) + # Replicate 1 channel mask into 3 channels + image = image.expand(-1, 3,-1,-1) + + # Resize and save + out_resize = F.interpolate(p[0],size=shape, mode='bilinear') + pred = torch.sigmoid(out_resize[0,0]) + pred = (pred*255).cpu().numpy() + name = os.path.basename(name[0]) + out = os.path.join(path, name) + cv2.imwrite(out, np.round(pred)) + + +if __name__=='__main__': + import argparse + parser = argparse.ArgumentParser(description = 'Saliency detection on cropped BigBird images') + parser.add_argument('--model', help='path to model') + parser.add_argument('--root', help='root folder containing cropped bigbird object instances', required=True) + parser.add_argument('--objects', help='only crop specific objects', nargs='*', default=None) + parser.add_argument('--in-folder', help='name of folder where cropped images are stored', required=True) + parser.add_argument('--out', help='output root path (default=root)', default=None) + parser.add_argument('--out-folder', help='name of folder where output saliency maps are to be stored', required=True) + parser.add_argument('--forwards', help='number of forward passes (default=1)', type=int, default=1) + args = parser.parse_args() + + for k,v in args.__dict__.items(): + print(f'{k:->20} : {v}') + + # Load object names + objects = args.objects + if objects is None: + objects = [i for i in os.listdir(args.root) if os.path.isdir(os.path.join(args.root, i))] + print(f'Inferencing on {len(objects)} objects') + + # Get output root + out = args.out + if out is None: + out = args.root + + # Iterate over each object + for obj_idx, obj in enumerate(objects, 1): + print(f'\nInferencing on \t[{obj_idx:>3d}/{len(objects)}] : \t{obj}') + obj_img_path = os.path.join(args.root, obj, args.in_folder) + obj_out_path = os.path.join(out, obj, args.out_folder) + t = Test(PGNet, obj_img_path, args.model) + t.save(obj_out_path, args.forwards) + print('-'*70) diff --git a/src/test_images.py b/src/test_images.py new file mode 100755 index 0000000..5453376 --- /dev/null +++ b/src/test_images.py @@ -0,0 +1,53 @@ +#!/usr/bin/python3 +#coding=utf-8 + +import os +import sys +sys.path.insert(0, '../') +sys.dont_write_bytecode = True +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from dataset import Data, Config, DataImage +from PGNet import PGNet + +class Test(object): + def __init__(self, Network, path, model): + ## dataset + self.model = model + self.cfg = Config(datapath=path, snapshot=model, mode='test') + self.data = DataImage(self.cfg) + self.loader = DataLoader(self.data, batch_size=1, shuffle=False, num_workers=2) + ## network + self.net = Network(self.cfg) + self.net.train(False) + self.net.cuda() + + def save(self): + head = os.path.join('../result', self.model[3:], self.cfg.datapath.split(os.sep)[-1]) + if not os.path.exists(head): + os.makedirs(head) + print(f'Saving results at {head}') + + with torch.no_grad(): + for image, mask, shape, name in self.loader: + + image = image.cuda().float() + mask = mask.cuda().float() + p = self.net(image, shape=None) + out_resize = F.interpolate(p[0],size=shape, mode='bilinear') + pred = torch.sigmoid(out_resize[0,0]) + pred = (pred*255).cpu().numpy() + + name = os.path.basename(name[0]) + out = os.path.join(head, name.split('.')[0]+'_mask.png') + print(out) + cv2.imwrite(out, np.round(pred)) + +if __name__=='__main__': + img_root = sys.argv[1] + for model in ['model-31']: + t = Test(PGNet, img_root, os.path.join('../model', 'PGNet_DUT+HR', model)) + t.save()