roi align grad v1

This commit is contained in:
Jonathan Yan 2020-07-24 00:36:32 -04:00
parent 402378a6d9
commit ad40e00228
6 changed files with 278 additions and 24 deletions

View File

@ -16,13 +16,6 @@
#include "roi_align_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
inline __device__ T gpu_atomic_add(const T val, T *address);
template <>
inline __device__ float gpu_atomic_add(const float val, float *address) {
return atomicAdd(address, val);
}
template <typename T>
__device__ void bilinear_interpolate(const int height, const int width, T y, T x, int *x_low, int *y_low, int *x_high,
@ -201,11 +194,11 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes,
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
// Shift half point RIGHT for y / x, while previous scaled roi shift half point LEFT
const T y =
roi_start_h + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h);
const T y = roi_start_h + static_cast<T>(ph) * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x =
roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
const T x = roi_start_w + static_cast<T>(pw) * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
// bilinear interpolate by shifted y / x
// calculate bilinear interpolation
int x_low, y_low, x_high, y_high;
@ -217,12 +210,28 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes,
T g4 = top_diff_this_bin * w4 / count_points_in_grid_cell;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
gpu_atomic_add(static_cast<T>(g1), dx + offset + y_low * width + x_low);
gpu_atomic_add(static_cast<T>(g2), dx + offset + y_low * width + x_high);
gpu_atomic_add(static_cast<T>(g3), dx + offset + y_high * width + x_low);
gpu_atomic_add(static_cast<T>(g4), dx + offset + y_high * width + x_high);
atomicAdd(dx + offset + y_low * width + x_low, static_cast<T>(g1));
atomicAdd(dx + offset + y_low * width + x_high, static_cast<T>(g2));
atomicAdd(dx + offset + y_high * width + x_low, static_cast<T>(g3));
atomicAdd(dx + offset + y_high * width + x_high, static_cast<T>(g4));
}
}
}
}
}
template <typename T>
void ROIAlignGrad(const T *dy, const T *roi_boxes, int roi_rows, int roi_cols, T *dx, const T spatial_scale,
const int sample_num, int roi_end_mode, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, cudaStream_t cuda_stream) {
size_t size = roi_rows * channels * pooled_height * pooled_width;
ROIAlignGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
size, dy, roi_boxes, roi_cols, dx, spatial_scale, sample_num, roi_end_mode, channels, height, width, pooled_height,
pooled_width);
return;
}
template void ROIAlignGrad<float>(const float *dy, const float *roi_boxes, int roi_rows, int roi_cols, float *dx,
const float spatial_scale, const int sample_num, int roi_end_mode, const int channels,
const int height, const int width, const int pooled_height, const int pooled_width,
cudaStream_t cuda_stream);

View File

@ -21,4 +21,9 @@ void ROIAlign(const T *x, const T *roi_boxes, int roi_rows, int roi_cols, T *out
const int sample_num, int roi_end_mode, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, cudaStream_t cuda_stream);
template <typename T>
void ROIAlignGrad(const T *dy, const T *roi_boxes, int roi_rows, int roi_cols, T *dx, const T spatial_scale,
const int sample_num, int roi_end_mode, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ROI_ALIGN_IMPL_H_

View File

@ -49,14 +49,14 @@ class ROIAlignGpuFwdKernel : public GpuKernel {
// Get the number of input args
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but RioAlign needs 2 input.";
MS_LOG(ERROR) << "Input number is " << input_num << ", but ROIAlign needs 2 input.";
return false;
}
// Get the number of output args
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but RioAlign needs 1 output.";
MS_LOG(ERROR) << "Output number is " << output_num << ", but ROIAlign needs 1 output.";
return false;
}
@ -65,17 +65,18 @@ class ROIAlignGpuFwdKernel : public GpuKernel {
auto rois_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto x_shape_size = x_shape.size();
if (x_shape_size < 2) {
MS_LOG(ERROR) << "x shape szie is " << x_shape_size << ", but at lease 2D.";
if (x_shape_size != 4) {
MS_LOG(ERROR) << "x shape size is " << x_shape_size << ", but shoud be 4.";
return false;
}
// Get channels, height & width
channels_ = x_shape_size >= 3 ? x_shape[x_shape_size - 3] : 1;
height_ = x_shape[x_shape_size - 2];
width_ = x_shape[x_shape_size - 1];
x_shape_ = {channels_, height_, width_};
x_size_ = channels_ * height_ * width_ * sizeof(T);
int batch_N = x_shape[0];
channels_ = x_shape[1];
height_ = x_shape[2];
width_ = x_shape[3];
x_shape_ = {batch_N, channels_, height_, width_};
x_size_ = batch_N * channels_ * height_ * width_ * sizeof(T);
// Get rois rows and cols
roi_rows_ = rois_shape[0];

View File

@ -0,0 +1,27 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/roi_align_grad_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
ROIAlignGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ROIAlignGradGpuFwdKernel, float)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,141 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GRAD_GPU_KERNEL_H
#define MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GRAD_GPU_KERNEL_H
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class ROIAlignGradGpuFwdKernel : public GpuKernel {
public:
ROIAlignGradGpuFwdKernel() : dy_size_(0), rois_size_(0), output_size_(0) {}
~ROIAlignGradGpuFwdKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
const T *dy = GetDeviceAddress<T>(inputs, 0);
const T *rois = GetDeviceAddress<T>(inputs, 1);
T *dx = GetDeviceAddress<T>(outputs, 0);
ROIAlignGrad(dy, rois, roi_rows_, roi_cols_, dx, spatial_scale_, sample_num_, roi_end_mode_, channels_, height_,
width_, pooled_height_, pooled_width_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
// Get the number of input args
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but ROIAlignGrad needs 2 input.";
return false;
}
// Get the number of output args
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but ROIAlignGrad needs 1 output.";
return false;
}
// Get the input shapes
auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto rois_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto dy_shape_size = dy_shape.size();
if (dy_shape_size != 4) {
MS_LOG(ERROR) << "dy shape size is " << dy_shape_size << ", but shoud be 4.";
return false;
}
// Parse y diff
dy_shape_ = {static_cast<int>(dy_shape[0]), static_cast<int>(dy_shape[1]), static_cast<int>(dy_shape[2]),
static_cast<int>(dy_shape[3])};
dy_size_ = dy_shape_[0] * dy_shape_[1] * dy_shape_[2] * dy_shape_[3] * sizeof(T);
// Get rois rows and cols
roi_rows_ = rois_shape[0];
roi_cols_ = rois_shape[1];
rois_shape_ = {roi_rows_, roi_cols_};
rois_size_ = roi_rows_ * roi_cols_ * sizeof(T);
// Get primitive args
xdiff_shape_ = GetAttr<std::vector<int>>(kernel_node, "xdiff_shape");
pooled_height_ = GetAttr<int>(kernel_node, "pooled_height");
pooled_width_ = GetAttr<int>(kernel_node, "pooled_width");
spatial_scale_ = static_cast<T>(GetAttr<float>(kernel_node, "spatial_scale"));
sample_num_ = GetAttr<int>(kernel_node, "sample_num");
roi_end_mode_ = 1;
// Get channels, height & width
channels_ = xdiff_shape_[1];
height_ = xdiff_shape_[2];
width_ = xdiff_shape_[3];
// Get output_shape
output_shape_ = {roi_rows_, channels_, height_, width_};
output_size_ = roi_rows_ * channels_ * height_ * width_ * sizeof(T);
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(dy_size_);
input_size_list_.push_back(rois_size_);
output_size_list_.push_back(output_size_);
}
private:
std::vector<int> xdiff_shape_;
int pooled_height_;
int pooled_width_;
T spatial_scale_;
int sample_num_;
int roi_end_mode_;
int roi_rows_;
int roi_cols_;
int channels_;
int height_;
int width_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
std::vector<int> dy_shape_;
std::vector<int> rois_shape_;
std::vector<int> output_shape_;
size_t dy_size_;
size_t rois_size_;
size_t output_size_;
}; // namespace kernel
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GRAD_GPU_KERNEL_H

View File

@ -0,0 +1,71 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetROIAlignGrad(nn.Cell):
def __init__(self, xdiff_shape, pooled_height, pooled_width, spatial_scale, sample_num):
super(NetROIAlignGrad, self).__init__()
self.roiAlignGrad = G.ROIAlignGrad(
xdiff_shape,
pooled_height,
pooled_width,
spatial_scale,
sample_num)
def construct(self, dy, rois):
return self.roiAlignGrad(dy, rois)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_roi_align_grad():
rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float32))
dy = Tensor(np.array([[[
[.1, .2, .3],
[.1, .2, .3],
[.1, .2, .3]
]]], np.float32))
xdiff_shape = (1, 1, 6, 6)
pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
roi_align_grad = NetROIAlignGrad(
xdiff_shape,
pooled_height,
pooled_width,
spatial_scale,
sample_num)
output = roi_align_grad(dy, rois)
print(output)
expect = ([[[[0.0563, 0.0563, 0.0750, 0.0938, 0.1125, 0.0563],
[0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
[0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
[0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
[0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
[0.0188, 0.0188, 0.0250, 0.0312, 0.0375, 0.0188]]]])
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4)