!3201 RoI Align GPU kernel

Merge pull request !3201 from JonathanY/main
This commit is contained in:
mindspore-ci-bot 2020-07-22 13:38:34 +08:00 committed by Gitee
commit d15b4c5d61
5 changed files with 519 additions and 0 deletions

View File

@ -0,0 +1,228 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "roi_align_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
inline __device__ T gpu_atomic_add(const T val, T *address);
template <>
inline __device__ float gpu_atomic_add(const float val, float *address) {
return atomicAdd(address, val);
}
template <typename T>
__device__ void bilinear_interpolate(const int height, const int width, T y, T x, int *x_low, int *y_low, int *x_high,
int *y_high, T *w1, T *w2, T *w3, T *w4) {
// return 0 if out of map boundary
if (y <= static_cast<T>(-1.0) || y >= static_cast<T>(height) || x <= static_cast<T>(-1.0) ||
x >= static_cast<T>(width)) {
*w1 = *w2 = *w3 = *w4 = 0;
*x_low = *x_high = *y_low = *y_high = -1;
return;
}
// low bounder is at least zero
y = y <= static_cast<T>(.0) ? static_cast<T>(.0) : y;
x = x <= static_cast<T>(.0) ? static_cast<T>(.0) : x;
// top left point
*y_low = static_cast<int>(y);
*x_low = static_cast<int>(x);
// bottom right point
if (*y_low >= height - 1) {
*y_high = *y_low = height - 1;
y = static_cast<T>(*y_low);
} else {
*y_high = *y_low + 1;
}
if (*x_low >= width - 1) {
*x_high = *x_low = width - 1;
x = static_cast<T>(*x_low);
} else {
*x_high = *x_low + 1;
}
// distance to nearest points
T lx, ly, hx, hy;
ly = y - static_cast<T>(*y_low), lx = x - static_cast<T>(*x_low);
hy = static_cast<T>(1.) - ly, hx = static_cast<T>(1.) - lx;
// weight is evaluated by the distance to point away.
// the closer to point home, the more weight, the farther to point away.
*w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx;
return;
}
template <typename T>
__device__ void bin_box(int thread_idx, const T *roi_boxes, int roi_cols, const T spatial_scale, const int sample_num,
int roi_end_mode, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, int *offset, int *n, int *c, int *ph, int *pw,
int *roi_bin_grid_h, int *roi_bin_grid_w, T *bin_size_h, T *bin_size_w, T *roi_start_h,
T *roi_start_w) {
// (n, c, ph, pw) is the base param of pooled map
*pw = thread_idx % pooled_width;
*ph = (thread_idx / pooled_width) % pooled_height;
*c = (thread_idx / pooled_width / pooled_height) % channels;
*n = thread_idx / pooled_width / pooled_height / channels;
// Roi has
// 1. 4 points, or
// 2. indicator + 4 points (1 + 4)
const T *roi_box = roi_boxes + (*n) * roi_cols;
int roi_batch_ind = 0;
if (roi_cols == 5) {
roi_batch_ind = roi_box[0];
roi_box++;
}
// Scale and shift ROI
T roi_offset = roi_end_mode == 1 ? static_cast<T>(0.5) : static_cast<T>(.0);
*roi_start_w = roi_box[0] * spatial_scale - roi_offset;
*roi_start_h = roi_box[1] * spatial_scale - roi_offset;
T roi_end_w = roi_box[2] * spatial_scale - roi_offset;
T roi_end_h = roi_box[3] * spatial_scale - roi_offset;
// New ROI height/width
T roi_width = roi_end_w - (*roi_start_w);
T roi_height = roi_end_h - (*roi_start_h);
// ratio of roi / pooled
*bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
*bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
*offset = (roi_batch_ind * channels + (*c)) * height * width;
// grid (int) by Sample ratio if defined, otherwise by pooled H/W
*roi_bin_grid_h = (sample_num > 0) ? sample_num : static_cast<int>(roi_height / static_cast<T>(pooled_height));
*roi_bin_grid_w = (sample_num > 0) ? sample_num : static_cast<int>(roi_width / static_cast<T>(pooled_width));
return;
}
template <typename T>
__global__ void ROIAlignKernel(size_t size, const T *input, const T *roi_boxes, int roi_cols, T *out_data,
const T spatial_scale, const int sample_num, int roi_end_mode, const int channels,
const int height, const int width, const int pooled_height, const int pooled_width) {
for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size;
thread_idx += blockDim.x * gridDim.x) {
int offset, n, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w;
T bin_size_h, bin_size_w, roi_start_h, roi_start_w;
bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width,
pooled_height, pooled_width, &offset, &n, &c, &ph, &pw, &roi_bin_grid_h, &roi_bin_grid_w, &bin_size_h,
&bin_size_w, &roi_start_h, &roi_start_w);
// (n, c, ph, pw) is the base param of pooled map
const T count_points_in_grid_cell = roi_bin_grid_h * roi_bin_grid_w;
T accumulate_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
// Shift half point RIGHT for y / x, while previous scaled roi shift half point LEFT
const T y = roi_start_h + static_cast<T>(ph) * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + static_cast<T>(pw) * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
// bilinear interpolate by shifted y / x
// calculate bilinear interpolation
int x_low, y_low, x_high, y_high;
T w1, w2, w3, w4;
bilinear_interpolate(height, width, y, x, &x_low, &y_low, &x_high, &y_high, &w1, &w2, &w3, &w4);
T v1 = input[y_low * width + x_low + offset];
T v2 = input[y_low * width + x_high + offset];
T v3 = input[y_high * width + x_low + offset];
T v4 = input[y_high * width + x_high + offset];
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
accumulate_val += val;
}
}
accumulate_val /= count_points_in_grid_cell;
out_data[thread_idx] = accumulate_val;
}
}
template <typename T>
void ROIAlign(const T *x, const T *roi_boxes, int roi_rows, int roi_cols, T *out_data, const T spatial_scale,
const int sample_num, int roi_end_mode, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, cudaStream_t cuda_stream) {
size_t size = roi_rows * channels * pooled_height * pooled_width;
ROIAlignKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, x, roi_boxes, roi_cols, out_data,
spatial_scale, sample_num, roi_end_mode, channels,
height, width, pooled_height, pooled_width);
return;
}
template void ROIAlign<float>(const float *x, const float *roi_boxes, int roi_rows, int roi_cols, float *out_data,
const float spatial_scale, const int sample_num, int roi_end_mode, const int channels,
const int height, const int width, const int pooled_height, const int pooled_width,
cudaStream_t cuda_stream);
template void ROIAlign<half>(const half *x, const half *roi_boxes, int roi_rows, int roi_cols, half *out_data,
const half spatial_scale, const int sample_num, int roi_end_mode, const int channels,
const int height, const int width, const int pooled_height, const int pooled_width,
cudaStream_t cuda_stream);
template <typename T>
__global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes, int roi_cols, T *dx,
const T spatial_scale, const int sample_num, int roi_end_mode, const int channels,
const int height, const int width, const int pooled_height, const int pooled_width) {
for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size;
thread_idx += blockDim.x * gridDim.x) {
int offset, n, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w;
T bin_size_h, bin_size_w, roi_start_h, roi_start_w;
bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width,
pooled_height, pooled_width, &offset, &n, &c, &ph, &pw, &roi_bin_grid_h, &roi_bin_grid_w, &bin_size_h,
&bin_size_w, &roi_start_h, &roi_start_w);
// (n, c, ph, pw) is the base param of pooled map
const T count_points_in_grid_cell = roi_bin_grid_h * roi_bin_grid_w;
int top_offset = (n * channels + c) * pooled_height * pooled_width;
const T *offset_top_diff = dy + top_offset;
const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
// Shift half point RIGHT for y / x, while previous scaled roi shift half point LEFT
const T y =
roi_start_h + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x =
roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
// bilinear interpolate by shifted y / x
// calculate bilinear interpolation
int x_low, y_low, x_high, y_high;
T w1, w2, w3, w4;
bilinear_interpolate(height, width, y, x, &x_low, &y_low, &x_high, &y_high, &w1, &w2, &w3, &w4);
T g1 = top_diff_this_bin * w1 / count_points_in_grid_cell;
T g2 = top_diff_this_bin * w2 / count_points_in_grid_cell;
T g3 = top_diff_this_bin * w3 / count_points_in_grid_cell;
T g4 = top_diff_this_bin * w4 / count_points_in_grid_cell;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
gpu_atomic_add(static_cast<T>(g1), dx + offset + y_low * width + x_low);
gpu_atomic_add(static_cast<T>(g2), dx + offset + y_low * width + x_high);
gpu_atomic_add(static_cast<T>(g3), dx + offset + y_high * width + x_low);
gpu_atomic_add(static_cast<T>(g4), dx + offset + y_high * width + x_high);
}
}
}
}
}

View File

@ -0,0 +1,24 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ROI_ALIGN_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ROI_ALIGN_IMPL_H_
template <typename T>
void ROIAlign(const T *x, const T *roi_boxes, int roi_rows, int roi_cols, T *out_data, const T spatial_scale,
const int sample_num, int roi_end_mode, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ROI_ALIGN_IMPL_H_

View File

@ -0,0 +1,32 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
ROIAlign,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ROIAlignGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(
ROIAlign,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ROIAlignGpuFwdKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,140 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GPU_KERNEL_H
#define MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GPU_KERNEL_H
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class ROIAlignGpuFwdKernel : public GpuKernel {
public:
ROIAlignGpuFwdKernel() : x_size_(0), rois_size_(0), output_size_(0) {}
~ROIAlignGpuFwdKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
const T *x = GetDeviceAddress<T>(inputs, 0);
const T *rois = GetDeviceAddress<T>(inputs, 1);
T *out_data = GetDeviceAddress<T>(outputs, 0);
ROIAlign(x, rois, roi_rows_, roi_cols_, out_data, spatial_scale_, sample_num_, roi_end_mode_, channels_, height_,
width_, pooled_height_, pooled_width_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
// Get the number of input args
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but RioAlign needs 2 input.";
return false;
}
// Get the number of output args
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but RioAlign needs 1 output.";
return false;
}
// Get the input shapes
auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto rois_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto x_shape_size = x_shape.size();
if (x_shape_size < 2) {
MS_LOG(ERROR) << "x shape szie is " << x_shape_size << ", but at lease 2D.";
return false;
}
// Get channels, height & width
channels_ = x_shape_size >= 3 ? x_shape[x_shape_size - 3] : 1;
height_ = x_shape[x_shape_size - 2];
width_ = x_shape[x_shape_size - 1];
x_shape_ = {channels_, height_, width_};
x_size_ = channels_ * height_ * width_ * sizeof(T);
// Get rois rows and cols
roi_rows_ = rois_shape[0];
roi_cols_ = rois_shape[1];
rois_size_ = roi_rows_ * roi_cols_ * sizeof(T);
rois_shape_ = {roi_rows_, roi_cols_};
// Get primitive args
pooled_height_ = GetAttr<int>(kernel_node, "pooled_height");
pooled_width_ = GetAttr<int>(kernel_node, "pooled_width");
spatial_scale_ = static_cast<T>(GetAttr<float>(kernel_node, "spatial_scale"));
sample_num_ = GetAttr<int>(kernel_node, "sample_num");
roi_end_mode_ = GetAttr<int>(kernel_node, "roi_end_mode");
// Get output_shape
output_shape_ = {roi_rows_, channels_, pooled_height_, pooled_width_};
output_size_ = 1;
for (size_t i = 0; i < 4; i++) {
output_size_ *= output_shape_[i];
}
output_size_ *= sizeof(T);
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(x_size_);
input_size_list_.push_back(rois_size_);
output_size_list_.push_back(output_size_);
}
private:
int pooled_height_;
int pooled_width_;
T spatial_scale_;
int sample_num_;
int roi_end_mode_;
int roi_rows_;
int roi_cols_;
int channels_;
int height_;
int width_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
std::vector<int> x_shape_;
std::vector<int> rois_shape_;
std::vector<int> output_shape_;
size_t x_size_;
size_t rois_size_;
size_t output_size_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GPU_KERNEL_H

View File

@ -0,0 +1,95 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_roi_align():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x = Tensor(np.array([[
[[1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18],
[19, 20, 21, 22, 23, 24],
[25, 26, 27, 28, 29, 30],
[31, 32, 33, 34, 35, 36]]
]], np.float32))
rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float32))
# test case 1
pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num)
output = roi_align(x, rois)
print(output)
expect = [[[[2.75, 4.5, 6.5],
[13.25, 15., 17.],
[25.25, 27., 29.]]]]
assert (output.asnumpy() == expect).all()
# test case 1
pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num)
output = roi_align(x, rois)
print(output)
expect = [[[[2.75, 4.5, 6.5],
[13.25, 15., 17.],
[25.25, 27., 29.]]]]
assert (output.asnumpy() == expect).all()
# test case 2
pooled_height, pooled_width, spatial_scale, sample_num = 4, 4, 0.2, 3
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num)
output = roi_align(x, rois)
print(output)
expect = [[[[1.2333, 2.1000, 3.3000, 4.5000],
[6.4333, 7.3000, 8.5000, 9.7000],
[13.6333, 14.5000, 15.7000, 16.9000],
[20.8333, 21.7000, 22.9000, 24.1000]]]]
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4)
# test case 3
pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.3, 3
rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0],
[0, 1.0, 0.0, 19.0, 18.0]],
np.float32))
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num)
output = roi_align(x, rois)
print(output)
expect = [[[[3.3333, 5.5000, 7.6667],
[16.3333, 18.5000, 20.6667],
[29.3333, 31.5000, 33.6667]]],
[[[4.5000, 6.3000, 8.1000],
[14.9000, 16.7000, 18.5000],
[25.7000, 27.5000, 29.3000]]]]
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4)
# test case 4
pooled_height, pooled_width, spatial_scale, sample_num = 2, 2, 1.0, -1
rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float32))
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num)
output = roi_align(x, rois)
print(output)
expect = [[[[4.625, 0.],
[0., 0.]]]]
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4)