Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug in calc_gradient_penalty? #47

Open
lezhang-thu opened this issue May 17, 2019 · 1 comment
Open

bug in calc_gradient_penalty? #47

lezhang-thu opened this issue May 17, 2019 · 1 comment

Comments

@lezhang-thu
Copy link

gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA

Right before this line, it should be added "gradients = gradients.view(BATCH_SIZE, -1)"?

@t-ae
Copy link

t-ae commented Dec 12, 2019

It's OK in gan_mnist.py because real_data and fake_data has shape [BATCH_SZIE, OUTPUT_DIM], and Discriminator reshapes input at first.

input = input.view(-1, 1, 28, 28)

But, in gan_cifar10.py, Discriminator requires input tensor to have rank 4. It means gradients also has rank 4.

wgan-gp/gan_cifar10.py

Lines 89 to 93 in ae47a18

def forward(self, input):
output = self.main(input)
output = output.view(-1, 4*4*4*DIM)
output = self.linear(output)
return output

So computation of gradient_penalty is not correct.

In original(?) implementation Discriminator requires rank 2 tensor.
https://github.com/igul222/improved_wgan_training/blob/master/gan_cifar.py#L71

In gan_language.py, original Discriminator takes rank 3 tensor. But norm is computed along two axes.
https://github.com/igul222/improved_wgan_training/blob/master/gan_language.py#L107

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants