[OP] optimieze the performance of max pool grad grad

This commit is contained in:
yangruoqi713 2022-06-30 19:56:58 +08:00
parent 914ebfeb1b
commit 247a0d3d86
12 changed files with 235 additions and 182 deletions

View File

@ -21,15 +21,13 @@
#include "utils/ms_utils.h"
#include "utils/profile.h"
#include "mindspore/ccsrc/kernel/common_utils.h"
#include "nnacl/fp32/maxpool_with_argmax.h"
#include "nnacl/base/gather_base.h"
#include "nnacl/fp32_grad/maxpool_grad_grad.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kMaxPoolGradGradInputsNum = 3;
constexpr size_t kMaxPoolGradGradOutputsNum = 1;
constexpr size_t kMaxPoolGradGradWorkSpaceNum = 2;
constexpr size_t kGradIndex = 2;
constexpr size_t kPadHalf = 2;
@ -125,12 +123,6 @@ void MaxPoolGradGradCpuKernelMod::CalPad() {
}
}
void MaxPoolGradGradCpuKernelMod::InitWorkspace() {
workspace_size_list_.push_back(input_size_list_[1]);
output_elements_ = std::accumulate(out_shapes_.begin(), out_shapes_.end(), 1, std::multiplies<size_t>());
workspace_size_list_.push_back(sizeof(int32_t) * output_elements_);
}
int MaxPoolGradGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
@ -152,6 +144,7 @@ int MaxPoolGradGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
param_->output_channel_ = LongToInt(out_shapes_[kDim1]);
param_->output_h_ = LongToInt(out_shapes_[height_index_]);
param_->output_w_ = LongToInt(out_shapes_[width_index_]);
output_elements_ = std::accumulate(out_shapes_.begin(), out_shapes_.end(), 1, std::multiplies<size_t>());
if (dim_ == kMaxPool3DGradGradDim) {
reinterpret_cast<Pooling3DParameter *>(param_)->input_d_ = LongToInt(in_shapes_[depth_index_]);
@ -162,7 +155,6 @@ int MaxPoolGradGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
CheckInputVaild();
CalPad();
InitWorkspace();
return KRET_OK;
}
@ -171,18 +163,17 @@ bool MaxPoolGradGradCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMaxPoolGradGradInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMaxPoolGradGradOutputsNum, kernel_name_);
CHECK_KERNEL_WORKSPACE_SIZE(workspace.size(), kMaxPoolGradGradWorkSpaceNum, kernel_name_);
auto *input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto *output_addr = reinterpret_cast<float *>(workspace[0]->addr);
auto *index_addr = reinterpret_cast<int32_t *>(workspace[1]->addr);
auto *grad_addr = reinterpret_cast<float *>(inputs[kGradIndex]->addr);
auto *dx_addr = reinterpret_cast<float *>(outputs[0]->addr);
auto task = [input_addr, output_addr, index_addr, this](size_t start, size_t end) {
auto task = [input_addr, grad_addr, dx_addr, this](size_t start, size_t end) {
auto ret = static_cast<int>(NNACL_OK);
if (dim_ == kMaxPool2DGradGradDim) {
ret = MaxPoolWithArgmax(input_addr, output_addr, index_addr, start, end, param_);
ret = MaxPoolGradGrad(input_addr, grad_addr, dx_addr, start, end, param_);
} else if (dim_ == kMaxPool3DGradGradDim) {
ret = MaxPool3DWithArgmax(input_addr, output_addr, index_addr, start, end,
reinterpret_cast<Pooling3DParameter *>(param_));
ret =
MaxPool3DGradGrad(input_addr, grad_addr, dx_addr, start, end, reinterpret_cast<Pooling3DParameter *>(param_));
}
if (ret != static_cast<int>(NNACL_OK)) {
MS_LOG(ERROR) << "For '" << kernel_name_
@ -192,23 +183,6 @@ bool MaxPoolGradGradCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
return true;
};
ParallelLaunchAutoSearch(task, output_elements_, this, &parallel_search_info_, pool_);
int64_t outer_size = 1;
int64_t inner_size = 1;
int64_t indices_element_size = SizeToLong(output_batch_stride_);
int64_t limit = SizeToLong(input_batch_stride_);
size_t byte_inner_size = inner_size * sizeof(float);
size_t byte_out_stride = indices_element_size * byte_inner_size;
for (int b = 0; b < param_->input_batch_; b++) {
auto *index_t = index_addr + b * output_batch_stride_;
auto *grad_t = reinterpret_cast<float *>(inputs[kGradIndex]->addr) + b * input_batch_stride_;
auto *dx_t = reinterpret_cast<float *>(outputs[0]->addr) + b * output_batch_stride_;
int ret = Gather(grad_t, outer_size, byte_inner_size, limit, index_t, indices_element_size, dx_t, byte_out_stride);
if (ret != 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', error_code[" << ret << "]";
}
}
return true;
}

View File

@ -60,7 +60,6 @@ class MaxPoolGradGradCpuKernelMod : public NativeCpuKernelMod {
private:
void CheckInputVaild();
void CalPad();
void InitWorkspace();
std::vector<int64_t> kernels_;
std::vector<int64_t> strides_;

View File

@ -9,14 +9,14 @@
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/maxpool_with_argmax.h"
#include "nnacl/fp32_grad/maxpool_grad_grad.h"
int MaxPoolWithArgmax(const float *input, float *output, int *index, size_t start, size_t end,
PoolingParameter *param) {
int MaxPoolGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end,
PoolingParameter *param) {
const int channel = param->input_channel_;
const int input_height = param->input_h_;
const int input_width = param->input_w_;
@ -53,13 +53,13 @@ int MaxPoolWithArgmax(const float *input, float *output, int *index, size_t star
h_start = MSMAX(h_start, 0);
w_start = MSMAX(w_start, 0);
int input_start = pos_n * channel * input_height * input_width;
int max_idx = pos_c * input_height * input_width + h_start * input_width + w_start;
int input_start = pos_n * channel * input_height * input_width + pos_c * input_height * input_width;
int max_idx = h_start * input_width + w_start;
float max_data = input[input_start + max_idx];
for (int h_cur = h_start; h_cur < h_end; ++h_cur) {
for (int w_cur = w_start; w_cur < w_end; ++w_cur) {
int input_idx = pos_c * input_height * input_width + h_cur * input_width + w_cur;
int input_idx = h_cur * input_width + w_cur;
float input_data = input[input_start + input_idx];
if (input_data > max_data) {
max_idx = input_idx;
@ -67,14 +67,13 @@ int MaxPoolWithArgmax(const float *input, float *output, int *index, size_t star
}
}
}
output[pos] = max_data;
index[pos] = max_idx;
output[pos] = grad[input_start + max_idx];
}
return NNACL_OK;
}
int MaxPool3DWithArgmax(const float *input, float *output, int *index, size_t start, size_t end,
Pooling3DParameter *param) {
int MaxPool3DGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end,
Pooling3DParameter *param) {
PoolingParameter *param_2d = (PoolingParameter *)(param);
const int channel = param_2d->input_channel_;
const int input_depth = param->input_d_;
@ -124,16 +123,15 @@ int MaxPool3DWithArgmax(const float *input, float *output, int *index, size_t st
h_start = MSMAX(h_start, 0);
w_start = MSMAX(w_start, 0);
int input_start = pos_n * channel * input_depth * input_height * input_width;
int max_idx = pos_c * input_depth * input_height * input_width + d_start * input_height * input_width +
h_start * input_width + w_start;
int input_start =
pos_n * channel * input_depth * input_height * input_width + pos_c * input_depth * input_height * input_width;
int max_idx = d_start * input_height * input_width + h_start * input_width + w_start;
float max_data = input[input_start + max_idx];
for (int d_cur = d_start; d_cur < d_end; ++d_cur) {
for (int h_cur = h_start; h_cur < h_end; ++h_cur) {
for (int w_cur = w_start; w_cur < w_end; ++w_cur) {
int input_idx = pos_c * input_depth * input_height * input_width + d_cur * input_height * input_width +
h_cur * input_width + w_cur;
int input_idx = d_cur * input_height * input_width + h_cur * input_width + w_cur;
float input_data = input[input_start + input_idx];
if (input_data > max_data) {
max_idx = input_idx;
@ -142,8 +140,7 @@ int MaxPool3DWithArgmax(const float *input, float *output, int *index, size_t st
}
}
}
output[pos] = max_data;
index[pos] = max_idx;
output[pos] = grad[input_start + max_idx];
}
return NNACL_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_FP32_MAXPOOL_WITH_ARGMAX_H_
#define MINDSPORE_NNACL_FP32_MAXPOOL_WITH_ARGMAX_H_
#ifndef MINDSPORE_NNACL_FP32_GRAD_MAXPOOL_GRAD_GARD_H_
#define MINDSPORE_NNACL_FP32_GRAD_MAXPOOL_GRAD_GARD_H_
#include "nnacl/op_base.h"
#include "nnacl/pooling_parameter.h"
@ -24,12 +24,13 @@
#ifdef __cplusplus
extern "C" {
#endif
int MaxPoolWithArgmax(const float *input, float *output, int *index, size_t start, size_t end, PoolingParameter *param);
int MaxPoolGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end,
PoolingParameter *param);
int MaxPool3DWithArgmax(const float *input, float *output, int *index, size_t start, size_t end,
Pooling3DParameter *param);
int MaxPool3DGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end,
Pooling3DParameter *param);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_NNACL_FP32_MAXPOOL_WITH_ARGMAX_H_
#endif // MINDSPORE_NNACL_FP32_GRAD_MAXPOOL_GRAD_GARD_H_

View File

@ -0,0 +1,152 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/maxpool_grad_grad_impl.cuh"
#include "include/cuda_fp16.h"
template <typename T>
__global__ void MaxPoolGradGrad(const T *input, const T *grad, const int n, const int c, const int h, const int w,
const int windowHeight, const int windowWidth, const int strideHeight,
const int strideWidth, const int padTop, const int padLeft, const int outputHeight,
const int outputWidth, const int outputNCHW, const int outputCHW, const int outputHW,
T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (outputNCHW); pos += blockDim.x * gridDim.x) {
const int posn = pos / outputCHW;
const int posc = pos / outputHW % c;
const int posh = pos / outputWidth % outputHeight;
const int posw = pos % outputWidth;
int hstart = posh * strideHeight - padTop;
int wstart = posw * strideWidth - padLeft;
const int hend = min(hstart + windowHeight, h);
const int wend = min(wstart + windowWidth, w);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
int inputStart = posn * c * h * w + posc * h * w;
int maxIdx = hstart * w + wstart;
T maxData = input[inputStart + maxIdx];
for (int hcur = hstart; hcur < hend; ++hcur) {
for (int wcur = wstart; wcur < wend; ++wcur) {
int inputIdx = hcur * w + wcur;
T inputData = input[inputStart + inputIdx];
if (inputData > maxData) {
maxIdx = inputIdx;
maxData = inputData;
}
}
}
output[pos] = grad[inputStart + maxIdx];
}
}
template <typename T>
void CalMaxPoolGradGrad(const T *input, const T *grad, const int n, const int c, const int h, const int w,
const int windowHeight, const int windowWidth, const int strideHeight, const int strideWidth,
const int padTop, const int padLeft, const int outputHeight, const int outputWidth, T *output,
const uint32_t &device_id, cudaStream_t cuda_stream) {
const int outputNCHW = n * c * outputHeight * outputWidth;
const int outputCHW = c * outputHeight * outputWidth;
const int outputHW = outputHeight * outputWidth;
MaxPoolGradGrad<<<CUDA_BLOCKS(device_id, n * c * outputHeight * outputWidth), CUDA_THREADS(device_id), 0,
cuda_stream>>>(input, grad, n, c, h, w, windowHeight, windowWidth, strideHeight, strideWidth,
padTop, padLeft, outputHeight, outputWidth, outputNCHW, outputCHW, outputHW, output);
}
template CUDA_LIB_EXPORT void CalMaxPoolGradGrad<float>(const float *input, const float *grad, const int n, const int c,
const int h, const int w, const int windowHeight,
const int windowWidth, const int strideHeight,
const int strideWidth, const int padTop, const int padLeft,
const int outputHeight, const int outputWidth, float *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalMaxPoolGradGrad<half>(const half *input, const half *grad, const int n, const int c,
const int h, const int w, const int windowHeight,
const int windowWidth, const int strideHeight,
const int strideWidth, const int padTop, const int padLeft,
const int outputHeight, const int outputWidth, half *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template <typename T>
__global__ void MaxPool3DGradGrad(const T *input, const T *grad, const int n, const int c, const int d, const int h,
const int w, const int windowDepth, const int windowHeight, const int windowWidth,
const int strideDepth, const int strideHeight, const int strideWidth,
const int padFront, const int padTop, const int padLeft, const int outputDepth,
const int outputHeight, const int outputWidth, T *output) {
const int outputNCDHW = n * c * outputDepth * outputHeight * outputWidth;
const int outputCDHW = c * outputDepth * outputHeight * outputWidth;
const int outputDHW = outputDepth * outputHeight * outputWidth;
const int outputHW = outputHeight * outputWidth;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (outputNCDHW); pos += blockDim.x * gridDim.x) {
const int posn = pos / outputCDHW;
const int posc = pos / outputDHW % c;
const int posd = pos / outputHW % outputDepth;
const int posh = pos / outputWidth % outputHeight;
const int posw = pos % outputWidth;
int dstart = posd * strideDepth - padFront;
int hstart = posh * strideHeight - padTop;
int wstart = posw * strideWidth - padLeft;
const int dend = min(dstart + windowDepth, d);
const int hend = min(hstart + windowHeight, h);
const int wend = min(wstart + windowWidth, w);
dstart = max(dstart, 0);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
int inputStart = posn * c * d * h * w + posc * d * h * w;
int maxIdx = dstart * h * w + hstart * w + wstart;
T maxData = input[inputStart + maxIdx];
for (int dcur = dstart; dcur < dend; ++dcur) {
for (int hcur = hstart; hcur < hend; ++hcur) {
for (int wcur = wstart; wcur < wend; ++wcur) {
int inputIdx = dcur * h * w + hcur * w + wcur;
T inputData = input[inputStart + inputIdx];
if (inputData > maxData) {
maxIdx = inputIdx;
maxData = inputData;
}
}
}
}
output[pos] = grad[inputStart + maxIdx];
}
}
template <typename T>
void CalMaxPool3DGradGrad(const T *input, const T *grad, const int n, const int c, const int d, const int h,
const int w, const int windowDepth, const int windowHeight, const int windowWidth,
const int strideDepth, const int strideHeight, const int strideWidth, const int padFront,
const int padTop, const int padLeft, const int outputDepth, const int outputHeight,
const int outputWidth, T *output, const uint32_t &device_id, cudaStream_t cuda_stream) {
MaxPool3DGradGrad<<<CUDA_BLOCKS(device_id, n * c * outputDepth * outputHeight * outputWidth), CUDA_THREADS(device_id),
0, cuda_stream>>>(input, grad, n, c, d, h, w, windowDepth, windowHeight, windowWidth, strideDepth,
strideHeight, strideWidth, padFront, padTop, padLeft, outputDepth, outputHeight,
outputWidth, output);
}
template CUDA_LIB_EXPORT void CalMaxPool3DGradGrad<float>(
const float *input, const float *grad, const int n, const int c, const int d, const int h, const int w,
const int windowDepth, const int windowHeight, const int windowWidth, const int strideDepth, const int strideHeight,
const int strideWidth, const int padFront, const int padTop, const int padLeft, const int outputDepth,
const int outputHeight, const int outputWidth, float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalMaxPool3DGradGrad<half>(
const half *input, const half *grad, const int n, const int c, const int d, const int h, const int w,
const int windowDepth, const int windowHeight, const int windowWidth, const int strideDepth, const int strideHeight,
const int strideWidth, const int padFront, const int padTop, const int padLeft, const int outputDepth,
const int outputHeight, const int outputWidth, half *output, const uint32_t &device_id, cudaStream_t cuda_stream);

View File

@ -0,0 +1,36 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXPOOL_GRAD_GRAD_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXPOOL_GRAD_GRAD_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
template <typename T>
CUDA_LIB_EXPORT void CalMaxPoolGradGrad(const T *input, const T *grad, const int n, const int c, const int h,
const int w, const int windowHeight, const int windowWidth,
const int strideHeight, const int strideWidth, const int padTop,
const int padLeft, const int outputHeight, const int outputWidth,
T *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void CalMaxPool3DGradGrad(const T *input, const T *grad, const int n, const int c, const int d,
const int h, const int w, const int windowDepth, const int windowHeight,
const int windowWidth, const int strideDepth, const int strideHeight,
const int strideWidth, const int padFront, const int padTop,
const int padLeft, const int outputDepth, const int outputHeight,
const int outputWidth, T *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXPOOL_GRAD_GRAD_IMPL_CUH_

View File

@ -18,8 +18,8 @@
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXPOOL_WITH_ARGMAX_GRAD_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
template <typename T, typename S>
CUDA_LIB_EXPORT void CalMaxPoolWithArgmaxGrad(const T* dy, const S* index, const int n, const int c, const int xHeight,
const int xWidth, const int dyHeight, const int dyWidth, T* dx,
CUDA_LIB_EXPORT void CalMaxPoolWithArgmaxGrad(const T *dy, const S *index, const int n, const int c, const int xHeight,
const int xWidth, const int dyHeight, const int dyWidth, T *dx,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXPOOL_WITH_ARGMAX_GRAD_IMPL_CUH_

View File

@ -75,75 +75,3 @@ template CUDA_LIB_EXPORT void CalMaxPoolWithArgmax<half, int>(
const half *input, const int n, const int c, const int h, const int w, const int windowHeight, const int windowWidth,
const int strideHeight, const int strideWidth, const int padTop, const int padLeft, const int outputHeight,
const int outputWidth, half *output, int *index, const uint32_t &device_id, cudaStream_t cuda_stream);
template <typename T, typename S>
__global__ void MaxPool3DWithArgmax(const T *input, const int n, const int c, const int d, const int h, const int w,
const int windowDepth, const int windowHeight, const int windowWidth,
const int strideDepth, const int strideHeight, const int strideWidth,
const int padFront, const int padTop, const int padLeft, const int outputDepth,
const int outputHeight, const int outputWidth, T *output, S *index) {
const int outputNCDHW = n * c * outputDepth * outputHeight * outputWidth;
const int outputCDHW = c * outputDepth * outputHeight * outputWidth;
const int outputDHW = outputDepth * outputHeight * outputWidth;
const int outputHW = outputHeight * outputWidth;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (outputNCDHW); pos += blockDim.x * gridDim.x) {
const int posn = pos / outputCDHW;
const int posc = pos / outputDHW % c;
const int posd = pos / outputHW % outputDepth;
const int posh = pos / outputWidth % outputHeight;
const int posw = pos % outputWidth;
int dstart = posd * strideDepth - padFront;
int hstart = posh * strideHeight - padTop;
int wstart = posw * strideWidth - padLeft;
const int dend = min(dstart + windowDepth, d);
const int hend = min(hstart + windowHeight, h);
const int wend = min(wstart + windowWidth, w);
dstart = max(dstart, 0);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
S inputStart = posn * c * d * h * w;
S maxIdx = posc * d * h * w + dstart * h * w + hstart * w + wstart;
T maxData = input[inputStart + maxIdx];
for (int dcur = dstart; dcur < dend; ++dcur) {
for (int hcur = hstart; hcur < hend; ++hcur) {
for (int wcur = wstart; wcur < wend; ++wcur) {
S inputIdx = posc * d * h * w + dcur * h * w + hcur * w + wcur;
T inputData = input[inputStart + inputIdx];
if (inputData > maxData) {
maxIdx = inputIdx;
maxData = inputData;
}
}
}
}
output[pos] = maxData;
index[pos] = maxIdx;
}
}
template <typename T, typename S>
void CalMaxPool3DWithArgmax(const T *input, const int n, const int c, const int d, const int h, const int w,
const int windowDepth, const int windowHeight, const int windowWidth, const int strideDepth,
const int strideHeight, const int strideWidth, const int padFront, const int padTop,
const int padLeft, const int outputDepth, const int outputHeight, const int outputWidth,
T *output, S *index, const uint32_t &device_id, cudaStream_t cuda_stream) {
MaxPool3DWithArgmax<<<CUDA_BLOCKS(device_id, n * c * d * outputHeight * outputWidth), CUDA_THREADS(device_id), 0,
cuda_stream>>>(input, n, c, d, h, w, windowDepth, windowHeight, windowWidth, strideDepth,
strideHeight, strideWidth, padFront, padTop, padLeft, outputDepth, outputHeight,
outputWidth, output, index);
}
template CUDA_LIB_EXPORT void CalMaxPool3DWithArgmax<float, int>(
const float *input, const int n, const int c, const int d, const int h, const int w, const int windowDepth,
const int windowHeight, const int windowWidth, const int strideDepth, const int strideHeight, const int strideWidth,
const int padFront, const int padTop, const int padLeft, const int outputDepth, const int outputHeight,
const int outputWidth, float *output, int *index, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalMaxPool3DWithArgmax<half, int>(
const half *input, const int n, const int c, const int d, const int h, const int w, const int windowDepth,
const int windowHeight, const int windowWidth, const int strideDepth, const int strideHeight, const int strideWidth,
const int padFront, const int padTop, const int padLeft, const int outputDepth, const int outputHeight,
const int outputWidth, half *output, int *index, const uint32_t &device_id, cudaStream_t cuda_stream);

View File

@ -24,13 +24,4 @@ CUDA_LIB_EXPORT void CalMaxPoolWithArgmax(const T *input, const int n, const int
const int outputHeight, const int outputWidth, T *output, S *index,
const uint32_t &device_id, cudaStream_t cuda_stream);
template <typename T, typename S>
CUDA_LIB_EXPORT void CalMaxPool3DWithArgmax(const T *input, const int n, const int c, const int d, const int h,
const int w, const int windowDepth, const int windowHeight,
const int windowWidth, const int strideDepth, const int strideHeight,
const int strideWidth, const int padFront, const int padTop,
const int padLeft, const int outputDepth, const int outputHeight,
const int outputWidth, T *output, S *index, const uint32_t &device_id,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXPOOL_WITH_ARGMAX_IMPL_CUH_

View File

@ -20,15 +20,13 @@
#include "mindspore/core/ops/grad/max_pool_grad_grad.h"
#include "abstract/utils.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/maxpool_with_argmax_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/gather.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/maxpool_grad_grad_impl.cuh"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kMaxPoolGradGradInputsNum = 3;
constexpr size_t kMaxPoolGradGradOutputsNum = 1;
constexpr size_t kMaxPoolGradGradWorkSpaceNum = 2;
constexpr size_t kGradIndex = 2;
constexpr size_t kPadHalf = 2;
} // namespace
@ -131,49 +129,28 @@ int MaxPoolGradGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
output_batch_stride_ = std::accumulate(out_shapes_.begin() + 1, out_shapes_.end(), 1, std::multiplies<size_t>());
CalPad();
workspace_size_list_.clear();
workspace_size_list_.push_back(input_size_list_[1]);
auto output_elements = std::accumulate(out_shapes_.begin(), out_shapes_.end(), 1, std::multiplies<size_t>());
workspace_size_list_.push_back(sizeof(int32_t) * output_elements);
return KRET_OK;
}
template <typename T>
bool MaxPoolGradGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *input_addr = GetDeviceAddress<T>(inputs, kIndex0);
T *output_addr = GetDeviceAddress<T>(workspace, kIndex0);
auto *index_addr = GetDeviceAddress<int32_t>(workspace, kIndex1);
T *grad_addr = GetDeviceAddress<T>(inputs, kGradIndex);
T *output_addr = GetDeviceAddress<T>(outputs, kIndex0);
if (dim_ == kMaxPool2DGradGradDim) {
CalMaxPoolWithArgmax<T, int32_t>(input_addr, batch_, channel_, input_height_, input_width_, window_height_,
window_width_, stride_height_, stride_width_, pad_top_, pad_left_, output_height_,
output_width_, output_addr, index_addr, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream_));
CalMaxPoolGradGrad<T>(input_addr, grad_addr, batch_, channel_, input_height_, input_width_, window_height_,
window_width_, stride_height_, stride_width_, pad_top_, pad_left_, output_height_,
output_width_, output_addr, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
} else if (dim_ == kMaxPool3DGradGradDim) {
CalMaxPool3DWithArgmax<T, int32_t>(
input_addr, batch_, channel_, input_depth_, input_height_, input_width_, window_depth_, window_height_,
window_width_, stride_depth_, stride_height_, stride_width_, pad_front_, pad_top_, pad_left_, output_depth_,
output_height_, output_width_, output_addr, index_addr, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
CalMaxPool3DGradGrad<T>(input_addr, grad_addr, batch_, channel_, input_depth_, input_height_, input_width_,
window_depth_, window_height_, window_width_, stride_depth_, stride_height_, stride_width_,
pad_front_, pad_top_, pad_left_, output_depth_, output_height_, output_width_, output_addr,
device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
} else {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', only supports 2D or 3D max pooling.";
return false;
}
T *grad_addr = GetDeviceAddress<T>(inputs, kIndex2);
T *dx = GetDeviceAddress<T>(outputs, kIndex0);
size_t dim_before_axis = 1;
size_t dim_at_axis_input = input_batch_stride_;
size_t dim_at_axis_output = output_batch_stride_;
size_t dim_after_axis = 1;
for (int b = 0; b < batch_; b++) {
int32_t *index_t = index_addr + b * output_batch_stride_;
T *grad_t = grad_addr + b * input_batch_stride_;
T *dx_t = dx + b * output_batch_stride_;
Gather<T, int32_t>(grad_t, index_t, dx_t, dim_before_axis, dim_at_axis_input, dim_at_axis_output, dim_after_axis,
reinterpret_cast<cudaStream_t>(cuda_stream_), device_id_);
}
return true;
}

View File

@ -35,7 +35,7 @@ class MaxPoolGradGradGpuKernelMod : public NativeGpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
cuda_stream_ = cuda_stream;
return kernel_func_(this, inputs, workspace, outputs);
return kernel_func_(this, inputs, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
@ -48,11 +48,9 @@ class MaxPoolGradGradGpuKernelMod : public NativeGpuKernelMod {
private:
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
using MaxPoolGradGradFunc =
std::function<bool(MaxPoolGradGradGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
using MaxPoolGradGradFunc = std::function<bool(MaxPoolGradGradGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
void *cuda_stream_{nullptr};
MaxPoolGradGradFunc kernel_func_{};

View File

@ -21,7 +21,7 @@ from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetPoolGradGrad(nn.Cell):