-
Notifications
You must be signed in to change notification settings - Fork 31
Home
The models are inspired by [1], [2]. They are just experiments I have tried and do not fully implement the architectures described in the referenced papers.
The input consists of 720p images from the YouTube-8M dataset (credit goes to gsssrao for the downloader and frame generator scripts).
Two datasets were used for training:
- 121,827 frames (download: weights)
- 2,286 frames (download: dataset / weights) - quickest way to start experimenting!
The images are padded to 1280x768 (i.e. 24,24 height pad) so that they can be split into 60 128x128 patches.
The model only gets to see a singular patch per forward pass (i.e. there are 60 forward passes and optimization steps for an image).
The loss is computed (per patch) as MSELoss(orig_patch_ij, out_patch_ij)
, and we have an average loss per image.
Model | Patch latent size | Compressed size |
---|---|---|
cae_16x8x8_zero_pad_bin |
16x8x8 | 7.5KB |
cae_16x8x8_refl_pad_bin |
16x8x8 | 7.5KB |
cae_16x16x16_zero_pad_bin |
16x16x16 | 30KB |
cae_32x32x32_zero_pad_bin |
32x32x32 | 240KB |
All models implement stochastic binarization [2], that is, the encoded representation is in binary format.
The number of bits per patch is given by the patch latent size, and the compressed size will be 60 * bits_per_patch / 8 / 1024
KB.
The benefits of stochastic binarization, as mentioned in [2], are:
(1) bit vectors are trivially serializable/deserializable for image transmission over the wire,
(2) control of the network compression rate is achieved simply by putting constraints on the bit allowance
(3) a binary bottleneck helps force the network to learn efficient representations compared to standard floating-point layers, which may have many redundant bit patterns that do not affect the output.
Since the best performing model is cae_32x32x32_zero_pad_bin
, we only describe its (high-level) architecture.
Output shapes and block descriptions can be found in the code.
encoder:
x => conv1 --> conv2 --> enc_block1 --> (+) --> enc_block2 --> (+) --> enc_block3 --> (+) --> conv3 (tanh) => enc
|-----------------------^ |------------------^ |------------------^
decoder:
enc => up_conv1 --> dec_block1 --> (+) --> dec_block2 --> (+) --> dec_block3 --> (+) --> up_conv2 --> up_conv3 (tanh) => x
|------------------------^ |------------------^ |------------------^