-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
106 lines (84 loc) · 4.53 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import argparse
import logging
import torch
from threading import Thread, Lock
from urllib.request import urlopen
import time
from transformers import CLIPProcessor, CLIPModel
from waggle.plugin import Plugin
from waggle.data.vision import Camera
class TextPromptWatcher:
"""
TextPromptWatcher manages a list of text prompts and can watch a remote URL to update the list.
"""
def __init__(self, text_prompts, poll_url_interval):
self.text_prompts = text_prompts
self.lock = Lock()
self.poll_url_interval = poll_url_interval
def get_text_prompts(self):
with self.lock:
return self.text_prompts
def watch_url(self, url):
while True:
try:
with urlopen(url) as f:
content = f.read()
text_prompts = content.decode().splitlines()
with self.lock:
self.text_prompts = text_prompts
except Exception:
logging.exception("failed to update text prompts")
time.sleep(self.poll_url_interval)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--debug", action="store_true", help="enable debug logging")
parser.add_argument("--cpu", action="store_true", help="use cpu instead of accelerator")
parser.add_argument("--input", default=0, help="input source")
parser.add_argument("--threshold-type", default="similarity", choices=["similarity", "softmax"], help="which type of value to check threshold on")
parser.add_argument("--threshold", type=float, help="threshold for publishing a detection")
parser.add_argument("--watch-text-url", help="url of text file to watch for new prompts")
parser.add_argument("--watch-text-interval", type=int, default=10, help="interval to poll url of text file")
parser.add_argument("text_prompts", nargs="+", help="list of text descriptions to match")
args = parser.parse_args()
logging.basicConfig(
level=logging.DEBUG if args.debug else logging.INFO, format="%(asctime)s %(message)s",
datefmt="%Y/%m/%d %H:%M:%S")
if args.threshold_type == "similarity" and args.threshold is None:
args.threshold = 28.0
elif args.threshold_type == "softmax" and args.threshold is None:
args.threshold = 0.90
device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
logging.info("using device %s", device)
logging.info("loading models...")
processor = CLIPProcessor.from_pretrained("./openai-clip-vit-base-patch32/")
model = CLIPModel.from_pretrained("./openai-clip-vit-base-patch32/").to(device)
logging.info("done loading models")
text_prompt_watcher = TextPromptWatcher(args.text_prompts, args.watch_text_interval)
if args.watch_text_url is not None:
Thread(target=text_prompt_watcher.watch_url, args=(args.watch_text_url,), daemon=True).start()
with Plugin() as plugin, Camera(args.input) as camera:
logging.info("processing stream...")
for snapshot in camera.stream():
# get latest cached text prompts
text_prompts = text_prompt_watcher.get_text_prompts()
logging.info("running inference...")
inputs = processor(text=text_prompts, images=snapshot.data, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
results = []
for prob, logits, description in sorted(zip(probs.view(-1), logits_per_image.view(-1), text_prompts)):
# TODO prefer similarity score to softmax prob for thresholding. software prob can give unituitive
# results - for example, when a single text is provided, that will always be published.
matched = (
(args.threshold_type == "similarity" and logits > args.threshold) or
(args.threshold_type == "softmax" and prob > args.threshold)
)
if matched:
plugin.publish("image.clip.prediction", f"{description},{logits:0.3f},{prob:0.3f}", timestamp=snapshot.timestamp)
marker = "*" if matched else " "
results.append(f"{logits:0.3f} {prob:0.3f} {marker} {description}")
logging.info("inference results are\n\n%s\n", "\n".join(results))
if __name__ == "__main__":
main()