what.examples.mobilenet_ssd_demo

 1import cv2
 2import torch
 3
 4from what.cli.model import *
 5from what.utils.file import get_file
 6
 7from what.models.detection.ssd.mobilenet_v1_ssd import MobileNetV1SSD
 8from what.models.detection.ssd.mobilenet_v2_ssd_lite import MobileNetV2SSDLite
 9
10from what.models.detection.utils.box_utils import draw_bounding_boxes
11
12device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13
14what_ssd_model_list = what_model_list[6:8]
15
16def mobilenet_ssd_inference_demo():
17
18    max_len = max([len(x[WHAT_MODEL_NAME_INDEX]) for x in what_ssd_model_list])
19    for i, model in enumerate(what_ssd_model_list, start=1):
20        if os.path.isfile(os.path.join(WHAT_MODEL_PATH, model[WHAT_MODEL_FILE_INDEX])):
21            downloaded = 'x'
22        else:
23            downloaded = ' '
24        print('[{}] {} : {:<{w}s}\t{}\t{}'.format(downloaded, i, model[WHAT_MODEL_NAME_INDEX], model[WHAT_MODEL_TYPE_INDEX], model[WHAT_MODEL_DESC_INDEX], w=max_len))
25
26    index = input(f"Please input the model index: ")
27    while not index.isdigit() or int(index) > len(what_ssd_model_list):
28        index = input(f"Model [{index}] does not exist. Please try again: ")
29
30    index = int(index) - 1
31
32    # Download the model first if not exists
33    # Check what_model_list for all available models
34    if not os.path.isfile(os.path.join(WHAT_MODEL_PATH, what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX])):
35        get_file(what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX],
36                    WHAT_MODEL_PATH,
37                    what_ssd_model_list[index][WHAT_MODEL_URL_INDEX],
38                    what_ssd_model_list[index][WHAT_MODEL_HASH_INDEX])
39
40    if index == 0:
41        # Initialize the model
42        model = MobileNetV1SSD(os.path.join(WHAT_MODEL_PATH, what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX]),
43                            is_test=True,
44                            device=device)
45
46    if index == 1:
47        # Initialize the model
48        model = MobileNetV2SSDLite(os.path.join(WHAT_MODEL_PATH, what_model_list[index][WHAT_MODEL_FILE_INDEX]),
49                                is_test=True,
50                                device=device)
51
52    video = input(f"Please input the OpenCV capture device (e.g. 0, 1, 2): ")
53
54    while not video.isdigit():
55        video = input(f"Please input the OpenCV capture device (e.g. 0, 1, 2): ")
56
57    # Capture from camera
58    cap = cv2.VideoCapture(int(video))
59    #cap.set(3, 1920)
60    #cap.set(4, 1080)
61
62    try:
63        while True:
64            _, orig_image = cap.read()
65            if orig_image is None:
66                continue
67
68            # Image preprocessing
69            image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
70
71            # Run inference
72            images, boxes, labels, probs = model.predict(image, 10, 0.4)
73            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
74
75            # Draw bounding boxes onto the image
76            height, width, _ = image.shape
77
78            output = draw_bounding_boxes(image, boxes, labels, model.class_names, probs);
79
80            cv2.imshow('MobileNet SSD Demo', output)
81
82            if cv2.waitKey(1) & 0xFF == ord('q'):
83                break
84
85        cap.release()
86        cv2.destroyAllWindows()
87
88    except Exception as e:
89        print(enumerate)
def mobilenet_ssd_inference_demo():
17def mobilenet_ssd_inference_demo():
18
19    max_len = max([len(x[WHAT_MODEL_NAME_INDEX]) for x in what_ssd_model_list])
20    for i, model in enumerate(what_ssd_model_list, start=1):
21        if os.path.isfile(os.path.join(WHAT_MODEL_PATH, model[WHAT_MODEL_FILE_INDEX])):
22            downloaded = 'x'
23        else:
24            downloaded = ' '
25        print('[{}] {} : {:<{w}s}\t{}\t{}'.format(downloaded, i, model[WHAT_MODEL_NAME_INDEX], model[WHAT_MODEL_TYPE_INDEX], model[WHAT_MODEL_DESC_INDEX], w=max_len))
26
27    index = input(f"Please input the model index: ")
28    while not index.isdigit() or int(index) > len(what_ssd_model_list):
29        index = input(f"Model [{index}] does not exist. Please try again: ")
30
31    index = int(index) - 1
32
33    # Download the model first if not exists
34    # Check what_model_list for all available models
35    if not os.path.isfile(os.path.join(WHAT_MODEL_PATH, what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX])):
36        get_file(what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX],
37                    WHAT_MODEL_PATH,
38                    what_ssd_model_list[index][WHAT_MODEL_URL_INDEX],
39                    what_ssd_model_list[index][WHAT_MODEL_HASH_INDEX])
40
41    if index == 0:
42        # Initialize the model
43        model = MobileNetV1SSD(os.path.join(WHAT_MODEL_PATH, what_ssd_model_list[index][WHAT_MODEL_FILE_INDEX]),
44                            is_test=True,
45                            device=device)
46
47    if index == 1:
48        # Initialize the model
49        model = MobileNetV2SSDLite(os.path.join(WHAT_MODEL_PATH, what_model_list[index][WHAT_MODEL_FILE_INDEX]),
50                                is_test=True,
51                                device=device)
52
53    video = input(f"Please input the OpenCV capture device (e.g. 0, 1, 2): ")
54
55    while not video.isdigit():
56        video = input(f"Please input the OpenCV capture device (e.g. 0, 1, 2): ")
57
58    # Capture from camera
59    cap = cv2.VideoCapture(int(video))
60    #cap.set(3, 1920)
61    #cap.set(4, 1080)
62
63    try:
64        while True:
65            _, orig_image = cap.read()
66            if orig_image is None:
67                continue
68
69            # Image preprocessing
70            image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
71
72            # Run inference
73            images, boxes, labels, probs = model.predict(image, 10, 0.4)
74            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
75
76            # Draw bounding boxes onto the image
77            height, width, _ = image.shape
78
79            output = draw_bounding_boxes(image, boxes, labels, model.class_names, probs);
80
81            cv2.imshow('MobileNet SSD Demo', output)
82
83            if cv2.waitKey(1) & 0xFF == ord('q'):
84                break
85
86        cap.release()
87        cv2.destroyAllWindows()
88
89    except Exception as e:
90        print(enumerate)