Skip to content
This repository has been archived by the owner on Jan 26, 2022. It is now read-only.

light_head_rcnn loss_bbox error #48

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/e2e_faster_rcnn_R-50-C4_2x.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ MODEL:
FASTER_RCNN: True
RESNETS:
IMAGENET_PRETRAINED_WEIGHTS: 'data/pretrained_model/resnet50_caffe.pth'
NUM_GPUS: 8
NUM_GPUS: 4
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
Expand Down
36 changes: 36 additions & 0 deletions configs/e2e_light_head_rcnn_R-50-C4.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
MODEL:
TYPE: generalized_rcnn
CONV_BODY: ResNet.ResNet50_conv4_body
LIGHT_HEAD_RCNN: True
RESNETS:
IMAGENET_PRETRAINED_WEIGHTS: 'data/pretrained_model/resnet50_caffe.pth'
NUM_GPUS: 4
SOLVER:
TYPE: 'SGD'
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.005
GAMMA: 0.1
# 2x schedule (note TRAIN.IMS_PER_BATCH: 1)
MAX_ITER: 360000
STEPS: [0, 240000, 320000]
RPN:
RPN_ON: True
# CLS_ACTIVATION: softmax
SIZES: (32, 64, 128, 256, 512)
LIGHT_HEAD_RCNN:
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_METHOD: PSRoIPool
TRAIN:
SCALES: (800,)
MAX_SIZE: 1333
IMS_PER_BATCH: 1
FG_THRESH: 0.7
BG_THRESH_HI: 0.5
BATCH_SIZE_PER_IM: 1024
TEST:
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
RPN_PRE_NMS_TOP_N: 6000
RPN_POST_NMS_TOP_N: 1000
3 changes: 2 additions & 1 deletion configs/e2e_mask_rcnn_R-101-FPN_2x.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ MODEL:
MASK_ON: True
RESNETS:
IMAGENET_PRETRAINED_WEIGHTS: 'data/pretrained_model/resnet101_caffe.pth'
NUM_GPUS: 8
NUM_GPUS: 4
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
Expand Down Expand Up @@ -41,3 +41,4 @@ TEST:
NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000
VIS: True
10 changes: 10 additions & 0 deletions lib/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,8 @@
# Indicates the model makes instance mask predictions (as in Mask R-CNN)
__C.MODEL.MASK_ON = False

__C.MODEL.LIGHT_HEAD_RCNN = False

# Indicates the model makes keypoint predictions (as in Mask R-CNN for
# keypoints)
__C.MODEL.KEYPOINTS_ON = False
Expand Down Expand Up @@ -633,6 +635,14 @@
__C.FAST_RCNN.ROI_XFORM_RESOLUTION = 14


# ---------------------------------------------------------------------------- #
#hw LIGHT_HEAD_RCNN options
# ---------------------------------------------------------------------------- #
__C.LIGHT_HEAD_RCNN = AttrDict()
__C.LIGHT_HEAD_RCNN.ROI_BOX_HEAD = ''
__C.LIGHT_HEAD_RCNN.MLP_HEAD_DIM = 1024
__C.LIGHT_HEAD_RCNN.ROI_XFORM_RESOLUTION = 7
__C.LIGHT_HEAD_RCNN.ROI_XFORM_METHOD = 'PSRoIPool'
# ---------------------------------------------------------------------------- #
# RPN options
# ---------------------------------------------------------------------------- #
Expand Down
8 changes: 4 additions & 4 deletions lib/core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def im_detect_bbox(model, im, target_scale, target_max_size, boxes=None):

inputs, im_scale = _get_blobs(im, boxes, target_scale, target_max_size)

if cfg.DEDUP_BOXES > 0 and not cfg.MODEL.FASTER_RCNN:
if cfg.DEDUP_BOXES > 0 and not cfg.MODEL.FASTER_RCNN and not cfg.MODEL.LIGHT_HEAD_RCNN:
v = np.array([1, 1e3, 1e6, 1e9, 1e12])
hashes = np.round(inputs['rois'] * cfg.DEDUP_BOXES).dot(v)
_, index, inv_index = np.unique(
Expand All @@ -126,7 +126,7 @@ def im_detect_bbox(model, im, target_scale, target_max_size, boxes=None):
boxes = boxes[index, :]

# Add multi-level rois for FPN
if cfg.FPN.MULTILEVEL_ROIS and not cfg.MODEL.FASTER_RCNN:
if cfg.FPN.MULTILEVEL_ROIS and not cfg.MODEL.FASTER_RCNN and not cfg.MODEL.LIGHT_HEAD_RCNN:
_add_multilevel_rois_for_test(inputs, 'rois')

if cfg.PYTORCH_VERSION_LESS_THAN_040:
Expand All @@ -138,7 +138,7 @@ def im_detect_bbox(model, im, target_scale, target_max_size, boxes=None):

return_dict = model(**inputs)

if cfg.MODEL.FASTER_RCNN:
if cfg.MODEL.FASTER_RCNN or cfg.MODEL.LIGHT_HEAD_RCNN:
rois = return_dict['rois'].data.cpu().numpy()
# unscale back to raw image space
boxes = rois[:, 1:5] / im_scale
Expand Down Expand Up @@ -168,7 +168,7 @@ def im_detect_bbox(model, im, target_scale, target_max_size, boxes=None):
# Simply repeat the boxes, once for each class
pred_boxes = np.tile(boxes, (1, scores.shape[1]))

if cfg.DEDUP_BOXES > 0 and not cfg.MODEL.FASTER_RCNN:
if cfg.DEDUP_BOXES > 0 and not cfg.MODEL.FASTER_RCNN and not cfg.MODEL.LIGHT_HEAD_RCNN:
# Map scores and predictions back to the original set of boxes
scores = scores[inv_index, :]
pred_boxes = pred_boxes[inv_index, :]
Expand Down
13 changes: 13 additions & 0 deletions lib/make.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ nvcc -c -o roi_crop_cuda_kernel.cu.o roi_crop_cuda_kernel.cu \
cd ../
python build.py

# compile psroi_pooling
cd ../../
cd model/psroi_pooling/src
echo "Compiling psroi pooling kernels by nvcc..."
nvcc -c -o psroi_pooling_kernel.cu.o psroi_pooling_kernel.cu \
-D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC $CUDA_ARCH
cd ../
python build.py

# compile roi_align (based on Caffe2's implementation)
cd ../../
cd modeling/roi_xfrom/roi_align/src
Expand All @@ -58,3 +67,7 @@ nvcc -c -o roi_align_kernel.cu.o roi_align_kernel.cu \
-D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC $CUDA_ARCH
cd ../
python build.py




Empty file.
Empty file.
15 changes: 15 additions & 0 deletions lib/model/psroi_align_pooling/_ext/psroi_align_pooling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

from torch.utils.ffi import _wrap_function
from ._psroi_align_pooling import lib as _lib, ffi as _ffi

__all__ = []
def _import_symbols(locals):
for symbol in dir(_lib):
fn = getattr(_lib, symbol)
if callable(fn):
locals[symbol] = _wrap_function(fn, _ffi)
else:
locals[symbol] = fn
__all__.append(symbol)

_import_symbols(locals())
37 changes: 37 additions & 0 deletions lib/model/psroi_align_pooling/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import torch
from torch.utils.ffi import create_extension

sources = []
headers = []
defines = []
with_cuda = False


if torch.cuda.is_available():
print('Including CUDA code.')
sources += ['src/psroi_align_pooling_cuda.c']
headers += ['src/psroi_align_pooling_cuda.h']
defines += [('WITH_CUDA', None)]
with_cuda = True


this_file = os.path.dirname(os.path.realpath(__file__))
print(this_file)
extra_objects = ['src/psroi_align_pooling_kernel.cu.o']
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
print(extra_objects)

ffi = create_extension(
'_ext.psroi_align_pooling',
headers=headers,
sources=sources,
define_macros=defines,
relative_to=__file__,
with_cuda=with_cuda,
extra_objects=extra_objects,
extra_compile_args=['-DDEBUG']
)

if __name__ == '__main__':
ffi.build()
Empty file.
69 changes: 69 additions & 0 deletions lib/model/psroi_align_pooling/functions/psroi_align_pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
from torch.autograd import Function
from .._ext import psroi_align_pooling

class PSRoiAlignPoolingFunction(Function):
def __init__(self,pooled_height,pooled_width,sample_height,sample_width,spatial_scale,group_size):
self.pooled_height = int(pooled_height)
self.pooled_width = int(pooled_width)
self.sample_height = int(sample_height)
self.sample_width = int(sample_width)
self.spatial_scale = float(spatial_scale)
self.group_size = int(group_size)
self.output = None
self.mapping_channel = None
self.argmax_position = None
self.rois = None
self.feature_size = None
self.output_dim = None
def forward(self, features, rois):
batch_size, num_channels, data_height, data_width = features.size()
self.output_dim = num_channels // self.pooled_height // self.pooled_width
# self.output_dim = num_channels
num_rois = rois.size()[0]
output = torch.zeros(num_rois, self.output_dim, self.pooled_height, self.pooled_width)
mapping_channel = torch.cuda.IntTensor(num_rois, self.output_dim, self.pooled_height, self.pooled_width).zero_()
argmax_position = torch.cuda.IntTensor(num_rois, self.output_dim, self.pooled_height, self.pooled_width).zero_()
output = output.cuda()
psroi_align_pooling.psroi_align_pooling_forward_cuda(self.pooled_height,
self.pooled_width,
self.sample_height,
self.sample_width,
self.spatial_scale,
self.group_size,
self.output_dim,
features,
rois,
output,
mapping_channel,
argmax_position
)
self.output = output
self.mapping_channel = mapping_channel
self.argmax_position = argmax_position
self.rois = rois
self.feature_size = features.size()

return output

def backward(self, grad_output):
assert(self.feature_size is not None and grad_output.is_cuda)

batch_size, num_channels, data_height, data_width = self.feature_size

grad_input = torch.zeros(batch_size, num_channels, data_height, data_width).cuda()
# import pdb
# pdb.set_trace()
psroi_align_pooling.psroi_align_pooling_backward_cuda(self.pooled_height,
self.pooled_width,
self.sample_height,
self.sample_width,
self.spatial_scale,
self.group_size,
self.output_dim,
grad_output,
self.rois,
grad_input,
self.mapping_channel,
self.argmax_position)
return grad_input, None
9 changes: 9 additions & 0 deletions lib/model/psroi_align_pooling/make.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/usr/bin/env bash
CUDA_PATH=/usr/local/cuda/

cd src
echo "Compiling psroi_align_pooling kernels by nvcc..."
nvcc -c -o psroi_align_pooling_kernel.cu.o psroi_align_pooling_kernel.cu.cc -x cu -Xcompiler -fPIC -arch=sm_60

cd ../
python build.py
Empty file.
17 changes: 17 additions & 0 deletions lib/model/psroi_align_pooling/modules/psroi_align_pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from torch.nn.modules.module import Module
from ..functions.psroi_align_pooling import PSRoiAlignPoolingFunction

class PSRoIAlignPool(Module):
def __init__(self, pooled_height, pooled_width,sample_height,sample_width, spatial_scale, group_size, output_dim):
super(PSRoIAlignPool, self).__init__()

self.pooled_width = int(pooled_width)
self.pooled_height = int(pooled_height)
self.sample_height = int(sample_height)
self.sample_width = int(sample_width)
self.spatial_scale = float(spatial_scale)
self.group_size = int(group_size)
self.output_dim = int(output_dim)

def forward(self, features, rois):
return PSRoiAlignPoolingFunction(self.pooled_height, self.pooled_width,self.sample_height,self.sample_width, self.spatial_scale, self.group_size, self.output_dim)(features, rois)
Loading