[OP] optimieze the performance of max pool grad grad
This commit is contained in:
parent
914ebfeb1b
commit
247a0d3d86
|
@ -21,15 +21,13 @@
|
|||
#include "utils/ms_utils.h"
|
||||
#include "utils/profile.h"
|
||||
#include "mindspore/ccsrc/kernel/common_utils.h"
|
||||
#include "nnacl/fp32/maxpool_with_argmax.h"
|
||||
#include "nnacl/base/gather_base.h"
|
||||
#include "nnacl/fp32_grad/maxpool_grad_grad.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kMaxPoolGradGradInputsNum = 3;
|
||||
constexpr size_t kMaxPoolGradGradOutputsNum = 1;
|
||||
constexpr size_t kMaxPoolGradGradWorkSpaceNum = 2;
|
||||
constexpr size_t kGradIndex = 2;
|
||||
constexpr size_t kPadHalf = 2;
|
||||
|
||||
|
@ -125,12 +123,6 @@ void MaxPoolGradGradCpuKernelMod::CalPad() {
|
|||
}
|
||||
}
|
||||
|
||||
void MaxPoolGradGradCpuKernelMod::InitWorkspace() {
|
||||
workspace_size_list_.push_back(input_size_list_[1]);
|
||||
output_elements_ = std::accumulate(out_shapes_.begin(), out_shapes_.end(), 1, std::multiplies<size_t>());
|
||||
workspace_size_list_.push_back(sizeof(int32_t) * output_elements_);
|
||||
}
|
||||
|
||||
int MaxPoolGradGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
|
@ -152,6 +144,7 @@ int MaxPoolGradGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
param_->output_channel_ = LongToInt(out_shapes_[kDim1]);
|
||||
param_->output_h_ = LongToInt(out_shapes_[height_index_]);
|
||||
param_->output_w_ = LongToInt(out_shapes_[width_index_]);
|
||||
output_elements_ = std::accumulate(out_shapes_.begin(), out_shapes_.end(), 1, std::multiplies<size_t>());
|
||||
|
||||
if (dim_ == kMaxPool3DGradGradDim) {
|
||||
reinterpret_cast<Pooling3DParameter *>(param_)->input_d_ = LongToInt(in_shapes_[depth_index_]);
|
||||
|
@ -162,7 +155,6 @@ int MaxPoolGradGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
|
||||
CheckInputVaild();
|
||||
CalPad();
|
||||
InitWorkspace();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
|
@ -171,18 +163,17 @@ bool MaxPoolGradGradCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
|
|||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMaxPoolGradGradInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMaxPoolGradGradOutputsNum, kernel_name_);
|
||||
CHECK_KERNEL_WORKSPACE_SIZE(workspace.size(), kMaxPoolGradGradWorkSpaceNum, kernel_name_);
|
||||
auto *input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto *output_addr = reinterpret_cast<float *>(workspace[0]->addr);
|
||||
auto *index_addr = reinterpret_cast<int32_t *>(workspace[1]->addr);
|
||||
auto *grad_addr = reinterpret_cast<float *>(inputs[kGradIndex]->addr);
|
||||
auto *dx_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
|
||||
auto task = [input_addr, output_addr, index_addr, this](size_t start, size_t end) {
|
||||
auto task = [input_addr, grad_addr, dx_addr, this](size_t start, size_t end) {
|
||||
auto ret = static_cast<int>(NNACL_OK);
|
||||
if (dim_ == kMaxPool2DGradGradDim) {
|
||||
ret = MaxPoolWithArgmax(input_addr, output_addr, index_addr, start, end, param_);
|
||||
ret = MaxPoolGradGrad(input_addr, grad_addr, dx_addr, start, end, param_);
|
||||
} else if (dim_ == kMaxPool3DGradGradDim) {
|
||||
ret = MaxPool3DWithArgmax(input_addr, output_addr, index_addr, start, end,
|
||||
reinterpret_cast<Pooling3DParameter *>(param_));
|
||||
ret =
|
||||
MaxPool3DGradGrad(input_addr, grad_addr, dx_addr, start, end, reinterpret_cast<Pooling3DParameter *>(param_));
|
||||
}
|
||||
if (ret != static_cast<int>(NNACL_OK)) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
|
@ -192,23 +183,6 @@ bool MaxPoolGradGradCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
|
|||
return true;
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, output_elements_, this, ¶llel_search_info_, pool_);
|
||||
|
||||
int64_t outer_size = 1;
|
||||
int64_t inner_size = 1;
|
||||
int64_t indices_element_size = SizeToLong(output_batch_stride_);
|
||||
int64_t limit = SizeToLong(input_batch_stride_);
|
||||
size_t byte_inner_size = inner_size * sizeof(float);
|
||||
size_t byte_out_stride = indices_element_size * byte_inner_size;
|
||||
|
||||
for (int b = 0; b < param_->input_batch_; b++) {
|
||||
auto *index_t = index_addr + b * output_batch_stride_;
|
||||
auto *grad_t = reinterpret_cast<float *>(inputs[kGradIndex]->addr) + b * input_batch_stride_;
|
||||
auto *dx_t = reinterpret_cast<float *>(outputs[0]->addr) + b * output_batch_stride_;
|
||||
int ret = Gather(grad_t, outer_size, byte_inner_size, limit, index_t, indices_element_size, dx_t, byte_out_stride);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', error_code[" << ret << "]";
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -60,7 +60,6 @@ class MaxPoolGradGradCpuKernelMod : public NativeCpuKernelMod {
|
|||
private:
|
||||
void CheckInputVaild();
|
||||
void CalPad();
|
||||
void InitWorkspace();
|
||||
|
||||
std::vector<int64_t> kernels_;
|
||||
std::vector<int64_t> strides_;
|
||||
|
|
|
@ -9,14 +9,14 @@
|
|||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "nnacl/fp32/maxpool_with_argmax.h"
|
||||
#include "nnacl/fp32_grad/maxpool_grad_grad.h"
|
||||
|
||||
int MaxPoolWithArgmax(const float *input, float *output, int *index, size_t start, size_t end,
|
||||
PoolingParameter *param) {
|
||||
int MaxPoolGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end,
|
||||
PoolingParameter *param) {
|
||||
const int channel = param->input_channel_;
|
||||
const int input_height = param->input_h_;
|
||||
const int input_width = param->input_w_;
|
||||
|
@ -53,13 +53,13 @@ int MaxPoolWithArgmax(const float *input, float *output, int *index, size_t star
|
|||
h_start = MSMAX(h_start, 0);
|
||||
w_start = MSMAX(w_start, 0);
|
||||
|
||||
int input_start = pos_n * channel * input_height * input_width;
|
||||
int max_idx = pos_c * input_height * input_width + h_start * input_width + w_start;
|
||||
int input_start = pos_n * channel * input_height * input_width + pos_c * input_height * input_width;
|
||||
int max_idx = h_start * input_width + w_start;
|
||||
float max_data = input[input_start + max_idx];
|
||||
|
||||
for (int h_cur = h_start; h_cur < h_end; ++h_cur) {
|
||||
for (int w_cur = w_start; w_cur < w_end; ++w_cur) {
|
||||
int input_idx = pos_c * input_height * input_width + h_cur * input_width + w_cur;
|
||||
int input_idx = h_cur * input_width + w_cur;
|
||||
float input_data = input[input_start + input_idx];
|
||||
if (input_data > max_data) {
|
||||
max_idx = input_idx;
|
||||
|
@ -67,14 +67,13 @@ int MaxPoolWithArgmax(const float *input, float *output, int *index, size_t star
|
|||
}
|
||||
}
|
||||
}
|
||||
output[pos] = max_data;
|
||||
index[pos] = max_idx;
|
||||
output[pos] = grad[input_start + max_idx];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int MaxPool3DWithArgmax(const float *input, float *output, int *index, size_t start, size_t end,
|
||||
Pooling3DParameter *param) {
|
||||
int MaxPool3DGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end,
|
||||
Pooling3DParameter *param) {
|
||||
PoolingParameter *param_2d = (PoolingParameter *)(param);
|
||||
const int channel = param_2d->input_channel_;
|
||||
const int input_depth = param->input_d_;
|
||||
|
@ -124,16 +123,15 @@ int MaxPool3DWithArgmax(const float *input, float *output, int *index, size_t st
|
|||
h_start = MSMAX(h_start, 0);
|
||||
w_start = MSMAX(w_start, 0);
|
||||
|
||||
int input_start = pos_n * channel * input_depth * input_height * input_width;
|
||||
int max_idx = pos_c * input_depth * input_height * input_width + d_start * input_height * input_width +
|
||||
h_start * input_width + w_start;
|
||||
int input_start =
|
||||
pos_n * channel * input_depth * input_height * input_width + pos_c * input_depth * input_height * input_width;
|
||||
int max_idx = d_start * input_height * input_width + h_start * input_width + w_start;
|
||||
float max_data = input[input_start + max_idx];
|
||||
|
||||
for (int d_cur = d_start; d_cur < d_end; ++d_cur) {
|
||||
for (int h_cur = h_start; h_cur < h_end; ++h_cur) {
|
||||
for (int w_cur = w_start; w_cur < w_end; ++w_cur) {
|
||||
int input_idx = pos_c * input_depth * input_height * input_width + d_cur * input_height * input_width +
|
||||
h_cur * input_width + w_cur;
|
||||
int input_idx = d_cur * input_height * input_width + h_cur * input_width + w_cur;
|
||||
float input_data = input[input_start + input_idx];
|
||||
if (input_data > max_data) {
|
||||
max_idx = input_idx;
|
||||
|
@ -142,8 +140,7 @@ int MaxPool3DWithArgmax(const float *input, float *output, int *index, size_t st
|
|||
}
|
||||
}
|
||||
}
|
||||
output[pos] = max_data;
|
||||
index[pos] = max_idx;
|
||||
output[pos] = grad[input_start + max_idx];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NNACL_FP32_MAXPOOL_WITH_ARGMAX_H_
|
||||
#define MINDSPORE_NNACL_FP32_MAXPOOL_WITH_ARGMAX_H_
|
||||
#ifndef MINDSPORE_NNACL_FP32_GRAD_MAXPOOL_GRAD_GARD_H_
|
||||
#define MINDSPORE_NNACL_FP32_GRAD_MAXPOOL_GRAD_GARD_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/pooling_parameter.h"
|
||||
|
@ -24,12 +24,13 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int MaxPoolWithArgmax(const float *input, float *output, int *index, size_t start, size_t end, PoolingParameter *param);
|
||||
int MaxPoolGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end,
|
||||
PoolingParameter *param);
|
||||
|
||||
int MaxPool3DWithArgmax(const float *input, float *output, int *index, size_t start, size_t end,
|
||||
Pooling3DParameter *param);
|
||||
int MaxPool3DGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end,
|
||||
Pooling3DParameter *param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_NNACL_FP32_MAXPOOL_WITH_ARGMAX_H_
|
||||
#endif // MINDSPORE_NNACL_FP32_GRAD_MAXPOOL_GRAD_GARD_H_
|
|
@ -0,0 +1,152 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/maxpool_grad_grad_impl.cuh"
|
||||
#include "include/cuda_fp16.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void MaxPoolGradGrad(const T *input, const T *grad, const int n, const int c, const int h, const int w,
|
||||
const int windowHeight, const int windowWidth, const int strideHeight,
|
||||
const int strideWidth, const int padTop, const int padLeft, const int outputHeight,
|
||||
const int outputWidth, const int outputNCHW, const int outputCHW, const int outputHW,
|
||||
T *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (outputNCHW); pos += blockDim.x * gridDim.x) {
|
||||
const int posn = pos / outputCHW;
|
||||
const int posc = pos / outputHW % c;
|
||||
const int posh = pos / outputWidth % outputHeight;
|
||||
const int posw = pos % outputWidth;
|
||||
int hstart = posh * strideHeight - padTop;
|
||||
int wstart = posw * strideWidth - padLeft;
|
||||
const int hend = min(hstart + windowHeight, h);
|
||||
const int wend = min(wstart + windowWidth, w);
|
||||
hstart = max(hstart, 0);
|
||||
wstart = max(wstart, 0);
|
||||
|
||||
int inputStart = posn * c * h * w + posc * h * w;
|
||||
int maxIdx = hstart * w + wstart;
|
||||
T maxData = input[inputStart + maxIdx];
|
||||
for (int hcur = hstart; hcur < hend; ++hcur) {
|
||||
for (int wcur = wstart; wcur < wend; ++wcur) {
|
||||
int inputIdx = hcur * w + wcur;
|
||||
T inputData = input[inputStart + inputIdx];
|
||||
if (inputData > maxData) {
|
||||
maxIdx = inputIdx;
|
||||
maxData = inputData;
|
||||
}
|
||||
}
|
||||
}
|
||||
output[pos] = grad[inputStart + maxIdx];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalMaxPoolGradGrad(const T *input, const T *grad, const int n, const int c, const int h, const int w,
|
||||
const int windowHeight, const int windowWidth, const int strideHeight, const int strideWidth,
|
||||
const int padTop, const int padLeft, const int outputHeight, const int outputWidth, T *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
const int outputNCHW = n * c * outputHeight * outputWidth;
|
||||
const int outputCHW = c * outputHeight * outputWidth;
|
||||
const int outputHW = outputHeight * outputWidth;
|
||||
MaxPoolGradGrad<<<CUDA_BLOCKS(device_id, n * c * outputHeight * outputWidth), CUDA_THREADS(device_id), 0,
|
||||
cuda_stream>>>(input, grad, n, c, h, w, windowHeight, windowWidth, strideHeight, strideWidth,
|
||||
padTop, padLeft, outputHeight, outputWidth, outputNCHW, outputCHW, outputHW, output);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalMaxPoolGradGrad<float>(const float *input, const float *grad, const int n, const int c,
|
||||
const int h, const int w, const int windowHeight,
|
||||
const int windowWidth, const int strideHeight,
|
||||
const int strideWidth, const int padTop, const int padLeft,
|
||||
const int outputHeight, const int outputWidth, float *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalMaxPoolGradGrad<half>(const half *input, const half *grad, const int n, const int c,
|
||||
const int h, const int w, const int windowHeight,
|
||||
const int windowWidth, const int strideHeight,
|
||||
const int strideWidth, const int padTop, const int padLeft,
|
||||
const int outputHeight, const int outputWidth, half *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
__global__ void MaxPool3DGradGrad(const T *input, const T *grad, const int n, const int c, const int d, const int h,
|
||||
const int w, const int windowDepth, const int windowHeight, const int windowWidth,
|
||||
const int strideDepth, const int strideHeight, const int strideWidth,
|
||||
const int padFront, const int padTop, const int padLeft, const int outputDepth,
|
||||
const int outputHeight, const int outputWidth, T *output) {
|
||||
const int outputNCDHW = n * c * outputDepth * outputHeight * outputWidth;
|
||||
const int outputCDHW = c * outputDepth * outputHeight * outputWidth;
|
||||
const int outputDHW = outputDepth * outputHeight * outputWidth;
|
||||
const int outputHW = outputHeight * outputWidth;
|
||||
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (outputNCDHW); pos += blockDim.x * gridDim.x) {
|
||||
const int posn = pos / outputCDHW;
|
||||
const int posc = pos / outputDHW % c;
|
||||
const int posd = pos / outputHW % outputDepth;
|
||||
const int posh = pos / outputWidth % outputHeight;
|
||||
const int posw = pos % outputWidth;
|
||||
|
||||
int dstart = posd * strideDepth - padFront;
|
||||
int hstart = posh * strideHeight - padTop;
|
||||
int wstart = posw * strideWidth - padLeft;
|
||||
const int dend = min(dstart + windowDepth, d);
|
||||
const int hend = min(hstart + windowHeight, h);
|
||||
const int wend = min(wstart + windowWidth, w);
|
||||
dstart = max(dstart, 0);
|
||||
hstart = max(hstart, 0);
|
||||
wstart = max(wstart, 0);
|
||||
|
||||
int inputStart = posn * c * d * h * w + posc * d * h * w;
|
||||
int maxIdx = dstart * h * w + hstart * w + wstart;
|
||||
T maxData = input[inputStart + maxIdx];
|
||||
for (int dcur = dstart; dcur < dend; ++dcur) {
|
||||
for (int hcur = hstart; hcur < hend; ++hcur) {
|
||||
for (int wcur = wstart; wcur < wend; ++wcur) {
|
||||
int inputIdx = dcur * h * w + hcur * w + wcur;
|
||||
T inputData = input[inputStart + inputIdx];
|
||||
if (inputData > maxData) {
|
||||
maxIdx = inputIdx;
|
||||
maxData = inputData;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output[pos] = grad[inputStart + maxIdx];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalMaxPool3DGradGrad(const T *input, const T *grad, const int n, const int c, const int d, const int h,
|
||||
const int w, const int windowDepth, const int windowHeight, const int windowWidth,
|
||||
const int strideDepth, const int strideHeight, const int strideWidth, const int padFront,
|
||||
const int padTop, const int padLeft, const int outputDepth, const int outputHeight,
|
||||
const int outputWidth, T *output, const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
MaxPool3DGradGrad<<<CUDA_BLOCKS(device_id, n * c * outputDepth * outputHeight * outputWidth), CUDA_THREADS(device_id),
|
||||
0, cuda_stream>>>(input, grad, n, c, d, h, w, windowDepth, windowHeight, windowWidth, strideDepth,
|
||||
strideHeight, strideWidth, padFront, padTop, padLeft, outputDepth, outputHeight,
|
||||
outputWidth, output);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalMaxPool3DGradGrad<float>(
|
||||
const float *input, const float *grad, const int n, const int c, const int d, const int h, const int w,
|
||||
const int windowDepth, const int windowHeight, const int windowWidth, const int strideDepth, const int strideHeight,
|
||||
const int strideWidth, const int padFront, const int padTop, const int padLeft, const int outputDepth,
|
||||
const int outputHeight, const int outputWidth, float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalMaxPool3DGradGrad<half>(
|
||||
const half *input, const half *grad, const int n, const int c, const int d, const int h, const int w,
|
||||
const int windowDepth, const int windowHeight, const int windowWidth, const int strideDepth, const int strideHeight,
|
||||
const int strideWidth, const int padFront, const int padTop, const int padLeft, const int outputDepth,
|
||||
const int outputHeight, const int outputWidth, half *output, const uint32_t &device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXPOOL_GRAD_GRAD_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXPOOL_GRAD_GRAD_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalMaxPoolGradGrad(const T *input, const T *grad, const int n, const int c, const int h,
|
||||
const int w, const int windowHeight, const int windowWidth,
|
||||
const int strideHeight, const int strideWidth, const int padTop,
|
||||
const int padLeft, const int outputHeight, const int outputWidth,
|
||||
T *output, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalMaxPool3DGradGrad(const T *input, const T *grad, const int n, const int c, const int d,
|
||||
const int h, const int w, const int windowDepth, const int windowHeight,
|
||||
const int windowWidth, const int strideDepth, const int strideHeight,
|
||||
const int strideWidth, const int padFront, const int padTop,
|
||||
const int padLeft, const int outputDepth, const int outputHeight,
|
||||
const int outputWidth, T *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXPOOL_GRAD_GRAD_IMPL_CUH_
|
|
@ -18,8 +18,8 @@
|
|||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXPOOL_WITH_ARGMAX_GRAD_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
template <typename T, typename S>
|
||||
CUDA_LIB_EXPORT void CalMaxPoolWithArgmaxGrad(const T* dy, const S* index, const int n, const int c, const int xHeight,
|
||||
const int xWidth, const int dyHeight, const int dyWidth, T* dx,
|
||||
CUDA_LIB_EXPORT void CalMaxPoolWithArgmaxGrad(const T *dy, const S *index, const int n, const int c, const int xHeight,
|
||||
const int xWidth, const int dyHeight, const int dyWidth, T *dx,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXPOOL_WITH_ARGMAX_GRAD_IMPL_CUH_
|
||||
|
|
|
@ -75,75 +75,3 @@ template CUDA_LIB_EXPORT void CalMaxPoolWithArgmax<half, int>(
|
|||
const half *input, const int n, const int c, const int h, const int w, const int windowHeight, const int windowWidth,
|
||||
const int strideHeight, const int strideWidth, const int padTop, const int padLeft, const int outputHeight,
|
||||
const int outputWidth, half *output, int *index, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void MaxPool3DWithArgmax(const T *input, const int n, const int c, const int d, const int h, const int w,
|
||||
const int windowDepth, const int windowHeight, const int windowWidth,
|
||||
const int strideDepth, const int strideHeight, const int strideWidth,
|
||||
const int padFront, const int padTop, const int padLeft, const int outputDepth,
|
||||
const int outputHeight, const int outputWidth, T *output, S *index) {
|
||||
const int outputNCDHW = n * c * outputDepth * outputHeight * outputWidth;
|
||||
const int outputCDHW = c * outputDepth * outputHeight * outputWidth;
|
||||
const int outputDHW = outputDepth * outputHeight * outputWidth;
|
||||
const int outputHW = outputHeight * outputWidth;
|
||||
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (outputNCDHW); pos += blockDim.x * gridDim.x) {
|
||||
const int posn = pos / outputCDHW;
|
||||
const int posc = pos / outputDHW % c;
|
||||
const int posd = pos / outputHW % outputDepth;
|
||||
const int posh = pos / outputWidth % outputHeight;
|
||||
const int posw = pos % outputWidth;
|
||||
|
||||
int dstart = posd * strideDepth - padFront;
|
||||
int hstart = posh * strideHeight - padTop;
|
||||
int wstart = posw * strideWidth - padLeft;
|
||||
const int dend = min(dstart + windowDepth, d);
|
||||
const int hend = min(hstart + windowHeight, h);
|
||||
const int wend = min(wstart + windowWidth, w);
|
||||
dstart = max(dstart, 0);
|
||||
hstart = max(hstart, 0);
|
||||
wstart = max(wstart, 0);
|
||||
|
||||
S inputStart = posn * c * d * h * w;
|
||||
S maxIdx = posc * d * h * w + dstart * h * w + hstart * w + wstart;
|
||||
T maxData = input[inputStart + maxIdx];
|
||||
for (int dcur = dstart; dcur < dend; ++dcur) {
|
||||
for (int hcur = hstart; hcur < hend; ++hcur) {
|
||||
for (int wcur = wstart; wcur < wend; ++wcur) {
|
||||
S inputIdx = posc * d * h * w + dcur * h * w + hcur * w + wcur;
|
||||
T inputData = input[inputStart + inputIdx];
|
||||
if (inputData > maxData) {
|
||||
maxIdx = inputIdx;
|
||||
maxData = inputData;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output[pos] = maxData;
|
||||
index[pos] = maxIdx;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void CalMaxPool3DWithArgmax(const T *input, const int n, const int c, const int d, const int h, const int w,
|
||||
const int windowDepth, const int windowHeight, const int windowWidth, const int strideDepth,
|
||||
const int strideHeight, const int strideWidth, const int padFront, const int padTop,
|
||||
const int padLeft, const int outputDepth, const int outputHeight, const int outputWidth,
|
||||
T *output, S *index, const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
MaxPool3DWithArgmax<<<CUDA_BLOCKS(device_id, n * c * d * outputHeight * outputWidth), CUDA_THREADS(device_id), 0,
|
||||
cuda_stream>>>(input, n, c, d, h, w, windowDepth, windowHeight, windowWidth, strideDepth,
|
||||
strideHeight, strideWidth, padFront, padTop, padLeft, outputDepth, outputHeight,
|
||||
outputWidth, output, index);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalMaxPool3DWithArgmax<float, int>(
|
||||
const float *input, const int n, const int c, const int d, const int h, const int w, const int windowDepth,
|
||||
const int windowHeight, const int windowWidth, const int strideDepth, const int strideHeight, const int strideWidth,
|
||||
const int padFront, const int padTop, const int padLeft, const int outputDepth, const int outputHeight,
|
||||
const int outputWidth, float *output, int *index, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalMaxPool3DWithArgmax<half, int>(
|
||||
const half *input, const int n, const int c, const int d, const int h, const int w, const int windowDepth,
|
||||
const int windowHeight, const int windowWidth, const int strideDepth, const int strideHeight, const int strideWidth,
|
||||
const int padFront, const int padTop, const int padLeft, const int outputDepth, const int outputHeight,
|
||||
const int outputWidth, half *output, int *index, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -24,13 +24,4 @@ CUDA_LIB_EXPORT void CalMaxPoolWithArgmax(const T *input, const int n, const int
|
|||
const int outputHeight, const int outputWidth, T *output, S *index,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T, typename S>
|
||||
CUDA_LIB_EXPORT void CalMaxPool3DWithArgmax(const T *input, const int n, const int c, const int d, const int h,
|
||||
const int w, const int windowDepth, const int windowHeight,
|
||||
const int windowWidth, const int strideDepth, const int strideHeight,
|
||||
const int strideWidth, const int padFront, const int padTop,
|
||||
const int padLeft, const int outputDepth, const int outputHeight,
|
||||
const int outputWidth, T *output, S *index, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXPOOL_WITH_ARGMAX_IMPL_CUH_
|
||||
|
|
|
@ -20,15 +20,13 @@
|
|||
#include "mindspore/core/ops/grad/max_pool_grad_grad.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/maxpool_with_argmax_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/gather.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/maxpool_grad_grad_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kMaxPoolGradGradInputsNum = 3;
|
||||
constexpr size_t kMaxPoolGradGradOutputsNum = 1;
|
||||
constexpr size_t kMaxPoolGradGradWorkSpaceNum = 2;
|
||||
constexpr size_t kGradIndex = 2;
|
||||
constexpr size_t kPadHalf = 2;
|
||||
} // namespace
|
||||
|
@ -131,49 +129,28 @@ int MaxPoolGradGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
output_batch_stride_ = std::accumulate(out_shapes_.begin() + 1, out_shapes_.end(), 1, std::multiplies<size_t>());
|
||||
|
||||
CalPad();
|
||||
|
||||
workspace_size_list_.clear();
|
||||
workspace_size_list_.push_back(input_size_list_[1]);
|
||||
auto output_elements = std::accumulate(out_shapes_.begin(), out_shapes_.end(), 1, std::multiplies<size_t>());
|
||||
workspace_size_list_.push_back(sizeof(int32_t) * output_elements);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MaxPoolGradGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
T *output_addr = GetDeviceAddress<T>(workspace, kIndex0);
|
||||
auto *index_addr = GetDeviceAddress<int32_t>(workspace, kIndex1);
|
||||
T *grad_addr = GetDeviceAddress<T>(inputs, kGradIndex);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
if (dim_ == kMaxPool2DGradGradDim) {
|
||||
CalMaxPoolWithArgmax<T, int32_t>(input_addr, batch_, channel_, input_height_, input_width_, window_height_,
|
||||
window_width_, stride_height_, stride_width_, pad_top_, pad_left_, output_height_,
|
||||
output_width_, output_addr, index_addr, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
CalMaxPoolGradGrad<T>(input_addr, grad_addr, batch_, channel_, input_height_, input_width_, window_height_,
|
||||
window_width_, stride_height_, stride_width_, pad_top_, pad_left_, output_height_,
|
||||
output_width_, output_addr, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
} else if (dim_ == kMaxPool3DGradGradDim) {
|
||||
CalMaxPool3DWithArgmax<T, int32_t>(
|
||||
input_addr, batch_, channel_, input_depth_, input_height_, input_width_, window_depth_, window_height_,
|
||||
window_width_, stride_depth_, stride_height_, stride_width_, pad_front_, pad_top_, pad_left_, output_depth_,
|
||||
output_height_, output_width_, output_addr, index_addr, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
CalMaxPool3DGradGrad<T>(input_addr, grad_addr, batch_, channel_, input_depth_, input_height_, input_width_,
|
||||
window_depth_, window_height_, window_width_, stride_depth_, stride_height_, stride_width_,
|
||||
pad_front_, pad_top_, pad_left_, output_depth_, output_height_, output_width_, output_addr,
|
||||
device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', only supports 2D or 3D max pooling.";
|
||||
return false;
|
||||
}
|
||||
|
||||
T *grad_addr = GetDeviceAddress<T>(inputs, kIndex2);
|
||||
T *dx = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
size_t dim_before_axis = 1;
|
||||
size_t dim_at_axis_input = input_batch_stride_;
|
||||
size_t dim_at_axis_output = output_batch_stride_;
|
||||
size_t dim_after_axis = 1;
|
||||
for (int b = 0; b < batch_; b++) {
|
||||
int32_t *index_t = index_addr + b * output_batch_stride_;
|
||||
T *grad_t = grad_addr + b * input_batch_stride_;
|
||||
T *dx_t = dx + b * output_batch_stride_;
|
||||
Gather<T, int32_t>(grad_t, index_t, dx_t, dim_before_axis, dim_at_axis_input, dim_at_axis_output, dim_after_axis,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_), device_id_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ class MaxPoolGradGradGpuKernelMod : public NativeGpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
|
||||
cuda_stream_ = cuda_stream;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
|
@ -48,11 +48,9 @@ class MaxPoolGradGradGpuKernelMod : public NativeGpuKernelMod {
|
|||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
using MaxPoolGradGradFunc =
|
||||
std::function<bool(MaxPoolGradGradGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
using MaxPoolGradGradFunc = std::function<bool(MaxPoolGradGradGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
void *cuda_stream_{nullptr};
|
||||
MaxPoolGradGradFunc kernel_func_{};
|
||||
|
|
|
@ -21,7 +21,7 @@ from mindspore import Tensor
|
|||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class NetPoolGradGrad(nn.Cell):
|
||||
|
|
Loading…
Reference in New Issue