Skip to content

Commit

Permalink
Changes for Pytorch 1.5.0
Browse files Browse the repository at this point in the history
  • Loading branch information
ducksoup committed Apr 22, 2020
1 parent 5b017ac commit 24fc791
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ To install PyTorch, please refer to https://github.com/pytorch/pytorch#installat

To install the package containing the iABN layers:
```bash
pip install git+https://github.com/mapillary/[email protected].11
pip install git+https://github.com/mapillary/[email protected].12
```
Note that some parts of InPlace-ABN have native C++/CUDA implementations, meaning that the command above will need to
compile them.
Expand Down
8 changes: 5 additions & 3 deletions inplace_abn/abn.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation
super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, activation_param)

def forward(self, x):
return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
self.training, self.momentum, self.eps, self.activation, self.activation_param)
x, _, _ = inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
self.training, self.momentum, self.eps, self.activation, self.activation_param)
return x


class InPlaceABNSync(ABN):
Expand Down Expand Up @@ -147,6 +148,7 @@ def set_group(self, group):
self.group = group

def forward(self, x):
return inplace_abn_sync(
x, _, _ = inplace_abn_sync(
x, self.weight, self.bias, self.running_mean, self.running_var, self.training, self.momentum, self.eps,
self.activation, self.activation_param, self.group)
return x
6 changes: 4 additions & 2 deletions inplace_abn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,13 @@ def forward(ctx, x, weight, bias, running_mean, running_var,

# Save for backward
ctx.save_for_backward(x, var, count, weight, bias)
return x

ctx.mark_non_differentiable(running_mean, running_var)
return x, running_mean, running_var

@staticmethod
@once_differentiable
def backward(ctx, dy_act):
def backward(ctx, dy_act, _drunning_mean, _drunning_var):
y_act, var, count, weight, bias = ctx.saved_tensors

# Call backward_reduce if we need to compute at least one of the gradients
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def find_sources(root_dir, with_cuda=True):
"cxx": ["-O3"],
"nvcc": []
},
include_dirs=["include/"],
include_dirs=[path.join(here, "include")],
define_macros=[("WITH_CUDA", 1)]
)
]
Expand All @@ -41,7 +41,7 @@ def find_sources(root_dir, with_cuda=True):
name="inplace_abn._backend",
sources=find_sources("src", False),
extra_compile_args=["-O3"],
include_dirs=["include/"]
include_dirs=[path.join(here, "include")]
)
]

Expand Down

0 comments on commit 24fc791

Please sign in to comment.