add maxpool_with_argmax/grad cuda kernel

This commit is contained in:
panfengfeng 2020-07-21 06:21:07 -08:00
parent 9dc23eeb98
commit ca881ec03e
15 changed files with 1128 additions and 71 deletions

View File

@ -0,0 +1,226 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "maxpool_with_argmax_grad_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
#include "include/cuda_fp16.h"
template <typename T, typename S>
__global__ void MaxPoolWithArgmaxGrad(const T* x,
const T* dy,
const S* index,
const int n,
const int c,
const int xHeight,
const int xWidth,
const int dyHeight,
const int dyWidth,
const int windowHeight,
const int windowWidth,
const int strideHeight,
const int strideWidth,
const int padTop,
const int padLeft,
const int xNCHW,
const int xCHW,
const int xHW,
const int dyCHW,
const int dyHW,
T* dx) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x;
pos < (xNCHW);
pos += blockDim.x * gridDim.x) {
const int posn = pos / xCHW;
const int posc = pos / xHW % c;
const int posh = pos / xHeight % xHeight;
const int posw = pos % xWidth;
const S posIdx = posh*xWidth + posw;
int hstart = posh+padTop;
if (hstart < windowHeight) {
hstart = 0;
} else {
hstart = (hstart-windowHeight)/strideHeight + 1;
}
int wstart = posw+padLeft;
if (wstart < windowWidth) {
wstart = 0;
} else {
wstart = (wstart-windowWidth)/strideWidth + 1;
}
const int hend = min((posh+padTop)/strideHeight +1, dyHeight);
const int wend = min((posw+padLeft)/strideWidth +1, dyWidth);
const int channelStart = posn*dyCHW + posc*dyHW;
T dySum = static_cast<T>(0.0);
for (int hcur = hstart; hcur < hend; ++hcur) {
for (int wcur = wstart; wcur < wend; ++wcur) {
const int curIdx = hcur*dyWidth + wcur;
S maxIdx = index[channelStart+curIdx];
if (maxIdx == posIdx) {
dySum += dy[channelStart+curIdx];
}
}
}
dx[pos] = dySum;
}
return;
}
template <>
__global__ void MaxPoolWithArgmaxGrad(const half* x,
const half* dy,
const int* index,
const int n,
const int c,
const int xHeight,
const int xWidth,
const int dyHeight,
const int dyWidth,
const int windowHeight,
const int windowWidth,
const int strideHeight,
const int strideWidth,
const int padTop,
const int padLeft,
const int xNCHW,
const int xCHW,
const int xHW,
const int dyCHW,
const int dyHW,
half* dx) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x;
pos < (xNCHW);
pos += blockDim.x * gridDim.x) {
const int posn = pos / xCHW;
const int posc = pos / xHW % c;
const int posh = pos / xHeight % xHeight;
const int posw = pos % xWidth;
const int posIdx = posh*xWidth + posw;
int hstart = posh+padTop;
if (hstart < windowHeight) {
hstart = 0;
} else {
hstart = (hstart-windowHeight)/strideHeight + 1;
}
int wstart = posw+padLeft;
if (wstart < windowWidth) {
wstart = 0;
} else {
wstart = (wstart-windowWidth)/strideWidth + 1;
}
const int hend = min((posh+padTop)/strideHeight +1, dyHeight);
const int wend = min((posw+padLeft)/strideWidth +1, dyWidth);
const int channelStart = posn*dyCHW + posc*dyHW;
float dySum = 0.0f;
for (int hcur = hstart; hcur < hend; ++hcur) {
for (int wcur = wstart; wcur < wend; ++wcur) {
const int curIdx = hcur*dyWidth + wcur;
int maxIdx = index[channelStart+curIdx];
if (maxIdx == posIdx) {
dySum += __half2float(dy[channelStart+curIdx]);
}
}
}
dx[pos] = __float2half(dySum);
}
return;
}
template <typename T, typename S>
void CalMaxPoolWithArgmaxGrad(const T* x,
const T* dy,
const S* index,
const int n,
const int c,
const int xHeight,
const int xWidth,
const int dyHeight,
const int dyWidth,
const int windowHeight,
const int windowWidth,
const int strideHeight,
const int strideWidth,
const int padTop,
const int padLeft,
T* dx,
cudaStream_t cuda_stream) {
const int xHW = xHeight*xWidth;
const int xCHW = c*xHW;
const int xNCHW = n*xCHW;
const int dyHW = dyHeight*dyWidth;
const int dyCHW = c*dyHW;
MaxPoolWithArgmaxGrad<<<GET_BLOCKS(xNCHW),
GET_THREADS,
0,
cuda_stream>>>(
x,
dy,
index,
n,
c,
xHeight,
xWidth,
dyHeight,
dyWidth,
windowHeight,
windowWidth,
strideHeight,
strideWidth,
padTop,
padLeft,
xNCHW,
xCHW,
xHW,
dyCHW,
dyHW,
dx);
return;
}
template void CalMaxPoolWithArgmaxGrad<float, int>(const float* x,
const float* dy,
const int* index,
const int n,
const int c,
const int xHeight,
const int xWidth,
const int dyHeight,
const int dyWidth,
const int windowHeight,
const int windowWidth,
const int strideHeight,
const int strideWidth,
const int padTop,
const int padLeft,
float* dx,
cudaStream_t cuda_stream);
template void CalMaxPoolWithArgmaxGrad<half, int>(const half* x,
const half* dy,
const int* index,
const int n,
const int c,
const int xHeight,
const int xWidth,
const int dyHeight,
const int dyWidth,
const int windowHeight,
const int windowWidth,
const int strideHeight,
const int strideWidth,
const int padTop,
const int padLeft,
half* dx,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,25 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_
template <typename T, typename S>
void CalMaxPoolWithArgmaxGrad(const T* x, const T* dy, const S* index, const int n, const int c, const int xHeight,
const int xWidth, const int dyHeight, const int dyWidth, const int windowHeight,
const int windowWidth, const int strideHeight, const int strideWidth, const int padTop,
const int padLeft, T* dx, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_

View File

@ -0,0 +1,149 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "maxpool_with_argmax_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
#include "include/cuda_fp16.h"
template <typename T, typename S>
__global__ void MaxPoolWithArgmax(const T* input,
const int n,
const int c,
const int h,
const int w,
const int windowHeight,
const int windowWidth,
const int strideHeight,
const int strideWidth,
const int padTop,
const int padLeft,
const int outputHeight,
const int outputWidth,
const int outputNCHW,
const int outputCHW,
const int outputHW,
T* output,
S *index) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x;
pos < (outputNCHW);
pos += blockDim.x * gridDim.x) {
const int posn = pos / outputCHW;
const int posc = pos / outputHW % c;
const int posh = pos / outputHeight % outputHeight;
const int posw = pos % outputWidth;
int hstart = posh * strideHeight - padTop;
int wstart = posw * strideWidth - padLeft;
const int hend = min(hstart + windowHeight, h);
const int wend = min(wstart + windowWidth, w);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
S inputStart = posn*c*h*w + posc*h*w;
S maxIdx = hstart*w + wstart;
T maxData = input[inputStart+maxIdx];
for (int hcur = hstart; hcur < hend; ++hcur) {
for (int wcur = wstart; wcur < wend; ++wcur) {
S inputIdx = hcur*w + wcur;
T inputData = input[inputStart+inputIdx];
if (inputData > maxData) {
maxIdx = inputIdx;
maxData = inputData;
}
}
}
output[pos] = maxData;
index[pos] = maxIdx;
}
return;
}
template <typename T, typename S>
void CalMaxPoolWithArgmax(const T* input,
const int n,
const int c,
const int h,
const int w,
const int windowHeight,
const int windowWidth,
const int strideHeight,
const int strideWidth,
const int padTop,
const int padLeft,
const int outputHeight,
const int outputWidth,
T* output,
S *index,
cudaStream_t cuda_stream) {
const int outputNCHW = n*c*outputHeight*outputWidth;
const int outputCHW = c*outputHeight*outputWidth;
const int outputHW = outputHeight*outputWidth;
MaxPoolWithArgmax<<<GET_BLOCKS(n*c*outputHeight*outputWidth),
GET_THREADS,
0,
cuda_stream>>>(
input,
n,
c,
h,
w,
windowHeight,
windowWidth,
strideHeight,
strideWidth,
padTop,
padLeft,
outputHeight,
outputWidth,
outputNCHW,
outputCHW,
outputHW,
output,
index);
return;
}
template void CalMaxPoolWithArgmax<float, int>(const float* input,
const int n,
const int c,
const int h,
const int w,
const int windowHeight,
const int windowWidth,
const int strideHeight,
const int strideWidth,
const int padTop,
const int padLeft,
const int outputHeight,
const int outputWidth,
float* output,
int* index,
cudaStream_t cuda_stream);
template void CalMaxPoolWithArgmax<half, int>(const half* input,
const int n,
const int c,
const int h,
const int w,
const int windowHeight,
const int windowWidth,
const int strideHeight,
const int strideWidth,
const int padTop,
const int padLeft,
const int outputHeight,
const int outputWidth,
half* output,
int* index,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,25 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_H_
template <typename T, typename S>
void CalMaxPoolWithArgmax(const T* input, const int n, const int c, const int h, const int w, const int windowHeight,
const int windowWidth, const int strideHeight, const int strideWidth, const int padTop,
const int padLeft, const int outputHeight, const int outputWidth, T* output, S *index,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_H_

View File

@ -0,0 +1,30 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
MaxPoolWithArgmax,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
MaxPoolWithArgmaxGpuFwdKernel, float, int)
MS_REG_GPU_KERNEL_TWO(
MaxPoolWithArgmax,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
MaxPoolWithArgmaxGpuFwdKernel, half, int)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,160 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GPU_KERNEL_H_
#include <algorithm>
#include <vector>
#include <string>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_impl.cuh"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class MaxPoolWithArgmaxGpuFwdKernel : public GpuKernel {
public:
MaxPoolWithArgmaxGpuFwdKernel()
: n_(0),
c_(0),
input_height_(0),
input_width_(0),
window_height_(0),
window_width_(0),
pad_height_(0),
pad_width_(0),
pad_top_(0),
pad_left_(0),
stride_height_(0),
stride_width_(0),
output_height_(0),
output_width_(0),
input_size_(0),
output_size_(0) {}
~MaxPoolWithArgmaxGpuFwdKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
S *index_addr = GetDeviceAddress<S>(outputs, 1);
CalMaxPoolWithArgmax(input_addr, n_, c_, input_height_, input_width_, window_height_, window_width_, stride_height_,
stride_width_, pad_top_, pad_left_, output_height_, output_width_, output_addr, index_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but MaxPoolWithArgmax needs 1 inputs.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 2) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but MaxPoolWithArgmax needs 2 output.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
input_size_ = sizeof(T);
for (auto x : input_shape) {
input_size_ *= x;
}
output_size_ = sizeof(T);
for (auto x : output_shape) {
output_size_ *= x;
}
n_ = SizeToInt(input_shape[0]);
c_ = SizeToInt(input_shape[1]);
input_height_ = SizeToInt(input_shape[2]);
input_width_ = SizeToInt(input_shape[3]);
output_height_ = SizeToInt(output_shape[2]);
output_width_ = SizeToInt(output_shape[3]);
auto window = GetValue<std::vector<int>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize"));
window_height_ = window[1];
window_width_ = window[2];
auto stride = GetValue<std::vector<int>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides"));
stride_height_ = stride[1];
stride_width_ = stride[2];
pad_mode_ = GetValue<std::string>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding"));
pad_top_ = 0;
pad_left_ = 0;
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) {
SetPad();
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
output_size_list_.push_back(output_size_ / sizeof(T) * sizeof(S));
}
private:
void SetPad() {
pad_height_ = std::max<int>(
0, (((input_height_ / stride_height_) * stride_height_ == input_height_ ? (input_height_ / stride_height_)
: (input_height_ / stride_height_) + 1) -
1) *
stride_height_ +
window_height_ - input_height_);
pad_width_ = std::max<int>(
0, (((input_width_ / stride_width_) * stride_width_ == input_width_ ? (input_width_ / stride_width_)
: (input_width_ / stride_width_) + 1) -
1) *
stride_width_ +
window_width_ - input_width_);
pad_top_ = pad_height_ / 2;
pad_left_ = pad_width_ / 2;
}
std::string pad_mode_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int n_;
int c_;
int input_height_;
int input_width_;
int window_height_;
int window_width_;
int pad_height_;
int pad_width_;
int pad_top_;
int pad_left_;
int stride_height_;
int stride_width_;
int output_height_;
int output_width_;
size_t input_size_;
size_t output_size_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GPU_KERNEL_H_

View File

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/maxpool_with_argmax_grad_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(MaxPoolGradWithArgmax,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
MaxPoolWithArgmaxGradGpuKernel, float, int)
MS_REG_GPU_KERNEL_TWO(MaxPoolGradWithArgmax,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
MaxPoolWithArgmaxGradGpuKernel, half, int)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,168 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GRAD_GPU_KERNEL_H_
#include <algorithm>
#include <vector>
#include <string>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_grad_impl.cuh"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class MaxPoolWithArgmaxGradGpuKernel : public GpuKernel {
public:
MaxPoolWithArgmaxGradGpuKernel()
: n_(0),
c_(0),
x_height_(0),
x_width_(0),
dy_height_(0),
dy_width_(0),
x_size_(0),
dy_size_(0),
index_size_(0),
dx_size_(0) {}
~MaxPoolWithArgmaxGradGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
T *x_addr = GetDeviceAddress<T>(inputs, 0);
T *dy_addr = GetDeviceAddress<T>(inputs, 1);
S *index_addr = GetDeviceAddress<S>(inputs, 2);
T *dx_addr = GetDeviceAddress<T>(outputs, 0);
CalMaxPoolWithArgmaxGrad(x_addr, dy_addr, index_addr, n_, c_, x_height_, x_width_, dy_height_, dy_width_,
window_height_, window_width_, stride_height_, stride_width_, pad_top_, pad_left_, dx_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but MaxPoolGradWithArgmax needs 3 inputs.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but MaxPoolGradWithArgmax needs 1 output.";
return false;
}
auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto index_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
auto dx_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
x_size_ = sizeof(T);
for (auto x : x_shape) {
x_size_ *= x;
}
dy_size_ = sizeof(T);
for (auto x : dy_shape) {
dy_size_ *= x;
}
index_size_ = sizeof(S);
for (auto x : index_shape) {
index_size_ *= x;
}
dx_size_ = sizeof(T);
for (auto x : dx_shape) {
dx_size_ *= x;
}
n_ = SizeToInt(x_shape[0]);
c_ = SizeToInt(x_shape[1]);
x_height_ = SizeToInt(x_shape[2]);
x_width_ = SizeToInt(x_shape[3]);
dy_height_ = SizeToInt(dy_shape[2]);
dy_width_ = SizeToInt(dy_shape[3]);
auto window = GetValue<std::vector<int>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize"));
window_height_ = window[1];
window_width_ = window[2];
auto stride = GetValue<std::vector<int>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides"));
stride_height_ = stride[1];
stride_width_ = stride[2];
pad_mode_ = GetValue<std::string>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding"));
pad_top_ = 0;
pad_left_ = 0;
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) {
SetPad();
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(x_size_);
input_size_list_.push_back(dy_size_);
input_size_list_.push_back(index_size_);
output_size_list_.push_back(dx_size_);
}
private:
void SetPad() {
pad_height_ = std::max<int>(
0, (((x_height_ / stride_height_) * stride_height_ == x_height_ ? (x_height_ / stride_height_)
: (x_height_ / stride_height_) + 1) -
1) *
stride_height_ +
window_height_ - x_height_);
pad_width_ =
std::max<int>(0, (((x_width_ / stride_width_) * stride_width_ == x_width_ ? (x_width_ / stride_width_)
: (x_width_ / stride_width_) + 1) -
1) *
stride_width_ +
window_width_ - x_width_);
pad_top_ = pad_height_ / 2;
pad_left_ = pad_width_ / 2;
}
std::string pad_mode_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int n_;
int c_;
int x_height_;
int x_width_;
int dy_height_;
int dy_width_;
int window_height_;
int window_width_;
int pad_height_;
int pad_width_;
int pad_top_;
int pad_left_;
int stride_height_;
int stride_width_;
size_t x_size_;
size_t dy_size_;
size_t index_size_;
size_t dx_size_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GRAD_GPU_KERNEL_H_

View File

@ -1181,6 +1181,7 @@ class MaxPoolWithArgmax(_Pool):
def __init__(self, ksize=1, strides=1, padding="valid"):
super(MaxPoolWithArgmax, self).__init__(ksize, strides, padding)
self.is_tbe = context.get_context("device_target") == "Ascend"
self.is_gpu = context.get_context("device_target") == "GPU"
def infer_shape(self, x_shape):
out_shape = _Pool.infer_shape(self, x_shape)
@ -1207,6 +1208,8 @@ class MaxPoolWithArgmax(_Pool):
out_dtype = x_dtype
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
argmax_dtype = mstype.uint16
if self.is_gpu:
argmax_dtype = mstype.int32
return out_dtype, argmax_dtype

View File

@ -38,7 +38,7 @@ if __name__ == '__main__':
if device_target == "Ascend":
context.set_context(device_id=cfg.device_id)
net = GoogleNet(num_classes=cfg.num_classes, platform=device_target)
net = GoogleNet(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)

View File

@ -41,5 +41,10 @@ mkdir ../train
cd ../train || exit
export CUDA_VISIBLE_DEVICES="$2"
if [ $1 -gt 1 ]
then
mpirun -n $1 --allow-run-as-root \
python3 ${BASEPATH}/../train.py > train.log 2>&1 &
else
python3 ${BASEPATH}/../train.py > train.log 2>&1 &
fi

View File

@ -56,35 +56,24 @@ class Inception(nn.Cell):
Inception Block
"""
def __init__(self, in_channels, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes, platform="Ascend"):
def __init__(self, in_channels, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
super(Inception, self).__init__()
self.platform = platform
self.b1 = Conv2dBlock(in_channels, n1x1, kernel_size=1)
self.b2 = nn.SequentialCell([Conv2dBlock(in_channels, n3x3red, kernel_size=1),
Conv2dBlock(n3x3red, n3x3, kernel_size=3, padding=0)])
self.b3 = nn.SequentialCell([Conv2dBlock(in_channels, n5x5red, kernel_size=1),
Conv2dBlock(n5x5red, n5x5, kernel_size=3, padding=0)])
if self.platform == "Ascend":
self.maxpool = P.MaxPoolWithArgmax(ksize=3, strides=1, padding="same")
else: # GPU
self.maxpool = P.MaxPool(ksize=3, strides=1, padding="same")
self.b4 = Conv2dBlock(in_channels, pool_planes, kernel_size=1)
self.concat = P.Concat(axis=1)
def construct(self, x):
'''
construct inception model
'''
branch1 = self.b1(x)
branch2 = self.b2(x)
branch3 = self.b3(x)
if self.platform == "Ascend":
cell, argmax = self.maxpool(x)
branch4 = self.b4(cell)
_ = argmax
else: # GPU
cell = self.maxpool(x)
branch4 = self.b4(cell)
return self.concat((branch1, branch2, branch3, branch4))
@ -93,82 +82,61 @@ class GoogleNet(nn.Cell):
Googlenet architecture
"""
def __init__(self, num_classes, platform="Ascend"):
def __init__(self, num_classes):
super(GoogleNet, self).__init__()
self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0)
self.platform = platform
self.maxpool1 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same")
self.conv2 = Conv2dBlock(64, 64, kernel_size=1)
self.conv3 = Conv2dBlock(64, 192, kernel_size=3, padding=0)
self.block3a = Inception(192, 64, 96, 128, 16, 32, 32, platform=self.platform)
self.block3b = Inception(256, 128, 128, 192, 32, 96, 64, platform=self.platform)
self.block4a = Inception(480, 192, 96, 208, 16, 48, 64, platform=self.platform)
self.block4b = Inception(512, 160, 112, 224, 24, 64, 64, platform=self.platform)
self.block4c = Inception(512, 128, 128, 256, 24, 64, 64, platform=self.platform)
self.block4d = Inception(512, 112, 144, 288, 32, 64, 64, platform=self.platform)
self.block4e = Inception(528, 256, 160, 320, 32, 128, 128, platform=self.platform)
if self.platform == "Ascend":
self.maxpool1 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same")
self.maxpool2 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same")
self.block3a = Inception(192, 64, 96, 128, 16, 32, 32)
self.block3b = Inception(256, 128, 128, 192, 32, 96, 64)
self.maxpool3 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same")
self.block4a = Inception(480, 192, 96, 208, 16, 48, 64)
self.block4b = Inception(512, 160, 112, 224, 24, 64, 64)
self.block4c = Inception(512, 128, 128, 256, 24, 64, 64)
self.block4d = Inception(512, 112, 144, 288, 32, 64, 64)
self.block4e = Inception(528, 256, 160, 320, 32, 128, 128)
self.maxpool4 = P.MaxPoolWithArgmax(ksize=2, strides=2, padding="same")
else: # GPU
self.maxpool1 = P.MaxPool(ksize=3, strides=2, padding="same")
self.maxpool2 = P.MaxPool(ksize=3, strides=2, padding="same")
self.maxpool3 = P.MaxPool(ksize=3, strides=2, padding="same")
self.maxpool4 = P.MaxPool(ksize=2, strides=2, padding="same")
self.block5a = Inception(832, 256, 160, 320, 32, 128, 128, platform=self.platform)
self.block5b = Inception(832, 384, 192, 384, 48, 128, 128, platform=self.platform)
self.block5a = Inception(832, 256, 160, 320, 32, 128, 128)
self.block5b = Inception(832, 384, 192, 384, 48, 128, 128)
self.mean = P.ReduceMean(keep_dims=True)
self.dropout = nn.Dropout(keep_prob=0.8)
self.flatten = nn.Flatten()
self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(),
bias_init=weight_variable())
def construct(self, x):
'''
construct googlenet model
'''
x = self.conv1(x)
if self.platform == "Ascend":
x, argmax = self.maxpool1(x)
else: # GPU
x = self.maxpool1(x)
x = self.conv2(x)
x = self.conv3(x)
if self.platform == "Ascend":
x, argmax = self.maxpool2(x)
else: # GPU
x = self.maxpool2(x)
x = self.block3a(x)
x = self.block3b(x)
if self.platform == "Ascend":
x, argmax = self.maxpool3(x)
else: # GPU
x = self.maxpool3(x)
x = self.block4a(x)
x = self.block4b(x)
x = self.block4c(x)
x = self.block4d(x)
x = self.block4e(x)
if self.platform == "Ascend":
x, argmax = self.maxpool4(x)
x = self.block5a(x)
x = self.block5b(x)
x = self.mean(x, (2, 3))
x = self.flatten(x)
x = self.classifier(x)
_ = argmax
else: # GPU
x = self.maxpool4(x)
x = self.block5a(x)
x = self.block5b(x)
x = self.mean(x, (2, 3))
x = self.flatten(x)
x = self.classifier(x)
return x

View File

@ -93,7 +93,7 @@ if __name__ == '__main__':
dataset = create_dataset(cfg.data_path, 1)
batch_num = dataset.get_dataset_size()
net = GoogleNet(num_classes=cfg.num_classes, platform=device_target)
net = GoogleNet(num_classes=cfg.num_classes)
# Continue training if set pre_trained to be True
if cfg.pre_trained:
param_dict = load_checkpoint(cfg.checkpoint_path)

View File

@ -0,0 +1,147 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class Net_Pool(nn.Cell):
def __init__(self):
super(Net_Pool, self).__init__()
self.maxpool_fun = P.MaxPoolWithArgmax(ksize=2, strides=2, padding="VALID")
def construct(self, x):
return self.maxpool_fun(x)
class Net_Pool2(nn.Cell):
def __init__(self):
super(Net_Pool2, self).__init__()
self.maxpool_fun = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="SAME")
def construct(self, x):
return self.maxpool_fun(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_maxpool_with_argmax_2d():
x = Tensor(np.array([[[
[0, 1, 2, 3, -4, -5],
[6, 7, 8, 9, -10, -11],
[12, 13, 14, -15, -16, -17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]
]]]).astype(np.float32))
expect_result = (np.array([[[
[7, 9, -4],
[19, 21, 23],
[31, 33, 35]
]]]))
expect_result2 = (np.array([[[
[14, 14, -4],
[26, 28, 29],
[32, 34, 35]
]]]))
expect_index_result = (np.array([[[
[7, 9, 4],
[19, 21, 23],
[31, 33, 35]
]]]))
expect__index_result2 = (np.array([[[
[14, 14, 4],
[26, 28, 29],
[32, 34, 35]
]]]))
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
maxpool2d = Net_Pool()
maxpool2d2 = Net_Pool2()
output2, index2 = maxpool2d2(x)
output, index = maxpool2d(x)
assert (output.asnumpy() == expect_result).all()
assert (output2.asnumpy() == expect_result2).all()
assert (index.asnumpy() == expect_index_result).all()
assert (index2.asnumpy() == expect__index_result2).all()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
maxpool2d = Net_Pool()
maxpool2d2 = Net_Pool2()
output2, index2 = maxpool2d2(x)
output, index = maxpool2d(x)
assert (output.asnumpy() == expect_result).all()
assert (output2.asnumpy() == expect_result2).all()
assert (index.asnumpy() == expect_index_result).all()
assert (index2.asnumpy() == expect__index_result2).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_maxpool_with_argmax_2d_fp16():
x = Tensor(np.array([[[
[0, 1, 2, 3, -4, -5],
[6, 7, 8, 9, -10, -11],
[12, 13, 14, -15, -16, -17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]
]]]).astype(np.float16))
expect_result = (np.array([[[
[7, 9, -4],
[19, 21, 23],
[31, 33, 35]
]]]))
expect_result2 = (np.array([[[
[14, 14, -4],
[26, 28, 29],
[32, 34, 35]
]]]))
expect_index_result = (np.array([[[
[7, 9, 4],
[19, 21, 23],
[31, 33, 35]
]]]))
expect__index_result2 = (np.array([[[
[14, 14, 4],
[26, 28, 29],
[32, 34, 35]
]]]))
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
maxpool2d = Net_Pool()
maxpool2d2 = Net_Pool2()
output2, index2 = maxpool2d2(x)
output, index = maxpool2d(x)
assert (output.asnumpy() == expect_result).all()
assert (output2.asnumpy() == expect_result2).all()
assert (index.asnumpy() == expect_index_result).all()
assert (index2.asnumpy() == expect__index_result2).all()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
maxpool2d = Net_Pool()
maxpool2d2 = Net_Pool2()
output2, index2 = maxpool2d2(x)
output, index = maxpool2d(x)
assert (output.asnumpy() == expect_result).all()
assert (output2.asnumpy() == expect_result2).all()
assert (index.asnumpy() == expect_index_result).all()
assert (index2.asnumpy() == expect__index_result2).all()

View File

@ -0,0 +1,115 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
class Net_Pool_Grad(nn.Cell):
def __init__(self):
super(Net_Pool_Grad, self).__init__()
self.maxpool_grad_fun = G.MaxPoolGradWithArgmax(padding="VALID", ksize=2, strides=2)
def construct(self, x, dy, index):
return self.maxpool_grad_fun(x, dy, index)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_maxpool2d_grad():
x = Tensor(np.array([[[
[0, 1, 2, 3, 4, 5],
[6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]
]]]).astype(np.float32))
dy = Tensor(np.array([[[
[0.7, 0.9, 0.11],
[0.19, 0.21, 0.23],
[0.31, 0.33, 0.35]
]]]).astype(np.float32))
index = Tensor(np.array([[[
[7, 9, 11],
[19, 21, 23],
[31, 33, 35]
]]]).astype(np.int32))
expect_result = (np.array([[[
[0., 0., 0., 0., 0., 0.],
[0., 0.7, 0., 0.9, 0., 0.11],
[0., 0., 0., 0., 0., 0.],
[0., 0.19, 0., 0.21, 0., 0.23],
[0., 0., 0., 0., 0., 0.],
[0., 0.31, 0., 0.33, 0., 0.35]
]]]))
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
maxpool2d_grad = Net_Pool_Grad()
output = maxpool2d_grad(x, dy, index)
assert np.allclose(expect_result, output.asnumpy())
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
maxpool2d_grad = Net_Pool_Grad()
output = maxpool2d_grad(x, dy, index)
assert np.allclose(expect_result, output.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_maxpool2d_grad_fp16():
x = Tensor(np.array([[[
[0, 1, 2, 3, 4, 5],
[6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]
]]]).astype(np.float16))
dy = Tensor(np.array([[[
[0.7, 0.9, 0.11],
[0.19, 0.21, 0.23],
[0.31, 0.33, 0.35]
]]]).astype(np.float16))
index = Tensor(np.array([[[
[7, 9, 11],
[19, 21, 23],
[31, 33, 35]
]]]).astype(np.int32))
expect_result = np.array([[[
[0., 0., 0., 0., 0., 0.],
[0., 0.7, 0., 0.9, 0., 0.11],
[0., 0., 0., 0., 0., 0.],
[0., 0.19, 0., 0.21, 0., 0.23],
[0., 0., 0., 0., 0., 0.],
[0., 0.31, 0., 0.33, 0., 0.35]
]]]).astype(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
maxpool2d_grad = Net_Pool_Grad()
output = maxpool2d_grad(x, dy, index)
assert np.allclose(expect_result, output.asnumpy())
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
maxpool2d_grad = Net_Pool_Grad()
output = maxpool2d_grad(x, dy, index)
assert np.allclose(expect_result, output.asnumpy())