suport half for roi align

This commit is contained in:
Jonathan Yan 2020-08-05 13:21:28 -04:00
parent 6aa65da5d2
commit 43094bf78e
7 changed files with 155 additions and 21 deletions

View File

@ -15,8 +15,12 @@
*/
#include "roi_align_impl.cuh"
#include "util.cuh"
#include "runtime/device/gpu/cuda_common.h"
inline __device__ int roi_cast_int(float x) { return static_cast<int>(x); }
inline __device__ int roi_cast_int(half x) { return __half2int_rd(x); }
template <typename T>
__device__ void bilinear_interpolate(const int height, const int width, T y, T x, int *x_low, int *y_low, int *x_high,
int *y_high, T *w1, T *w2, T *w3, T *w4) {
@ -33,8 +37,8 @@ __device__ void bilinear_interpolate(const int height, const int width, T y, T x
x = x <= static_cast<T>(.0) ? static_cast<T>(.0) : x;
// top left point
*y_low = static_cast<int>(y);
*x_low = static_cast<int>(x);
*y_low = roi_cast_int(y);
*x_low = roi_cast_int(x);
// bottom right point
if (*y_low >= height - 1) {
@ -102,8 +106,8 @@ __device__ void bin_box(int thread_idx, const T *roi_boxes, int roi_cols, const
*offset = (roi_batch_ind * channels + (*c)) * height * width;
// grid (int) by Sample ratio if defined, otherwise by pooled H/W
*roi_bin_grid_h = (sample_num > 0) ? sample_num : static_cast<int>(roi_height / static_cast<T>(pooled_height));
*roi_bin_grid_w = (sample_num > 0) ? sample_num : static_cast<int>(roi_width / static_cast<T>(pooled_width));
*roi_bin_grid_h = (sample_num > 0) ? sample_num : roi_cast_int(roi_height / static_cast<T>(pooled_height));
*roi_bin_grid_w = (sample_num > 0) ? sample_num : roi_cast_int(roi_width / static_cast<T>(pooled_width));
return;
}
@ -209,11 +213,15 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes,
T g3 = top_diff_this_bin * w3 / count_points_in_grid_cell;
T g4 = top_diff_this_bin * w4 / count_points_in_grid_cell;
T *dx_1 = dx + offset + y_low * width + x_low;
T *dx_2 = dx + offset + y_low * width + x_high;
T *dx_3 = dx + offset + y_high * width + x_low;
T *dx_4 = dx + offset + y_high * width + x_high;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomicAdd(dx + offset + y_low * width + x_low, static_cast<T>(g1));
atomicAdd(dx + offset + y_low * width + x_high, static_cast<T>(g2));
atomicAdd(dx + offset + y_high * width + x_low, static_cast<T>(g3));
atomicAdd(dx + offset + y_high * width + x_high, static_cast<T>(g4));
ms_atomic_add(dx_1, g1);
ms_atomic_add(dx_2, g2);
ms_atomic_add(dx_3, g3);
ms_atomic_add(dx_4, g4);
}
}
}
@ -235,3 +243,8 @@ template void ROIAlignGrad<float>(const float *dy, const float *roi_boxes, int r
const float spatial_scale, const int sample_num, int roi_end_mode, const int channels,
const int height, const int width, const int pooled_height, const int pooled_width,
cudaStream_t cuda_stream);
template void ROIAlignGrad<half>(const half *dy, const half *roi_boxes, int roi_rows, int roi_cols, half *dx,
const half spatial_scale, const int sample_num, int roi_end_mode, const int channels,
const int height, const int width, const int pooled_height, const int pooled_width,
cudaStream_t cuda_stream);

View File

@ -14,6 +14,9 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UTIL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UTIL_H_
#include <cuda_fp16.h>
inline __device__ float ms_atomic_add(float *address, float val) { return atomicAdd(address, val); }
@ -38,3 +41,5 @@ inline __device__ half ms_atomic_add(half *address, half val) {
__half_raw raw = {old_as_us};
return half(raw);
}
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UTIL_H_

View File

@ -23,5 +23,10 @@ MS_REG_GPU_KERNEL_ONE(
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ROIAlignGradGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(
ROIAlignGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ROIAlignGradGpuFwdKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -42,6 +42,7 @@ class ROIAlignGradGpuFwdKernel : public GpuKernel {
ROIAlignGrad(dy, rois, roi_rows_, roi_cols_, dx, spatial_scale_, sample_num_, roi_end_mode_, channels_, height_,
width_, pooled_height_, pooled_width_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}

View File

@ -0,0 +1,71 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetROIAlignGrad(nn.Cell):
def __init__(self, xdiff_shape, pooled_height, pooled_width, spatial_scale, sample_num):
super(NetROIAlignGrad, self).__init__()
self.roiAlignGrad = G.ROIAlignGrad(
xdiff_shape,
pooled_height,
pooled_width,
spatial_scale,
sample_num)
def construct(self, dy, rois):
return self.roiAlignGrad(dy, rois)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_roi_align_grad_half():
rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float16))
dy = Tensor(np.array([[[
[.1, .2, .3],
[.1, .2, .3],
[.1, .2, .3]
]]], np.float16))
xdiff_shape = (1, 1, 6, 6)
pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
roi_align_grad = NetROIAlignGrad(
xdiff_shape,
pooled_height,
pooled_width,
spatial_scale,
sample_num)
output = roi_align_grad(dy, rois)
print(output)
expect = ([[[[0.0563, 0.0563, 0.0750, 0.0938, 0.1125, 0.0563],
[0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
[0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
[0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
[0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
[0.0188, 0.0188, 0.0250, 0.0312, 0.0375, 0.0188]]]])
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4)

View File

@ -0,0 +1,49 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_roi_align_half():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x = Tensor(np.array([[
[[1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18],
[19, 20, 21, 22, 23, 24],
[25, 26, 27, 28, 29, 30],
[31, 32, 33, 34, 35, 36]]
]], np.float16))
rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float16))
# test case 1
pooled_height, pooled_width, spatial_scale, sample_num = 4, 4, 0.2, 3
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num)
output = roi_align(x, rois)
print(output)
expect = [[[[1.2333, 2.1000, 3.3000, 4.5000],
[6.4333, 7.3000, 8.5000, 9.7000],
[13.6333, 14.5000, 15.7000, 16.9000],
[20.8333, 21.7000, 22.9000, 24.1000]]]]
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=1)

View File

@ -47,16 +47,6 @@ def test_roi_align():
[25.25, 27., 29.]]]]
assert (output.asnumpy() == expect).all()
# test case 1
pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num)
output = roi_align(x, rois)
print(output)
expect = [[[[2.75, 4.5, 6.5],
[13.25, 15., 17.],
[25.25, 27., 29.]]]]
assert (output.asnumpy() == expect).all()
# test case 2
pooled_height, pooled_width, spatial_scale, sample_num = 4, 4, 0.2, 3
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num)