!7482 roialign gpu operator output is zero

Merge pull request !7482 from JonathanY/roialign_zero
This commit is contained in:
mindspore-ci-bot 2020-10-21 23:42:06 +08:00 committed by Gitee
commit 6cc37db833
5 changed files with 37 additions and 30 deletions

View File

@ -91,7 +91,7 @@ __device__ void bin_box(int thread_idx, const T *roi_boxes, int roi_cols, const
}
// Scale and shift ROI
T roi_offset = roi_end_mode == 1 ? static_cast<T>(0.5) : static_cast<T>(.0);
T roi_offset = roi_end_mode == 0 ? static_cast<T>(0.5) : static_cast<T>(.0);
*roi_start_w = roi_box[0] * spatial_scale - roi_offset;
*roi_start_h = roi_box[1] * spatial_scale - roi_offset;
T roi_end_w = roi_box[2] * spatial_scale - roi_offset;
@ -121,10 +121,9 @@ __global__ void ROIAlignKernel(size_t size, const T *input, const T *roi_boxes,
thread_idx += blockDim.x * gridDim.x) {
int n = thread_idx / pooled_width / pooled_height / channels;
const T *roi_box = roi_boxes + n * roi_cols;
if (roi_box[0] < static_cast<T>(0.001) && roi_box[1] < static_cast<T>(0.001) &&
roi_box[2] < static_cast<T>(0.001) && roi_box[3] < static_cast<T>(0.001) &&
roi_box[0] > static_cast<T>(-0.001) && roi_box[1] > static_cast<T>(-0.001) &&
roi_box[2] > static_cast<T>(-0.001) && roi_box[3] > static_cast<T>(-0.001)) {
// Skip if roi box is a line
if (roi_box[1] < static_cast<T>(0.001) && roi_box[3] < static_cast<T>(0.001) &&
roi_box[1] > static_cast<T>(-0.001) && roi_box[3] > static_cast<T>(-0.001)) {
continue;
}
@ -136,8 +135,6 @@ __global__ void ROIAlignKernel(size_t size, const T *input, const T *roi_boxes,
pooled_height, pooled_width, &offset, &n, &c, &ph, &pw, &roi_bin_grid_h, &roi_bin_grid_w, &bin_size_h,
&bin_size_w, &roi_start_h, &roi_start_w);
if (offset < 0 || offset >= size) continue;
// (n, c, ph, pw) is the base param of pooled map
const T count_points_in_grid_cell = roi_bin_grid_h * roi_bin_grid_w;
@ -209,10 +206,8 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes,
thread_idx += blockDim.x * gridDim.x) {
int n = thread_idx / pooled_width / pooled_height / channels;
const T *roi_box = roi_boxes + n * roi_cols;
if (roi_box[0] < static_cast<T>(0.001) && roi_box[1] < static_cast<T>(0.001) &&
roi_box[2] < static_cast<T>(0.001) && roi_box[3] < static_cast<T>(0.001) &&
roi_box[0] > static_cast<T>(-0.001) && roi_box[1] > static_cast<T>(-0.001) &&
roi_box[2] > static_cast<T>(-0.001) && roi_box[3] > static_cast<T>(-0.001)) {
if (roi_box[1] < static_cast<T>(0.001) && roi_box[3] < static_cast<T>(0.001) &&
roi_box[1] > static_cast<T>(-0.001) && roi_box[3] > static_cast<T>(-0.001)) {
continue;
}
@ -224,8 +219,6 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes,
pooled_height, pooled_width, &offset, &n, &c, &ph, &pw, &roi_bin_grid_h, &roi_bin_grid_w, &bin_size_h,
&bin_size_w, &roi_start_h, &roi_start_w);
if (offset < 0 || offset >= size) continue;
// (n, c, ph, pw) is the base param of pooled map
const T count_points_in_grid_cell = roi_bin_grid_h * roi_bin_grid_w;

View File

@ -62,10 +62,17 @@ def test_roi_align_grad_half():
sample_num)
output = roi_align_grad(dy, rois)
print(output)
expect = ([[[[0.0563, 0.0563, 0.0750, 0.0938, 0.1125, 0.0563],
[0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
[0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
[0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
[0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
[0.0188, 0.0188, 0.0250, 0.0312, 0.0375, 0.0188]]]])
# the out if aligned is True
# expect = ([[[[0.0563, 0.0563, 0.0750, 0.0938, 0.1125, 0.0563],
# [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
# [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
# [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
# [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
# [0.0188, 0.0188, 0.0250, 0.0312, 0.0375, 0.0188]]]])
expect = ([[[[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075]]]])
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4)

View File

@ -62,10 +62,17 @@ def test_roi_align_grad():
sample_num)
output = roi_align_grad(dy, rois)
print(output)
expect = ([[[[0.0563, 0.0563, 0.0750, 0.0938, 0.1125, 0.0563],
[0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
[0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
[0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
[0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
[0.0188, 0.0188, 0.0250, 0.0312, 0.0375, 0.0188]]]])
# the out if aligned is True
# expect = ([[[[0.0563, 0.0563, 0.0750, 0.0938, 0.1125, 0.0563],
# [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
# [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
# [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
# [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
# [0.0188, 0.0188, 0.0250, 0.0312, 0.0375, 0.0188]]]])
expect = ([[[[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075]]]])
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4)

View File

@ -39,7 +39,7 @@ def test_roi_align_half():
# test case 1
pooled_height, pooled_width, spatial_scale, sample_num = 4, 4, 0.2, 3
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num)
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, 0)
output = roi_align(x, rois)
print(output)
expect = [[[[1.2333, 2.1000, 3.3000, 4.5000],

View File

@ -39,7 +39,7 @@ def test_roi_align():
# test case 1
pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num)
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, 0)
output = roi_align(x, rois)
print(output)
expect = [[[[2.75, 4.5, 6.5],
@ -49,7 +49,7 @@ def test_roi_align():
# test case 2
pooled_height, pooled_width, spatial_scale, sample_num = 4, 4, 0.2, 3
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num)
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, 0)
output = roi_align(x, rois)
print(output)
expect = [[[[1.2333, 2.1000, 3.3000, 4.5000],
@ -63,7 +63,7 @@ def test_roi_align():
rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0],
[0, 1.0, 0.0, 19.0, 18.0]],
np.float32))
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num)
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, 0)
output = roi_align(x, rois)
print(output)
expect = [[[[3.3333, 5.5000, 7.6667],
@ -77,7 +77,7 @@ def test_roi_align():
# test case 4
pooled_height, pooled_width, spatial_scale, sample_num = 2, 2, 1.0, -1
rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float32))
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num)
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, 0)
output = roi_align(x, rois)
print(output)
expect = [[[[8.2222, 0.],