diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..c937bf0 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,6 @@ +__pycache__ +*.py[cod] + +*.pth +*.pb +*.pkl diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c937bf0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +__pycache__ +*.py[cod] + +*.pth +*.pb +*.pkl diff --git a/AUTHORS.md b/AUTHORS.md new file mode 100644 index 0000000..917cc68 --- /dev/null +++ b/AUTHORS.md @@ -0,0 +1,2 @@ +Daniel J. Hofmann +Harsimrat Sandhawalia diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..a826f7a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,17 @@ +FROM ubuntu:18.04 + +WORKDIR /usr/src/app + +ENV LANG="C.UTF-8" LC_ALL="C.UTF-8" PATH="/opt/venv/bin:$PATH" PIP_NO_CACHE_DIR="false" + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + python3 python3-pip python3-venv libglib2.0-0 && \ + rm -rf /var/lib/apt/lists/* + +COPY requirements.txt . + +RUN python3 -m venv /opt/venv && \ + python3 -m pip install pip==19.2.3 pip-tools==4.0.0 && \ + python3 -m piptools sync + +COPY . . diff --git a/LICENSE b/LICENSE.md similarity index 100% rename from LICENSE rename to LICENSE.md diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..d5a228f --- /dev/null +++ b/Makefile @@ -0,0 +1,35 @@ +dockerimage ?= das/vmz +dockerfile ?= Dockerfile +srcdir ?= $(shell pwd) +datadir ?= $(shell pwd) + +install: + @docker build -t $(dockerimage) -f $(dockerfile) . + +i: install + + +update: + @docker build -t $(dockerimage) -f $(dockerfile) . --pull --no-cache + +u: update + + +run: + @docker run -it --rm -v $(srcdir):/usr/src/app/ \ + -v $(datadir):/data \ + --entrypoint=/bin/bash $(dockerimage) + +r: run + + +webcam: + @docker run -it --rm -v $(srcdir):/usr/src/app/ \ + -v $(datadir):/data \ + --device=/dev/video0 \ + --entrypoint=/bin/bash $(dockerimage) + +w: webcam + + +.PHONY: install i run r update u webcam w diff --git a/README.md b/README.md index 19b3867..a09c5f4 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,37 @@ -# video-resnet -ResNet 3D Conv Video models +# IG65-M PyTorch + +Unofficial PyTorch (and ONNX) models and weights for IG65-M pre-trained 3d video architectures. + +The official research Caffe2 model and weights are availabe at: https://github.com/facebookresearch/vmz + + +## Models + +| Model | Weights | Input Size | pth | onnx | +| ------------- | ------------------ | ---------- | ----------------------------------------------- | --------------------------------------------- | +| r(2+1)d 34 | IG65-M | 8x112x112 | *r2plus1d_34_clip8_ig65m_from_scratch.pth* | *r2plus1d_34_clip8_ig65m_from_scratch.pb* | +| r(2+1)d 34 | IG65-M + Kinetics | 8x112x112 | *r2plus1d_34_clip8_ft_kinetics_from_ig65m.pth* | *r2plus1d_34_clip8_ft_kinetics_from_ig65m.pb* | +| r(2+1)d 34 | IG65-M | 32x112x112 | NA | NA | +| r(2+1)d 34 | IG65-M + Kinetics | 32x112x112 | *r2plus1d_34_clip32_ft_kinetics_from_ig65m.pth* | r2plus1d_34_clip32_ft_kinetics_from_ig65m.pb | + + +## Usage + +See +- `convert.py` for model conversion +- `extract.py` for feature extraction + +We provide converted `.pth` PyTorch weights as artifacts in our Github releases. + + +## References +- [VMZ: Model Zoo for Video Modeling](https://github.com/facebookresearch/vmz) +- [Kinetics](https://arxiv.org/abs/1705.06950) +- [IG65-M](https://arxiv.org/abs/1905.00561) + + +## License + +Copyright © 2019 MoabitCoin + +Distributed under the MIT License (MIT). diff --git a/convert.py b/convert.py new file mode 100755 index 0000000..381e9af --- /dev/null +++ b/convert.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 + +import pickle +import argparse +from pathlib import Path + +import torch +import torch.nn as nn + +from torchvision.models.video.resnet import VideoResNet, BasicBlock, R2Plus1dStem, Conv2Plus1D + + +def r2plus1d_34(num_classes, pretrained=False, progress=False, **kwargs): + model = VideoResNet(block=BasicBlock, + conv_makers=[Conv2Plus1D] * 4, + layers=[3, 4, 6, 3], + stem=R2Plus1dStem) + + model.fc = nn.Linear(model.fc.in_features, out_features=num_classes) + + # Fix difference in PyTorch vs Caffe2 architecture + # https://github.com/facebookresearch/VMZ/issues/89 + model.layer2[0].conv2[0] = Conv2Plus1D(128, 128, 288) + model.layer3[0].conv2[0] = Conv2Plus1D(256, 256, 576) + model.layer4[0].conv2[0] = Conv2Plus1D(512, 512, 1152) + + # We need exact Caffe2 momentum for BatchNorm scaling + for m in model.modules(): + if isinstance(m, nn.BatchNorm3d): + m.eps = 1e-3 + m.momentum = 0.9 + + return model + + +def blobs_from_pkl(path): + with path.open(mode="rb") as f: + pkl = pickle.load(f, encoding="latin1") + return pkl["blobs"] + + +def copy_tensor(data, blobs, name): + tensor = torch.from_numpy(blobs[name]) + + del blobs[name] # enforce: use at most once + + assert data.size() == tensor.size() + assert data.dtype == tensor.dtype + + data.copy_(tensor) + + +def copy_conv(module, blobs, prefix): + assert isinstance(module, nn.Conv3d) + assert module.bias is None + copy_tensor(module.weight.data, blobs, prefix + "_w") + + +def copy_bn(module, blobs, prefix): + assert isinstance(module, nn.BatchNorm3d) + copy_tensor(module.weight.data, blobs, prefix + "_s") + copy_tensor(module.running_mean.data, blobs, prefix + "_rm") + copy_tensor(module.running_var.data, blobs, prefix + "_riv") + copy_tensor(module.bias.data, blobs, prefix + "_b") + + +def copy_fc(module, blobs): + assert isinstance(module, nn.Linear) + n = module.out_features + copy_tensor(module.bias.data, blobs, "last_out_L" + str(n) + "_b") + copy_tensor(module.weight.data, blobs, "last_out_L" + str(n) + "_w") + + +# https://github.com/pytorch/vision/blob/v0.4.0/torchvision/models/video/resnet.py#L174-L188 +# https://github.com/facebookresearch/VMZ/blob/6c925c47b7d6545b64094a083f111258b37cbeca/lib/models/r3d_model.py#L233-L275 +def copy_stem(module, blobs): + assert isinstance(module, R2Plus1dStem) + assert len(module) == 6 + copy_conv(module[0], blobs, "conv1_middle") + copy_bn(module[1], blobs, "conv1_middle_spatbn_relu") + assert isinstance(module[2], nn.ReLU) + copy_conv(module[3], blobs, "conv1") + copy_bn(module[4], blobs, "conv1_spatbn_relu") + assert isinstance(module[5], nn.ReLU) + + +# https://github.com/pytorch/vision/blob/v0.4.0/torchvision/models/video/resnet.py#L82-L114 +def copy_conv2plus1d(module, blobs, i, j): + assert isinstance(module, Conv2Plus1D) + assert len(module) == 4 + copy_conv(module[0], blobs, "comp_" + str(i) + "_conv_" + str(j) + "_middle") + copy_bn(module[1], blobs, "comp_" + str(i) + "_spatbn_" + str(j) + "_middle") + assert isinstance(module[2], nn.ReLU) + copy_conv(module[3], blobs, "comp_" + str(i) + "_conv_" + str(j)) + + +# https://github.com/pytorch/vision/blob/v0.4.0/torchvision/models/video/resnet.py#L82-L114 +def copy_basicblock(module, blobs, i): + assert isinstance(module, BasicBlock) + + assert len(module.conv1) == 3 + assert isinstance(module.conv1[0], Conv2Plus1D) + copy_conv2plus1d(module.conv1[0], blobs, i, 1) + assert isinstance(module.conv1[1], nn.BatchNorm3d) + copy_bn(module.conv1[1], blobs, "comp_" + str(i) + "_spatbn_" + str(1)) + assert isinstance(module.conv1[2], nn.ReLU) + + assert len(module.conv2) == 2 + assert isinstance(module.conv2[0], Conv2Plus1D) + copy_conv2plus1d(module.conv2[0], blobs, i, 2) + assert isinstance(module.conv2[1], nn.BatchNorm3d) + copy_bn(module.conv2[1], blobs, "comp_" + str(i) + "_spatbn_" + str(2)) + + if module.downsample is not None: + assert i in [3, 7, 13] + assert len(module.downsample) == 2 + assert isinstance(module.downsample[0], nn.Conv3d) + assert isinstance(module.downsample[1], nn.BatchNorm3d) + copy_conv(module.downsample[0], blobs, "shortcut_projection_" + str(i)) + copy_bn(module.downsample[1], blobs, "shortcut_projection_" + str(i) + "_spatbn") + + +def copy_layer(module, blobs, i): + assert {0: 3, 3: 4, 7: 6, 13: 3}[i] == len(module) + + for basicblock in module: + copy_basicblock(basicblock, blobs, i) + i += 1 + + +def init_canary(model): + nan = float("nan") + + for m in model.modules(): + if isinstance(m, nn.Conv3d): + assert m.bias is None + nn.init.constant_(m.weight, nan) + elif isinstance(m, nn.BatchNorm3d): + nn.init.constant_(m.weight, nan) + nn.init.constant_(m.running_mean, nan) + nn.init.constant_(m.running_var, nan) + nn.init.constant_(m.bias, nan) + elif isinstance(m, nn.Linear): + nn.init.constant_(m.weight, nan) + nn.init.constant_(m.bias, nan) + + +def check_canary(model): + for m in model.modules(): + if isinstance(m, nn.Conv3d): + assert m.bias is None + assert not torch.isnan(m.weight).any() + elif isinstance(m, nn.BatchNorm3d): + assert not torch.isnan(m.weight).any() + assert not torch.isnan(m.running_mean).any() + assert not torch.isnan(m.running_var).any() + assert not torch.isnan(m.bias).any() + elif isinstance(m, nn.Linear): + assert not torch.isnan(m.weight).any() + assert not torch.isnan(m.bias).any() + + +def main(args): + blobs = blobs_from_pkl(args.pkl) + + model = r2plus1d_34(num_classes=args.classes) + + init_canary(model) + + copy_stem(model.stem, blobs) + + layers = [model.layer1, model.layer2, model.layer3, model.layer4] + blocks = [0, 3, 7, 13] + + for layer, i in zip(layers, blocks): + copy_layer(layer, blobs, i) + + copy_fc(model.fc, blobs) + + assert not blobs + check_canary(model) + + # Export to pytorch .pth and self-contained onnx .pb files + + batch = torch.rand(1, 3, args.frames, 112, 112) # NxCxTxHxW + torch.save(model.state_dict(), args.out.with_suffix(".pth")) + torch.onnx.export(model, batch, args.out.with_suffix(".pb")) + + # Check pth roundtrip into fresh model + + model = r2plus1d_34(num_classes=args.classes) + model.load_state_dict(torch.load(args.out.with_suffix(".pth"))) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + arg = parser.add_argument + + arg("pkl", type=Path, help=".pkl file to read the R(2+1)D 34 layer weights from") + arg("out", type=Path, help="prefix to save converted R(2+1)D 34 layer weights to") + arg("--frames", type=int, choices=(8, 32), required=True, help="clip frames for video model") + arg("--classes", type=int, choices=(400, 487), required=True, help="classes in last layer") + + main(parser.parse_args()) diff --git a/extract.py b/extract.py new file mode 100755 index 0000000..a365ec3 --- /dev/null +++ b/extract.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 + +import sys +import math +import json +import argparse +from pathlib import Path + +import cv2 +import numpy as np + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, IterableDataset, get_worker_info + +from torchvision.transforms import Compose +from torchvision.models.video.resnet import VideoResNet, BasicBlock, R2Plus1dStem, Conv2Plus1D + +from einops.layers.torch import Rearrange + + +def r2plus1d_34(num_classes): + model = VideoResNet(block=BasicBlock, + conv_makers=[Conv2Plus1D] * 4, + layers=[3, 4, 6, 3], + stem=R2Plus1dStem) + + model.fc = nn.Linear(model.fc.in_features, out_features=num_classes) + + # Fix difference in PyTorch vs Caffe2 architecture + # https://github.com/facebookresearch/VMZ/issues/89 + model.layer2[0].conv2[0] = Conv2Plus1D(128, 128, 288) + model.layer3[0].conv2[0] = Conv2Plus1D(256, 256, 576) + model.layer4[0].conv2[0] = Conv2Plus1D(512, 512, 1152) + + # We need exact Caffe2 momentum for BatchNorm scaling + for m in model.modules(): + if isinstance(m, nn.BatchNorm3d): + m.eps = 1e-3 + m.momentum = 0.9 + + return model + + +class FrameRange: + def __init__(self, video, first, last): + assert first <= last + + for i in range(first): + ret, _ = video.read() + + if not ret: + raise RuntimeError("seeking to frame at index {} failed".format(i)) + + self.video = video + self.it = first + self.last = last + + def __next__(self): + if self.it >= self.last or not self.video.isOpened(): + raise StopIteration + + ok, frame = self.video.read() + + if not ok: + raise RuntimeError("decoding frame at index {} failed".format(self.it)) + + self.it += 1 + + return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + +class BatchedRange: + def __init__(self, rng, n): + self.rng = rng + self.n = n + + def __next__(self): + ret = [] + + for i in range(self.n): + ret.append(next(self.rng)) + + return ret + + +class TransformedRange: + def __init__(self, rng, fn): + self.rng = rng + self.fn = fn + + def __next__(self): + return self.fn(next(self.rng)) + + +class VideoDataset(IterableDataset): + def __init__(self, path, clip, transform=None): + super().__init__() + + self.path = path + self.clip = clip + self.transform = transform + + video = cv2.VideoCapture(str(path)) + frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + video.release() + + self.first = 0 + self.last = frames + + def __iter__(self): + info = get_worker_info() + + video = cv2.VideoCapture(str(self.path)) + + if info is None: + rng = FrameRange(video, self.first, self.last) + else: + per = int(math.ceil((self.last - self.first) / float(info.num_workers))) + wid = info.id + + first = self.first + wid * per + last = min(first + per, self.last) + + rng = FrameRange(video, first, last) + + if self.transform is not None: + fn = self.transform + else: + fn = lambda v: v + + return TransformedRange(BatchedRange(rng, self.clip), fn) + + +class WebcamDataset(IterableDataset): + def __init__(self, clip, transform=None): + super().__init__() + + self.clip = clip + self.transform = transform + self.video = cv2.VideoCapture(0) + + def __iter__(self): + info = get_worker_info() + + if info is not None: + raise RuntimeError("multiple workers not supported in WebcamDataset") + + # treat webcam as fixed frame range for now: 10 minutes + rng = FrameRange(self.video, 0, 30 * 60 * 10) + + if self.transform is not None: + fn = self.transform + else: + fn = lambda v: v + + return TransformedRange(BatchedRange(rng, self.clip), fn) + + +class ToTensor: + def __call__(self, x): + return torch.from_numpy(np.array(x)).float() / 255. + + +class Resize: + def __init__(self, size, mode="bilinear"): + self.size = size + self.mode = mode + + def __call__(self, video): + return torch.nn.functional.interpolate(video, size=self.size, + mode=self.mode, align_corners=False) + + +class CenterCrop: + def __init__(self, size): + self.size = size + + def __call__(self, video): + h, w = video.shape[-2:] + th, tw = self.size + + i = int(round((h - th) / 2.)) + j = int(round((w - tw) / 2.)) + + return video[..., i:(i + th), j:(j + tw)] + + +class Normalize: + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, video): + shape = (-1,) + (1,) * (video.dim() - 1) + + mean = torch.as_tensor(self.mean).reshape(shape) + std = torch.as_tensor(self.std).reshape(shape) + + return (video - mean) / std + + +def main(args): + if args.labels: + with args.labels.open() as f: + labels = json.load(f) + else: + labels = list(range(args.classes)) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if torch.cuda.is_available(): + torch.backends.cudnn.benchmark = True + + model = r2plus1d_34(num_classes=args.classes) + model = model.to(device) + + weights = torch.load(args.model, map_location=device) + model.load_state_dict(weights) + + model = nn.DataParallel(model) + model.eval() + + transform = Compose([ + ToTensor(), + Rearrange("t h w c -> c t h w"), + Resize((128, 171)), + Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]), + CenterCrop((112, 112)), + ]) + + #dataset = WebcamDataset(args.frames, transform=transform) + + dataset = VideoDataset(args.video, args.frames, transform=transform) + loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers) + + for inputs in loader: + # NxCxTxHxW + assert inputs.size() == (args.batch_size, 3, args.frames, 112, 112) + + inputs = inputs.to(device) + + outputs = model(inputs) + + _, preds = torch.max(outputs, dim=1) + preds = preds.data.cpu().numpy() + + scores = nn.functional.softmax(outputs, dim=1) + scores = scores.data.cpu().numpy() + + for pred, score in zip(preds, scores): + index = pred.item() + label = labels[index] + score = round(score.max().item(), 3) + + print("label='{}' score={}".format(label, score), file=sys.stderr) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + arg = parser.add_argument + + arg("model", type=Path, help=".pth file to load model weights from") + arg("video", type=Path, help="video file to run feature extraction on") + arg("--frames", type=int, choices=(8, 32), required=True, help="clip frames for video model") + arg("--classes", type=int, choices=(400, 487), required=True, help="classes in last layer") + arg("--batch-size", type=int, default=1, help="number of sequences per batch for inference") + arg("--num-workers", type=int, default=0, help="number of workers for data loading") + arg("--labels", type=Path, help="JSON file with label map array") + + main(parser.parse_args()) diff --git a/requirements.in b/requirements.in new file mode 100644 index 0000000..9a562cc --- /dev/null +++ b/requirements.in @@ -0,0 +1,5 @@ +numpy +torch +torchvision +opencv-contrib-python-headless +einops diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..173f33c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,106 @@ +# +# This file is autogenerated by pip-compile +# To update, run: +# +# pip-compile --generate-hashes +# +einops==0.1.0 \ + --hash=sha256:4ab512fe059c0841e1a315449ca9d7f35eaa05c8c095a14f2c1b92b2b77684d2 \ + --hash=sha256:4fd64864fcb8159074da3213b9327c242536784416cbf423745ef8579850d30b +numpy==1.17.2 \ + --hash=sha256:05dbfe72684cc14b92568de1bc1f41e5f62b00f714afc9adee42f6311738091f \ + --hash=sha256:0d82cb7271a577529d07bbb05cb58675f2deb09772175fab96dc8de025d8ac05 \ + --hash=sha256:10132aa1fef99adc85a905d82e8497a580f83739837d7cbd234649f2e9b9dc58 \ + --hash=sha256:12322df2e21f033a60c80319c25011194cd2a21294cc66fee0908aeae2c27832 \ + --hash=sha256:16f19b3aa775dddc9814e02a46b8e6ae6a54ed8cf143962b4e53f0471dbd7b16 \ + --hash=sha256:3d0b0989dd2d066db006158de7220802899a1e5c8cf622abe2d0bd158fd01c2c \ + --hash=sha256:438a3f0e7b681642898fd7993d38e2bf140a2d1eafaf3e89bb626db7f50db355 \ + --hash=sha256:5fd214f482ab53f2cea57414c5fb3e58895b17df6e6f5bca5be6a0bb6aea23bb \ + --hash=sha256:73615d3edc84dd7c4aeb212fa3748fb83217e00d201875a47327f55363cef2df \ + --hash=sha256:7bd355ad7496f4ce1d235e9814ec81ee3d28308d591c067ce92e49f745ba2c2f \ + --hash=sha256:7d077f2976b8f3de08a0dcf5d72083f4af5411e8fddacd662aae27baa2601196 \ + --hash=sha256:a4092682778dc48093e8bda8d26ee8360153e2047826f95a3f5eae09f0ae3abf \ + --hash=sha256:b458de8624c9f6034af492372eb2fee41a8e605f03f4732f43fc099e227858b2 \ + --hash=sha256:e70fc8ff03a961f13363c2c95ef8285e0cf6a720f8271836f852cc0fa64e97c8 \ + --hash=sha256:ee8e9d7cad5fe6dde50ede0d2e978d81eafeaa6233fb0b8719f60214cf226578 \ + --hash=sha256:f4a4f6aba148858a5a5d546a99280f71f5ee6ec8182a7d195af1a914195b21a2 +opencv-contrib-python-headless==4.1.1.26 \ + --hash=sha256:083c1d0dce23b86c627ad8c7eddc93b19431431ea7413be78673950e8a67966d \ + --hash=sha256:08db29152b2a124445e233ec90786a93150e565cdc83f37208e6ccdee87493a4 \ + --hash=sha256:1545a6d521d2de1294949a9b25ac2117dae617d87cddb8415b6518c5d3f21240 \ + --hash=sha256:33e75d168439c77fcea6fd983d132f4ea6fe6f873e4dde4586278965c36a8680 \ + --hash=sha256:3fd6edd50bff5e50c95799e717c096d2262b3967013a24badfdd809660a0fa19 \ + --hash=sha256:4a1f9c199e0c98b19bae8e03a490f0b613cfb119cd1611364cb3b6bc03c5c05d \ + --hash=sha256:4eb771b366307d8259be8cbd2ad744d477ed7bd3667f767951e8fe392c526e92 \ + --hash=sha256:5072d82175c41f9cc5df504ad78c9807e12e0a358a63a1b791c8cc9c0501b173 \ + --hash=sha256:5576f2884bfe33c73280c4f3d76cee4c71337f5504aacf6631ae5694a9718ee5 \ + --hash=sha256:58627c73e703a306391c102d79c94827ce5e5f401aef4f98ca864b4b8cb57841 \ + --hash=sha256:6e21b9dd145a04b73c2e1b5ab39b65b58fb42710bb18e07223bc70685748d5e6 \ + --hash=sha256:763b62ab72761c0ce3b78a0f985f3bfdcd067d573719da297d5826824acd29f5 \ + --hash=sha256:82981868a3ce8fb6b1f8c332deb86779567fcacc98026d18d3fac11503305760 \ + --hash=sha256:972787b61efcaf0c2d833e83601bc06511d8afe17c444e643f6e4b237d564157 \ + --hash=sha256:a135288b970165ffe9c4ad571c11d9f140b0e9ef53d6cb49d275ff309715df55 \ + --hash=sha256:ab8f0d900f1a0a88a7135050876da15fd8d0e023224e8100839cb3946980afbc \ + --hash=sha256:b03886e5eb1b84126b8ba38fed118b272941044c1b0d15f04550130509c1d6cd \ + --hash=sha256:b3a4cb11fe8f389278143d732b1bc1237e2ed4b373896a39018ac3d2ebb31069 \ + --hash=sha256:b65e99712fbf927237a7a9207ce4166c93c6bec21a7203bd08b08ef9b937501e \ + --hash=sha256:cbd9e52c1de91a40e294b3b96fb4c3758c133e4cc92b985beaeea9bde3bea3c5 \ + --hash=sha256:d713c76569f44768fd4843bc9dcb227cba30407a2365b042b88e32d23a1ecc55 \ + --hash=sha256:db33f3cdd5d59b8a6ab8dea544ef6ecbef2a448e579ac1ff3523074eedc86f05 \ + --hash=sha256:defdcc4ca6b86f9032e61258e6846ce786c64f586b10a70feda7a1218a1e3378 \ + --hash=sha256:e3f8a7c03ab35c98d402f2758ab88fcb9f9003404eec8085096136976536adc7 \ + --hash=sha256:eca35aca76e7e1debd051083399bbf8319dfdb47ca13df56b0d8acb5c2215a22 \ + --hash=sha256:f9d57c94410e91af940f331aa9351065ba9d470d05646b8fd289da1170051bd2 \ + --hash=sha256:ffec278ef8c6a0341b656dd967c2109c861e39106b0067583756575c54c4caf2 +pillow==6.1.0 \ + --hash=sha256:0804f77cb1e9b6dbd37601cee11283bba39a8d44b9ddb053400c58e0c0d7d9de \ + --hash=sha256:0ab7c5b5d04691bcbd570658667dd1e21ca311c62dcfd315ad2255b1cd37f64f \ + --hash=sha256:0b3e6cf3ea1f8cecd625f1420b931c83ce74f00c29a0ff1ce4385f99900ac7c4 \ + --hash=sha256:365c06a45712cd723ec16fa4ceb32ce46ad201eb7bbf6d3c16b063c72b61a3ed \ + --hash=sha256:38301fbc0af865baa4752ddae1bb3cbb24b3d8f221bf2850aad96b243306fa03 \ + --hash=sha256:3aef1af1a91798536bbab35d70d35750bd2884f0832c88aeb2499aa2d1ed4992 \ + --hash=sha256:3fe0ab49537d9330c9bba7f16a5f8b02da615b5c809cdf7124f356a0f182eccd \ + --hash=sha256:45a619d5c1915957449264c81c008934452e3fd3604e36809212300b2a4dab68 \ + --hash=sha256:49f90f147883a0c3778fd29d3eb169d56416f25758d0f66775db9184debc8010 \ + --hash=sha256:571b5a758baf1cb6a04233fb23d6cf1ca60b31f9f641b1700bfaab1194020555 \ + --hash=sha256:5ac381e8b1259925287ccc5a87d9cf6322a2dc88ae28a97fe3e196385288413f \ + --hash=sha256:6153db744a743c0c8c91b8e3b9d40e0b13a5d31dbf8a12748c6d9bfd3ddc01ad \ + --hash=sha256:6fd63afd14a16f5d6b408f623cc2142917a1f92855f0df997e09a49f0341be8a \ + --hash=sha256:70acbcaba2a638923c2d337e0edea210505708d7859b87c2bd81e8f9902ae826 \ + --hash=sha256:70b1594d56ed32d56ed21a7fbb2a5c6fd7446cdb7b21e749c9791eac3a64d9e4 \ + --hash=sha256:76638865c83b1bb33bcac2a61ce4d13c17dba2204969dedb9ab60ef62bede686 \ + --hash=sha256:7b2ec162c87fc496aa568258ac88631a2ce0acfe681a9af40842fc55deaedc99 \ + --hash=sha256:7cee2cef07c8d76894ebefc54e4bb707dfc7f258ad155bd61d87f6cd487a70ff \ + --hash=sha256:7d16d4498f8b374fc625c4037742fbdd7f9ac383fd50b06f4df00c81ef60e829 \ + --hash=sha256:b50bc1780681b127e28f0075dfb81d6135c3a293e0c1d0211133c75e2179b6c0 \ + --hash=sha256:bd0582f831ad5bcad6ca001deba4568573a4675437db17c4031939156ff339fa \ + --hash=sha256:cfd40d8a4b59f7567620410f966bb1f32dc555b2b19f82a91b147fac296f645c \ + --hash=sha256:e3ae410089de680e8f84c68b755b42bc42c0ceb8c03dbea88a5099747091d38e \ + --hash=sha256:e9046e559c299b395b39ac7dbf16005308821c2f24a63cae2ab173bd6aa11616 \ + --hash=sha256:ef6be704ae2bc8ad0ebc5cb850ee9139493b0fc4e81abcc240fb392a63ebc808 \ + --hash=sha256:f8dc19d92896558f9c4317ee365729ead9d7bbcf2052a9a19a3ef17abbb8ac5b \ + # via torchvision +six==1.12.0 \ + --hash=sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c \ + --hash=sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73 \ + # via torchvision +torch==1.2.0 \ + --hash=sha256:0698d0a48014b9b8f36d93e69901eca2e7ec712cd2033908f7a77e7d86a4f0d7 \ + --hash=sha256:2ac8e58b069232f079bd289aa160366a9367ae1a4616a2c1007dceed19ff9bfa \ + --hash=sha256:43a0e28c448ddeea65fb9e956bc743389592afac824095bdbc08e8a87364c639 \ + --hash=sha256:661ad06b4616663149bd504e8c0271196d0386712e21a92619d95ba88138794a \ + --hash=sha256:880a0c22692eaebbce808a5bf2255ab7d345ab43c40795be0a421c6250ba0fb4 \ + --hash=sha256:a13bf6f78a49d844b85c142b8cd62d2e1833a11ed21ea0bc6b1ac73d24c76415 \ + --hash=sha256:a8c21f82fd03b67927078ea917040478c3263753fe1906fc19d0f5f0c7d9aa10 \ + --hash=sha256:b87fd224a7de3bc01ce87eb947698797b4514e27115b0aa60a56991515dd9dd6 \ + --hash=sha256:f63d489c54b4f170ce8335727bbb196ceb9acd0e7805477bbef8fabc914bc0f9 +torchvision==0.4.0 \ + --hash=sha256:3a8e9403252fefdf6e8f9993ae111d28eb4ad1e73f696f03de485d7f77d88067 \ + --hash=sha256:6fff5a31d50de3a59dcceda2a48de9df33a5f43357dc3e0da0ffbb97699aec52 \ + --hash=sha256:740b3718470aa4ec0b389df876eb25117df1952dd2e8105b7828a02aa5bce73b \ + --hash=sha256:8114c33b736ee430496eef4fe03b25be8b939b2abd2a968558737bb9aed1928b \ + --hash=sha256:904ef213594672f2ed7fafa3ab010cbf2a4704a951a7bf221cf36b3d2e3acd62 \ + --hash=sha256:afff8e987564192bc7f139d8b089541d4471ad6fc99e977e8bc8dbb4e0873041 \ + --hash=sha256:d7939f2ca401de3067a30b6f4dcef63d13d24a4cd1ddc2d3a9af3413ce658d03 \ + --hash=sha256:d8c2402704ce8ef8e87e4922160388c7ca010ef27700082014d6bd694cf1cc51 \ + --hash=sha256:e00de7571d83f968f5aea7a59e84e3262669acef0a077ce4bd705eca2df68167