-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
58 lines (46 loc) · 1.68 KB
/
Copy pathmain.py
File metadata and controls
58 lines (46 loc) · 1.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import argparse
import logging
import torch
from config.configure import VALID_MODEL_NAMES, DEFAULT_TARGETS, DEFAULT_THRESHOLD
from inference.segmenter import Segmenter
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser(description="Segmentation Inference")
parser.add_argument("--image", type=str, required=True, help="Input image path")
parser.add_argument(
"--model",
type=str,
default="google/deeplabv3_mobilenet_v2_1.0_513",
choices=VALID_MODEL_NAMES,
help="model name for segmentation",
)
parser.add_argument(
"--threshold",
type=float,
default=DEFAULT_THRESHOLD,
help="the confidence threshold for mask",
)
parser.add_argument(
"--target_ids",
type=int, # Convert each argument to an integer
nargs="+", # Accept one or more values, which will be collected into a list
default=DEFAULT_TARGETS,
help="List of target IDs for segmentation (e.g., 1 2 3)",
)
parser.add_argument(
"--device", type=str, default=None, help="Device for inference (cuda or cpu)"
)
args = parser.parse_args()
if not args.device:
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
segmenter = Segmenter(device=device, model_name=args.model)
logger.info(f"Processing {args.image}")
results = segmenter.segment(args.image, target_class_ids=args.target_ids, threshold=args.threshold)
logger.info(f"{results = }")
if __name__ == "__main__":
main()