From 43094bf78e38e1bb3d9784038855d07d1886156b Mon Sep 17 00:00:00 2001 From: Jonathan Yan Date: Wed, 5 Aug 2020 13:21:28 -0400 Subject: [PATCH] suport half for roi align --- .../gpu/cuda_impl/roi_align_impl.cu | 29 +++++--- .../kernel_compiler/gpu/cuda_impl/util.cuh | 11 ++- .../gpu/nn/roi_align_grad_gpu_kernel.cc | 5 ++ .../gpu/nn/roi_align_grad_gpu_kernel.h | 1 + .../st/ops/gpu/test_roi_align_grad_half_op.py | 71 +++++++++++++++++++ tests/st/ops/gpu/test_roi_align_half_op.py | 49 +++++++++++++ tests/st/ops/gpu/test_roi_align_op.py | 10 --- 7 files changed, 155 insertions(+), 21 deletions(-) create mode 100644 tests/st/ops/gpu/test_roi_align_grad_half_op.py create mode 100644 tests/st/ops/gpu/test_roi_align_half_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu index 5706aa15fcb..789abcf0f73 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu @@ -15,8 +15,12 @@ */ #include "roi_align_impl.cuh" +#include "util.cuh" #include "runtime/device/gpu/cuda_common.h" +inline __device__ int roi_cast_int(float x) { return static_cast(x); } +inline __device__ int roi_cast_int(half x) { return __half2int_rd(x); } + template __device__ void bilinear_interpolate(const int height, const int width, T y, T x, int *x_low, int *y_low, int *x_high, int *y_high, T *w1, T *w2, T *w3, T *w4) { @@ -33,8 +37,8 @@ __device__ void bilinear_interpolate(const int height, const int width, T y, T x x = x <= static_cast(.0) ? static_cast(.0) : x; // top left point - *y_low = static_cast(y); - *x_low = static_cast(x); + *y_low = roi_cast_int(y); + *x_low = roi_cast_int(x); // bottom right point if (*y_low >= height - 1) { @@ -102,8 +106,8 @@ __device__ void bin_box(int thread_idx, const T *roi_boxes, int roi_cols, const *offset = (roi_batch_ind * channels + (*c)) * height * width; // grid (int) by Sample ratio if defined, otherwise by pooled H/W - *roi_bin_grid_h = (sample_num > 0) ? sample_num : static_cast(roi_height / static_cast(pooled_height)); - *roi_bin_grid_w = (sample_num > 0) ? sample_num : static_cast(roi_width / static_cast(pooled_width)); + *roi_bin_grid_h = (sample_num > 0) ? sample_num : roi_cast_int(roi_height / static_cast(pooled_height)); + *roi_bin_grid_w = (sample_num > 0) ? sample_num : roi_cast_int(roi_width / static_cast(pooled_width)); return; } @@ -209,11 +213,15 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes, T g3 = top_diff_this_bin * w3 / count_points_in_grid_cell; T g4 = top_diff_this_bin * w4 / count_points_in_grid_cell; + T *dx_1 = dx + offset + y_low * width + x_low; + T *dx_2 = dx + offset + y_low * width + x_high; + T *dx_3 = dx + offset + y_high * width + x_low; + T *dx_4 = dx + offset + y_high * width + x_high; if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - atomicAdd(dx + offset + y_low * width + x_low, static_cast(g1)); - atomicAdd(dx + offset + y_low * width + x_high, static_cast(g2)); - atomicAdd(dx + offset + y_high * width + x_low, static_cast(g3)); - atomicAdd(dx + offset + y_high * width + x_high, static_cast(g4)); + ms_atomic_add(dx_1, g1); + ms_atomic_add(dx_2, g2); + ms_atomic_add(dx_3, g3); + ms_atomic_add(dx_4, g4); } } } @@ -235,3 +243,8 @@ template void ROIAlignGrad(const float *dy, const float *roi_boxes, int r const float spatial_scale, const int sample_num, int roi_end_mode, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, cudaStream_t cuda_stream); + +template void ROIAlignGrad(const half *dy, const half *roi_boxes, int roi_rows, int roi_cols, half *dx, + const half spatial_scale, const int sample_num, int roi_end_mode, const int channels, + const int height, const int width, const int pooled_height, const int pooled_width, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh index 9da273a6618..2b216baa8e9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh @@ -14,6 +14,9 @@ * limitations under the License. */ +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UTIL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UTIL_H_ + #include inline __device__ float ms_atomic_add(float *address, float val) { return atomicAdd(address, val); } @@ -25,12 +28,12 @@ inline __device__ half ms_atomic_add(half *address, half val) { reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); unsigned int old = *aligned; unsigned int assumed; - unsigned short old_as_us; //NOLINT + unsigned short old_as_us; // NOLINT do { assumed = old; - old_as_us = static_cast(reinterpret_cast(address) & 2 ? old >> 16 : old & 0xffff); //NOLINT + old_as_us = static_cast(reinterpret_cast(address) & 2 ? old >> 16 : old & 0xffff); // NOLINT half sum = __float2half_rn(__half2float(__ushort_as_half(old_as_us)) + static_cast(val)); - unsigned short sum_as_us = __half_as_ushort(sum); //NOLINT + unsigned short sum_as_us = __half_as_ushort(sum); // NOLINT unsigned int sum_as_ui = reinterpret_cast(address) & 2 ? (sum_as_us << 16) | (old & 0xffff) : (old & 0xffff0000) | sum_as_us; old = atomicCAS(aligned, assumed, sum_as_ui); @@ -38,3 +41,5 @@ inline __device__ half ms_atomic_add(half *address, half val) { __half_raw raw = {old_as_us}; return half(raw); } + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UTIL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.cc index 5d08e3d4702..42f310f5e56 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.cc @@ -23,5 +23,10 @@ MS_REG_GPU_KERNEL_ONE( KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ROIAlignGradGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE( + ROIAlignGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ROIAlignGradGpuFwdKernel, half) + } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.h index 5d63083e03d..4a445630675 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.h @@ -42,6 +42,7 @@ class ROIAlignGradGpuFwdKernel : public GpuKernel { ROIAlignGrad(dy, rois, roi_rows_, roi_cols_, dx, spatial_scale_, sample_num_, roi_end_mode_, channels_, height_, width_, pooled_height_, pooled_width_, reinterpret_cast(stream_ptr)); + return true; } diff --git a/tests/st/ops/gpu/test_roi_align_grad_half_op.py b/tests/st/ops/gpu/test_roi_align_grad_half_op.py new file mode 100644 index 00000000000..f8dbd57c118 --- /dev/null +++ b/tests/st/ops/gpu/test_roi_align_grad_half_op.py @@ -0,0 +1,71 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class NetROIAlignGrad(nn.Cell): + def __init__(self, xdiff_shape, pooled_height, pooled_width, spatial_scale, sample_num): + super(NetROIAlignGrad, self).__init__() + self.roiAlignGrad = G.ROIAlignGrad( + xdiff_shape, + pooled_height, + pooled_width, + spatial_scale, + sample_num) + + def construct(self, dy, rois): + return self.roiAlignGrad(dy, rois) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_roi_align_grad_half(): + rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float16)) + + dy = Tensor(np.array([[[ + [.1, .2, .3], + [.1, .2, .3], + [.1, .2, .3] + ]]], np.float16)) + + xdiff_shape = (1, 1, 6, 6) + pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2 + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + roi_align_grad = NetROIAlignGrad( + xdiff_shape, + pooled_height, + pooled_width, + spatial_scale, + sample_num) + output = roi_align_grad(dy, rois) + print(output) + expect = ([[[[0.0563, 0.0563, 0.0750, 0.0938, 0.1125, 0.0563], + [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375], + [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375], + [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375], + [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375], + [0.0188, 0.0188, 0.0250, 0.0312, 0.0375, 0.0188]]]]) + np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4) diff --git a/tests/st/ops/gpu/test_roi_align_half_op.py b/tests/st/ops/gpu/test_roi_align_half_op.py new file mode 100644 index 00000000000..d74fb6cdd65 --- /dev/null +++ b/tests/st/ops/gpu/test_roi_align_half_op.py @@ -0,0 +1,49 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +from mindspore.ops import operations as P + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_roi_align_half(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + x = Tensor(np.array([[ + [[1, 2, 3, 4, 5, 6], + [7, 8, 9, 10, 11, 12], + [13, 14, 15, 16, 17, 18], + [19, 20, 21, 22, 23, 24], + [25, 26, 27, 28, 29, 30], + [31, 32, 33, 34, 35, 36]] + ]], np.float16)) + + rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float16)) + + # test case 1 + pooled_height, pooled_width, spatial_scale, sample_num = 4, 4, 0.2, 3 + roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num) + output = roi_align(x, rois) + print(output) + expect = [[[[1.2333, 2.1000, 3.3000, 4.5000], + [6.4333, 7.3000, 8.5000, 9.7000], + [13.6333, 14.5000, 15.7000, 16.9000], + [20.8333, 21.7000, 22.9000, 24.1000]]]] + np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=1) diff --git a/tests/st/ops/gpu/test_roi_align_op.py b/tests/st/ops/gpu/test_roi_align_op.py index e31d54ef050..5a7f4d49528 100644 --- a/tests/st/ops/gpu/test_roi_align_op.py +++ b/tests/st/ops/gpu/test_roi_align_op.py @@ -47,16 +47,6 @@ def test_roi_align(): [25.25, 27., 29.]]]] assert (output.asnumpy() == expect).all() - # test case 1 - pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2 - roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num) - output = roi_align(x, rois) - print(output) - expect = [[[[2.75, 4.5, 6.5], - [13.25, 15., 17.], - [25.25, 27., 29.]]]] - assert (output.asnumpy() == expect).all() - # test case 2 pooled_height, pooled_width, spatial_scale, sample_num = 4, 4, 0.2, 3 roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num)