Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regarding when skip is computed #1

Open
wpeebles opened this issue Jul 3, 2018 · 2 comments
Open

Regarding when skip is computed #1

wpeebles opened this issue Jul 3, 2018 · 2 comments

Comments

@wpeebles
Copy link

wpeebles commented Jul 3, 2018

In line 311 of train_svg_lp.py, what is the reasoning behind setting the condition as:

if opt.last_frame_skip or i < opt.n_past:	
    h, skip = h

instead of (the only change is the second < is swapped to <=):

if opt.last_frame_skip or i <= opt.n_past:	
    h, skip = h

Since h = encoder(x[i-1]), I believe the strict < will cause the skip features to be from the t = n_past - 1 frame instead of the t = n_past frame (where t is indexed from 1). Is this intended?

@lpcinelli
Copy link

I started out trying to explain you the reason, but in the middle of it I realized you were right. Following my explanation.

According to the paper "For all datasets we add skip connections from the encoder at the last groundtruth frame to the decoder at t, enabling the model to easily generate static background features." So skip should come from the last known groundtruth frame.

If opt.n_past is the number of context frames (that is, the number of frames whose groundtruths we know), the model will be conditioned on frames 0 to opt.n_past - 1. Hence, when predicting new frames, skip should come from frame number opt.n_past - 1. Then that would be:

           h, skip = encoder(x[i-1]) # encoder(x[opt.n_past - 1])

We can easily see that happens when i = opt.n_past, so indeed as pointed out by @wpeeb the if clause should be i <= opt.n_past instead of i < opt.n_past

@DanielTakeshi
Copy link

DanielTakeshi commented Jan 15, 2021

Hi @wpeebles I read through the code and I agree with you, as well as with @lpcinelli .

Of course, the way that the code is set up is unlikely to have too much of an impact. It just means if opt.n_past=3 for example, and we're conditioning on x_0, x_1, x_2 as ground truth, then the last skip connections we preserve come from Encoder(x_1) instead of Encoder(x_2). But it definitely seems like it should be if opt.last_frame_skip or i <= opt.n_past: in this segment, because skip does not get updated further once i is large enough.

svg/train_svg_lp.py

Lines 308 to 320 in 3f19f0b

for i in range(1, opt.n_past+opt.n_future):
h = encoder(x[i-1])
h_target = encoder(x[i])[0]
if opt.last_frame_skip or i < opt.n_past:
h, skip = h
else:
h = h[0]
z_t, mu, logvar = posterior(h_target)
_, mu_p, logvar_p = prior(h)
h_pred = frame_predictor(torch.cat([h, z_t], 1))
x_pred = decoder([h_pred, skip])
mse += mse_criterion(x_pred, x[i])
kld += kl_criterion(mu, logvar, mu_p, logvar_p)

Note that setting opt.last_frame_skip=False by default, and it probably should stay that way since otherwise we'd be using skip connections but from predicted frames.

Based on this discussion, seems like this should also have a <= instead of <:

svg/train_svg_lp.py

Lines 257 to 269 in 3f19f0b

def plot_rec(x, epoch):
frame_predictor.hidden = frame_predictor.init_hidden()
posterior.hidden = posterior.init_hidden()
gen_seq = []
gen_seq.append(x[0])
x_in = x[0]
for i in range(1, opt.n_past+opt.n_future):
h = encoder(x[i-1])
h_target = encoder(x[i])
if opt.last_frame_skip or i < opt.n_past:
h, skip = h
else:
h, _ = h

as well as other places that rely on this general pattern.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants