diff --git a/utils/weight_transfer.py b/utils/weight_transfer.py index 86d3891..6d36697 100644 --- a/utils/weight_transfer.py +++ b/utils/weight_transfer.py @@ -43,7 +43,7 @@ def transfer_ConvTranspose2d(m1, m2, input_index=None, output_index=None): assert isinstance(m1, nn.ConvTranspose2d) and isinstance(m2, nn.ConvTranspose2d) assert output_index is None p = m1.weight.data - if input_index is not None: + if input_index is None: q = p.abs().sum([1, 2, 3]) _, idxs = q.topk(m2.in_channels, largest=True) p = p[idxs]