-
Notifications
You must be signed in to change notification settings - Fork 3
/
lpool4.py
66 lines (49 loc) · 1.77 KB
/
lpool4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Local Pooling Operations"""
# Authors: Nicolas Pinto <[email protected]>
# Nicolas Poilvert <[email protected]>
# Giovani Chiachia <[email protected]>
#
# License: BSD
__all__ = ['lpool4']
import numpy as np
from skimage.util.shape import view_as_windows
import numexpr as ne
if not ne.use_vml:
import warnings
warnings.warn("numexpr is NOT using Intel VML!")
# --
DEFAULT_STRIDE = (1,1)
DEFAULT_ORDER = 1.0
def lpool4(arr_in, neighborhood,
order=DEFAULT_ORDER,
stride=DEFAULT_STRIDE, arr_out=None):
"""4D Local Pooling Operation
XXX: docstring
"""
assert arr_in.ndim == 4
assert len(neighborhood) == 2
order = np.array([order], dtype=arr_in.dtype)
#stride = np.int(stride)
in_imgs, inh, inw, ind = arr_in.shape
nbh, nbw = neighborhood
assert nbh <= inh
assert nbw <= inw
if arr_out is not None:
assert arr_out.dtype == arr_in.dtype
assert arr_out.shape == (in_imgs,
1 + (inh - nbh) / stride[0],
1 + (inw - nbw) / stride[1],
ind)
_arr_out = ne.evaluate('arr_in ** order')
_arr_out = view_as_windows(_arr_out, (1, 1, nbw, 1))
_arr_out = ne.evaluate('sum(_arr_out, 6)')[:, :, ::stride[0], :, 0, 0, 0]
_arr_out = view_as_windows(_arr_out, (1, nbh, 1, 1))
_arr_out = ne.evaluate('sum(_arr_out, 5)')[:, ::stride[1], :, :, 0, 0, 0]
_arr_out = ne.evaluate('_arr_out ** (1 / order)')
if arr_out is not None:
arr_out[:] = _arr_out
else:
arr_out = _arr_out
assert arr_out.shape[0] == in_imgs
assert arr_out.dtype == arr_in.dtype
return arr_out