rewrite gpu implementation of op maxpool_with_argmax and maxpool_with_argmax_grad

This commit is contained in:
zhouyuanshen 2020-12-22 09:23:32 +08:00
parent 4477b97465
commit 09b68f9006
5 changed files with 109 additions and 234 deletions

View File

@ -18,130 +18,34 @@
#include "maxpool_with_argmax_grad_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
#include "include/cuda_fp16.h"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
template <typename T, typename S>
__global__ void MaxPoolWithArgmaxGrad(const T* x,
const T* dy,
__global__ void MaxPoolWithArgmaxGrad(const T* dy,
const S* index,
const int n,
const int c,
const int xHeight,
const int xWidth,
const int dyHeight,
const int dyWidth,
const int windowHeight,
const int windowWidth,
const int strideHeight,
const int strideWidth,
const int padTop,
const int padLeft,
const int xNCHW,
const int xCHW,
const int xHW,
const int dyCHW,
const int dyHW,
const int dyNCHW,
T* dx) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x;
pos < (xNCHW);
pos += blockDim.x * gridDim.x) {
const int posn = pos / xCHW;
const int posc = pos / xHW % c;
const int posh = pos / xHeight % xHeight;
const int posw = pos % xWidth;
const S posIdx = posh*xWidth + posw;
int hstart = posh+padTop;
if (hstart < windowHeight) {
hstart = 0;
} else {
hstart = (hstart-windowHeight)/strideHeight + 1;
}
int wstart = posw+padLeft;
if (wstart < windowWidth) {
wstart = 0;
} else {
wstart = (wstart-windowWidth)/strideWidth + 1;
}
const int hend = min((posh+padTop)/strideHeight +1, dyHeight);
const int wend = min((posw+padLeft)/strideWidth +1, dyWidth);
const int channelStart = posn*dyCHW + posc*dyHW;
T dySum = static_cast<T>(0.0);
for (int hcur = hstart; hcur < hend; ++hcur) {
for (int wcur = wstart; wcur < wend; ++wcur) {
const int curIdx = hcur*dyWidth + wcur;
S maxIdx = index[channelStart+curIdx];
if (maxIdx == posIdx) {
dySum += dy[channelStart+curIdx];
}
}
}
dx[pos] = dySum;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (dyNCHW); pos += blockDim.x * gridDim.x) {
const S idx = index[pos];
const int posn = pos / dyCHW;
MsAtomicAdd(dx + posn*xCHW + static_cast<int>(idx), dy[pos]);
}
return;
}
template <>
__global__ void MaxPoolWithArgmaxGrad(const half* x,
const half* dy,
const int* index,
const int n,
const int c,
const int xHeight,
const int xWidth,
const int dyHeight,
const int dyWidth,
const int windowHeight,
const int windowWidth,
const int strideHeight,
const int strideWidth,
const int padTop,
const int padLeft,
const int xNCHW,
const int xCHW,
const int xHW,
const int dyCHW,
const int dyHW,
half* dx) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x;
pos < (xNCHW);
pos += blockDim.x * gridDim.x) {
const int posn = pos / xCHW;
const int posc = pos / xHW % c;
const int posh = pos / xHeight % xHeight;
const int posw = pos % xWidth;
const int posIdx = posh*xWidth + posw;
int hstart = posh+padTop;
if (hstart < windowHeight) {
hstart = 0;
} else {
hstart = (hstart-windowHeight)/strideHeight + 1;
template <typename T>
__global__ void InitOutput(const int size, T *output) {
T zero = 0;
for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < size; id += blockDim.x * gridDim.x) {
output[id] = zero;
}
int wstart = posw+padLeft;
if (wstart < windowWidth) {
wstart = 0;
} else {
wstart = (wstart-windowWidth)/strideWidth + 1;
}
const int hend = min((posh+padTop)/strideHeight +1, dyHeight);
const int wend = min((posw+padLeft)/strideWidth +1, dyWidth);
const int channelStart = posn*dyCHW + posc*dyHW;
float dySum = 0.0f;
for (int hcur = hstart; hcur < hend; ++hcur) {
for (int wcur = wstart; wcur < wend; ++wcur) {
const int curIdx = hcur*dyWidth + wcur;
int maxIdx = index[channelStart+curIdx];
if (maxIdx == posIdx) {
dySum += __half2float(dy[channelStart+curIdx]);
}
}
}
dx[pos] = __float2half(dySum);
}
return;
return;
}
template <typename T, typename S>
void CalMaxPoolWithArgmaxGrad(const T* x,
const T* dy,
void CalMaxPoolWithArgmaxGrad(const T* dy,
const S* index,
const int n,
const int c,
@ -149,12 +53,6 @@ void CalMaxPoolWithArgmaxGrad(const T* x,
const int xWidth,
const int dyHeight,
const int dyWidth,
const int windowHeight,
const int windowWidth,
const int strideHeight,
const int strideWidth,
const int padTop,
const int padLeft,
T* dx,
cudaStream_t cuda_stream) {
const int xHW = xHeight*xWidth;
@ -162,36 +60,22 @@ void CalMaxPoolWithArgmaxGrad(const T* x,
const int xNCHW = n*xCHW;
const int dyHW = dyHeight*dyWidth;
const int dyCHW = c*dyHW;
MaxPoolWithArgmaxGrad<<<GET_BLOCKS(xNCHW),
const int dyNCHW = n*dyCHW;
InitOutput<<<GET_BLOCKS(xNCHW), GET_THREADS, 0, cuda_stream>>>(xNCHW, dx);
MaxPoolWithArgmaxGrad<<<GET_BLOCKS(dyNCHW),
GET_THREADS,
0,
cuda_stream>>>(
x,
dy,
index,
n,
c,
xHeight,
xWidth,
dyHeight,
dyWidth,
windowHeight,
windowWidth,
strideHeight,
strideWidth,
padTop,
padLeft,
xNCHW,
xCHW,
xHW,
dyCHW,
dyHW,
dyNCHW,
dx);
return;
}
template void CalMaxPoolWithArgmaxGrad<float, int>(const float* x,
const float* dy,
template void CalMaxPoolWithArgmaxGrad<float, int>(const float* dy,
const int* index,
const int n,
const int c,
@ -199,16 +83,9 @@ template void CalMaxPoolWithArgmaxGrad<float, int>(const float* x,
const int xWidth,
const int dyHeight,
const int dyWidth,
const int windowHeight,
const int windowWidth,
const int strideHeight,
const int strideWidth,
const int padTop,
const int padLeft,
float* dx,
cudaStream_t cuda_stream);
template void CalMaxPoolWithArgmaxGrad<half, int>(const half* x,
const half* dy,
template void CalMaxPoolWithArgmaxGrad<half, int>(const half* dy,
const int* index,
const int n,
const int c,
@ -216,11 +93,5 @@ template void CalMaxPoolWithArgmaxGrad<half, int>(const half* x,
const int xWidth,
const int dyHeight,
const int dyWidth,
const int windowHeight,
const int windowWidth,
const int strideHeight,
const int strideWidth,
const int padTop,
const int padLeft,
half* dx,
cudaStream_t cuda_stream);

View File

@ -17,9 +17,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_
template <typename T, typename S>
void CalMaxPoolWithArgmaxGrad(const T* x, const T* dy, const S* index, const int n, const int c, const int xHeight,
const int xWidth, const int dyHeight, const int dyWidth, const int windowHeight,
const int windowWidth, const int strideHeight, const int strideWidth, const int padTop,
const int padLeft, T* dx, cudaStream_t cuda_stream);
void CalMaxPoolWithArgmaxGrad(const T* dy, const S* index, const int n, const int c, const int xHeight,
const int xWidth, const int dyHeight, const int dyWidth, T* dx, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_

View File

@ -42,7 +42,7 @@ __global__ void MaxPoolWithArgmax(const T* input,
pos += blockDim.x * gridDim.x) {
const int posn = pos / outputCHW;
const int posc = pos / outputHW % c;
const int posh = pos / outputHeight % outputHeight;
const int posh = pos / outputWidth % outputHeight;
const int posw = pos % outputWidth;
int hstart = posh * strideHeight - padTop;
int wstart = posw * strideWidth - padLeft;
@ -50,12 +50,12 @@ __global__ void MaxPoolWithArgmax(const T* input,
const int wend = min(wstart + windowWidth, w);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
S inputStart = posn*c*h*w + posc*h*w;
S maxIdx = hstart*w + wstart;
S inputStart = posn*c*h*w;
S maxIdx = posc*h*w + hstart*w + wstart;
T maxData = input[inputStart+maxIdx];
for (int hcur = hstart; hcur < hend; ++hcur) {
for (int wcur = wstart; wcur < wend; ++wcur) {
S inputIdx = hcur*w + wcur;
S inputIdx = posc*h*w + hcur*w + wcur;
T inputData = input[inputStart+inputIdx];
if (inputData > maxData) {
maxIdx = inputIdx;

View File

@ -48,12 +48,10 @@ class MaxPoolWithArgmaxGradGpuKernel : public GpuKernel {
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
T *x_addr = GetDeviceAddress<T>(inputs, 0);
T *dy_addr = GetDeviceAddress<T>(inputs, 1);
S *index_addr = GetDeviceAddress<S>(inputs, 2);
T *dx_addr = GetDeviceAddress<T>(outputs, 0);
CalMaxPoolWithArgmaxGrad(x_addr, dy_addr, index_addr, n_, c_, x_height_, x_width_, dy_height_, dy_width_,
window_height_, window_width_, stride_height_, stride_width_, pad_top_, pad_left_, dx_addr,
CalMaxPoolWithArgmaxGrad(dy_addr, index_addr, n_, c_, x_height_, x_width_, dy_height_, dy_width_, dx_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -95,57 +93,19 @@ class MaxPoolWithArgmaxGradGpuKernel : public GpuKernel {
x_width_ = SizeToInt(x_shape[3]);
dy_height_ = SizeToInt(dy_shape[2]);
dy_width_ = SizeToInt(dy_shape[3]);
std::vector<int> window;
std::vector<int64_t> window_me =
GetValue<std::vector<int64_t>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize"));
(void)std::transform(window_me.begin(), window_me.end(), std::back_inserter(window),
[](const int64_t &value) { return static_cast<int>(value); });
window_height_ = window[1];
window_width_ = window[2];
std::vector<int> stride;
std::vector<int64_t> stride_me =
GetValue<std::vector<int64_t>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides"));
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride),
[](const int64_t &value) { return static_cast<int>(value); });
stride_height_ = stride[1];
stride_width_ = stride[2];
pad_mode_ = GetValue<std::string>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding"));
pad_top_ = 0;
pad_left_ = 0;
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) {
SetPad();
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(x_size_);
input_size_list_.push_back(dy_size_);
input_size_list_.push_back(index_size_);
output_size_list_.push_back(dx_size_);
}
private:
void SetPad() {
pad_height_ = std::max<int>(
0, (((x_height_ / stride_height_) * stride_height_ == x_height_ ? (x_height_ / stride_height_)
: (x_height_ / stride_height_) + 1) -
1) *
stride_height_ +
window_height_ - x_height_);
pad_width_ =
std::max<int>(0, (((x_width_ / stride_width_) * stride_width_ == x_width_ ? (x_width_ / stride_width_)
: (x_width_ / stride_width_) + 1) -
1) *
stride_width_ +
window_width_ - x_width_);
pad_top_ = pad_height_ / 2;
pad_left_ = pad_width_ / 2;
}
std::string pad_mode_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
@ -156,14 +116,6 @@ class MaxPoolWithArgmaxGradGpuKernel : public GpuKernel {
int x_width_;
int dy_height_;
int dy_width_;
int window_height_;
int window_width_;
int pad_height_;
int pad_width_;
int pad_top_;
int pad_left_;
int stride_height_;
int stride_width_;
size_t x_size_;
size_t dy_size_;

View File

@ -16,29 +16,84 @@
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class Net_Pool(nn.Cell):
def __init__(self):
super(Net_Pool, self).__init__()
self.maxpool_fun = P.MaxPoolWithArgmax(ksize=2, strides=2, padding="VALID")
def construct(self, x):
return self.maxpool_fun(x)
import mindspore.ops.operations as P
from mindspore import context, Tensor
from mindspore.nn import Cell
from mindspore.ops import composite as C
class Net_Pool2(nn.Cell):
def __init__(self):
super(Net_Pool2, self).__init__()
self.maxpool_fun = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="SAME")
class MaxPoolWithArgMax_Net(Cell):
def __init__(self, padding, ksize, strides):
super(MaxPoolWithArgMax_Net, self).__init__()
self.maxpool_with_argmax = P.MaxPoolWithArgmax(padding=padding, ksize=ksize, strides=strides)
def construct(self, x):
return self.maxpool_fun(x)
def construct(self, input_data):
output, argmax = self.maxpool_with_argmax(input_data)
return output, argmax
class Grad(Cell):
def __init__(self, network, argmax):
super(Grad, self).__init__()
self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network
self.sens = (Tensor(np.ones(argmax.shape).astype(np.float32)),
Tensor(np.ones(argmax.shape).astype(np.int32)))
def construct(self, input_data):
gout = self.grad(self.network)(input_data, self.sens)
return gout
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_train_forward_backward():
x = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float32)
expect_output = np.array([[[[5, 6, 7, 7],
[9, 10, 11, 11],
[9, 10, 11, 11]],
[[17, 18, 19, 19],
[21, 22, 23, 23],
[21, 22, 23, 23]],
[[29, 30, 31, 31],
[33, 34, 35, 35],
[33, 34, 35, 35]]]]).astype(np.float32)
expect_argmax = np.array([[[[5, 6, 7, 7],
[9, 10, 11, 11],
[9, 10, 11, 11]],
[[17, 18, 19, 19],
[21, 22, 23, 23],
[21, 22, 23, 23]],
[[29, 30, 31, 31],
[33, 34, 35, 35],
[33, 34, 35, 35]]]]).astype(np.int32)
expect_dx = np.array([[[[0, 0, 0, 0],
[0, 1, 1, 2],
[0, 2, 2, 4]],
[[0, 0, 0, 0],
[0, 1, 1, 2],
[0, 2, 2, 4]],
[[0, 0, 0, 0],
[0, 1, 1, 2],
[0, 2, 2, 4]]]]).astype(np.float32)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = MaxPoolWithArgMax_Net(padding="SAME", ksize=2, strides=1)
output_tensor, argmax_tensor = net(Tensor(x))
assert output_tensor.shape == expect_output.shape
assert argmax_tensor.shape == expect_argmax.shape
error = np.ones(shape=expect_output.shape) * 1.0e-5
diff_output = output_tensor.asnumpy() - expect_output
assert np.all(diff_output < error)
net_grad = Grad(net, argmax_tensor)
dx = net_grad(Tensor(x))[0].asnumpy()
assert dx.shape == expect_dx.shape
diff = dx - expect_dx
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -73,8 +128,8 @@ def test_maxpool_with_argmax_2d():
]]]))
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
maxpool2d = Net_Pool()
maxpool2d2 = Net_Pool2()
maxpool2d = MaxPoolWithArgMax_Net(padding="VALID", ksize=2, strides=2)
maxpool2d2 = MaxPoolWithArgMax_Net(padding="SAME", ksize=3, strides=2)
output2, index2 = maxpool2d2(x)
output, index = maxpool2d(x)
assert (output.asnumpy() == expect_result).all()
@ -83,8 +138,8 @@ def test_maxpool_with_argmax_2d():
assert (index2.asnumpy() == expect__index_result2).all()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
maxpool2d = Net_Pool()
maxpool2d2 = Net_Pool2()
maxpool2d = MaxPoolWithArgMax_Net(padding="VALID", ksize=2, strides=2)
maxpool2d2 = MaxPoolWithArgMax_Net(padding="SAME", ksize=3, strides=2)
output2, index2 = maxpool2d2(x)
output, index = maxpool2d(x)
assert (output.asnumpy() == expect_result).all()
@ -126,8 +181,8 @@ def test_maxpool_with_argmax_2d_fp16():
]]]))
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
maxpool2d = Net_Pool()
maxpool2d2 = Net_Pool2()
maxpool2d = MaxPoolWithArgMax_Net(padding="VALID", ksize=2, strides=2)
maxpool2d2 = MaxPoolWithArgMax_Net(padding="SAME", ksize=3, strides=2)
output2, index2 = maxpool2d2(x)
output, index = maxpool2d(x)
assert (output.asnumpy() == expect_result).all()
@ -136,12 +191,11 @@ def test_maxpool_with_argmax_2d_fp16():
assert (index2.asnumpy() == expect__index_result2).all()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
maxpool2d = Net_Pool()
maxpool2d2 = Net_Pool2()
maxpool2d = MaxPoolWithArgMax_Net(padding="VALID", ksize=2, strides=2)
maxpool2d2 = MaxPoolWithArgMax_Net(padding="SAME", ksize=3, strides=2)
output2, index2 = maxpool2d2(x)
output, index = maxpool2d(x)
assert (output.asnumpy() == expect_result).all()
assert (output2.asnumpy() == expect_result2).all()
assert (index.asnumpy() == expect_index_result).all()
assert (index2.asnumpy() == expect__index_result2).all()