From 081249b53fb8dba39be021d5decb25d3cc4a38c4 Mon Sep 17 00:00:00 2001 From: danish Date: Wed, 5 Aug 2020 16:48:58 -0400 Subject: [PATCH] commit 1 - mirror pad commit 2 lint fix lint fix 2 updated backprop + st test test_file_fix test_file_fix_2 fixed header_guards comments addressed clangFormatFix --- .../gpu/cuda_impl/mirror_pad_impl.cu | 182 ++++++++++++++++++ .../gpu/cuda_impl/mirror_pad_impl.cuh | 31 +++ .../gpu/nn/mirror_pad_gpu_kernel.cc | 30 +++ .../gpu/nn/mirror_pad_gpu_kernel.h | 150 +++++++++++++++ .../gpu/nn/mirror_pad_grad_gpu_kernel.cc | 30 +++ .../gpu/nn/mirror_pad_grad_gpu_kernel.h | 150 +++++++++++++++ mindspore/ops/operations/nn_ops.py | 9 +- tests/st/ops/gpu/test_mirror_pad.py | 88 +++++++++ 8 files changed, 668 insertions(+), 2 deletions(-) create mode 100755 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cu create mode 100755 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_grad_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_grad_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_mirror_pad.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cu new file mode 100755 index 0000000000..c1117b85ff --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cu @@ -0,0 +1,182 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cuh" + +__inline__ __device__ bool range_check(int x, int y, int padded_width, int padded_height) { + // check for existence in current padded array + if (((x >= 0) && (x <= padded_width - 1)) && ((y >= 0) && (y <= padded_height - 1))) { + return true; + } + return false; +} + +template +__global__ void MirrorPad(const size_t size, const T *input, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, const int padd_dim, + const int *paddings, int mode, T *output) { + int padd_offset = 4 * (padd_dim - 2); + int pad_left_ = paddings[padd_offset + 4]; + int pad_top_ = paddings[padd_offset + 0]; + + // Create anchor points for old tensor positions inside new tensor + int ap1_x = pad_left_; + int ap1_y = pad_top_; + int ap2_x = pad_left_ + old_width - 1; + int ap2_y = pad_top_ + old_height - 1; + + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + int block_num = (pos / padded_width) / padded_height; + const int padded_x = pos % padded_width; + const int padded_y = (pos / padded_width) % padded_height; + + // distance to move from anchor point + int x_dist = 0; + int y_dist = 0; + + // x,y value to mirror in new tenspr + int matchval_x_index = padded_x; + int matchval_y_index = padded_y; + + if (padded_y - pad_top_ < 0 || padded_x - pad_left_ < 0 || padded_y - pad_top_ >= old_height || + padded_x - pad_left_ >= old_width) { + if ((padded_x < ap1_x) || (padded_x > ap2_x)) { + x_dist = (padded_x < ap1_x) ? (ap1_x - padded_x) : (padded_x - ap2_x); // GEN DIST + matchval_x_index = (padded_x < ap1_x) ? (ap1_x + x_dist - mode) : (ap2_x - x_dist + mode); + } + if ((padded_y < ap1_y) || (padded_y > ap2_y)) { + y_dist = (padded_y < ap1_y) ? (ap1_y - padded_y) : (padded_y - ap2_y); + matchval_y_index = (padded_y < ap1_y) ? (ap1_y + y_dist - mode) : (ap2_y - y_dist + mode); + } + output[pos] = + input[(block_num * old_height + matchval_y_index - pad_top_) * old_width + matchval_x_index - pad_left_]; + } else { + // existing values remain the same + output[pos] = input[(block_num * old_height + padded_y - pad_top_) * old_width + padded_x - pad_left_]; + } + } + return; +} + +template +__global__ void MirrorPadGrad(const size_t size, const T *dy, const int num, const int channels, + const int padded_height, const int padded_width, const int old_height, + const int old_width, const int padd_dim, const int *paddings, int mode, T *dx) { + int padd_offset = 4 * (padd_dim - 2); + int pad_left_ = paddings[padd_offset + 4]; + int pad_top_ = paddings[padd_offset + 0]; + + // Create anchor points for positions in the dy array + int ap1_x = pad_left_; + int ap1_y = pad_top_; + int ap2_x = pad_left_ + old_width - 1; + int ap2_y = pad_top_ + old_height - 1; + + int adjust = 0; // adjust dist from reflection axis for symmetric padding + if (mode == 1) { + adjust = 1; + } + + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + int block_num = (pos / old_width) / old_height; + + // refer to indices of original values inside padded array + const int padded_x = (pos % old_width) + pad_left_; + const int padded_y = ((pos / old_width) % old_height) + pad_top_; + + // copy positions own value into output + dx[pos] = dx[pos] + dy[(block_num * padded_height + padded_y) * padded_width + padded_x]; + + int x_dist_1 = (ap1_x - padded_x - adjust); + int y_dist_1 = (ap1_y - padded_y - adjust); + int x_dist_2 = (ap2_x - padded_x + adjust); + int y_dist_2 = (ap2_y - padded_y + adjust); + + int axis_dist[] = {x_dist_1, x_dist_2, y_dist_1, y_dist_2}; + int anch_point[] = {ap1_x, ap2_x, ap1_y, ap2_y}; + bool x_axis_check[] = {true, true, false, false}; // true - update X , false - update Y + + int temp_x = 0; + int temp_y = 0; + + // mirroring in axis lines + for (int x = 0; x < 4; x++) { + if (axis_dist[x] != 0) { + if (x_axis_check[x]) { + temp_y = padded_y; + temp_x = anch_point[x] + axis_dist[x]; + } else { + temp_x = padded_x; + temp_y = anch_point[x] + axis_dist[x]; + } + if (range_check(temp_x, temp_y, padded_width, padded_height)) { + dx[pos] = dx[pos] + dy[(block_num * padded_height + temp_y) * padded_width + temp_x]; + } + } + } + + // mirroring at corners + for (int x = 0; x < 2; x++) { + for (int y = 2; y < 4; y++) { + if ((axis_dist[x] != 0) && (axis_dist[y] != 0)) { + temp_x = anch_point[x] + axis_dist[x]; + temp_y = anch_point[y] + axis_dist[y]; + if (range_check(temp_x, temp_y, padded_width, padded_height)) { + dx[pos] = dx[pos] + dy[(block_num * padded_height + temp_y) * padded_width + temp_x]; + } + } + } + } + } + return; +} + +template +void CalMirrorPad(const size_t size, const T *input, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, int padd_num, + const int *paddings, const int mode, T *output, cudaStream_t cuda_stream) { + MirrorPad<<>>( + size, input, num, channels, old_height, old_width, padded_height, padded_width, padd_num, paddings, mode, output); + return; +} + +template +void CalMirrorPadGrad(const size_t size, const T *dy, const int num, const int channels, const int padded_height, + const int padded_width, const int old_height, const int old_width, const int padd_dim, + const int *paddings, int mode, T *dx, cudaStream_t cuda_stream) { + MirrorPadGrad<<>>(size, dy, num, channels, padded_height, padded_width, + old_height, old_width, padd_dim, paddings, mode, dx); + return; +} + +template void CalMirrorPad(const size_t size, const float *input, const int num, const int channels, + const int old_height, const int old_width, const int padded_height, + const int padded_width, int padd_num, const int *paddings, int mode, float *output, + cudaStream_t cuda_stream); +template void CalMirrorPadGrad(const size_t size, const float *dy, const int num, const int channels, + const int old_height, const int old_width, const int padded_height, + const int padded_width, const int padd_dim, const int *paddings, int mode, + float *dx, cudaStream_t cuda_stream); +template void CalMirrorPad(const size_t size, const half *input, const int num, const int channels, + const int old_height, const int old_width, const int padded_height, + const int padded_width, int padd_num, const int *paddings, int mode, half *output, + cudaStream_t cuda_stream); +template void CalMirrorPadGrad(const size_t size, const half *dy, const int num, const int channels, + const int old_height, const int old_width, const int padded_height, + const int padded_width, const int padd_dim, const int *paddings, int mode, + half *dx, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cuh new file mode 100755 index 0000000000..d2bf4dff11 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cuh @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MIRROR_PAD_IMPL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MIRROR_PAD_IMPL_H_ +#include +#include "runtime/device/gpu/cuda_common.h" + +template +void CalMirrorPad(const size_t size, const T *input, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, int padd_num, + const int *paddings, int mode, T *output, cudaStream_t cuda_stream); +template +void CalMirrorPadGrad(const size_t size, const T *dy, const int num, const int channels, const int padded_height, + const int padded_width, const int old_height, const int old_width, const int padd_dim, + const int *paddings, int mode, T *dx, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MIRROR_PAD_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_gpu_kernel.cc new file mode 100644 index 0000000000..306a7b09ab --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/mirror_pad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + MirrorPad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + MirrorPadGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE( + MirrorPad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + MirrorPadGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_gpu_kernel.h new file mode 100644 index 0000000000..c7c4aefb76 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_gpu_kernel.h @@ -0,0 +1,150 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MIRROR_PAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MIRROR_PAD_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class MirrorPadGpuFwdKernel : public GpuKernel { + public: + MirrorPadGpuFwdKernel() + : num_input_(0), num_paddings_(0), mode_(0), input_size_(1), output_size_(1), workspace_size_(0) {} + ~MirrorPadGpuFwdKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + int *paddings = GetDeviceAddress(inputs, 1); + T *output = GetDeviceAddress(outputs, 0); + + size_t size = output_size_ / sizeof(T); + int dim_offset = output_shape_.size() - 2; + + CalMirrorPad(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], + output_shape_[dim_offset + 0], output_shape_[dim_offset + 1], num_paddings_, paddings, mode_, output, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but MirrorPad needs 2 input."; + return false; + } + // check number of output -> should be 1 + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but Pad needs 1 output."; + return false; + } + + string mode = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("mode")); + + if (mode == "REFLECT") { + mode_ = 0; // reflected mirroring + } else { + mode_ = 1; // symmetric mirroring + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + // shape adjustement -> from 2d/3d to 4d to standardize + if (input_shape.size() == 4) { + } else if (input_shape.size() == 3) { + auto it = input_shape.begin(); + input_shape.insert(it, 1); // batch padding + } else if (input_shape.size() == 2) { + auto it = input_shape.begin(); + input_shape.insert(it, 2, 1); // channel padding + } + + for (auto in_shape : input_shape) { + input_size_ *= in_shape; + input_shape_.push_back(in_shape); + } + num_input_ = input_size_; + input_size_ *= sizeof(T); + + auto padding_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + num_paddings_ = padding_shape[0]; + input_size_ += 2 * num_paddings_ * sizeof(int); + + output_size_ = sizeof(T); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (auto x : output_shape) { + output_size_ *= x; + output_shape_.push_back(x); + } + + int max_width = input_shape_[3]; + int max_height = input_shape_[2]; + + // basic error check for padding value + if (mode_ == 1) { // symmetric + max_width = max_width + (2 * max_width); + max_height = max_height + (2 * max_height); + } else { // reflect + max_width = max_width + (2 * (max_width - 1)); + max_height = max_height + (2 * (max_height - 1)); + } + + if (output_shape_[(output_shape_.size() - 2) + 0] > max_width || + output_shape_[(output_shape_.size() - 2) + 1] > max_width) { + MS_LOG(ERROR) << "ERROR: Padding value too high for input Tensor on 1 or more dims"; + return false; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(num_input_ * sizeof(T)); + input_size_list_.push_back(2 * num_paddings_ * sizeof(int)); + output_size_list_.push_back(output_size_); + } + + private: + size_t num_input_; + int num_paddings_; + int mode_; + std::vector input_shape_; // dims of the input data + std::vector output_shape_; // dims of the output data + // default + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MIRROR_PAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_grad_gpu_kernel.cc new file mode 100644 index 0000000000..599e51272b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_grad_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/mirror_pad_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + MirrorPadGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + MirrorPadGpuBackKernel, float) +MS_REG_GPU_KERNEL_ONE( + MirrorPadGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + MirrorPadGpuBackKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_grad_gpu_kernel.h new file mode 100644 index 0000000000..2f793aba77 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/mirror_pad_grad_gpu_kernel.h @@ -0,0 +1,150 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MIRROR_PAD_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MIRROR_PAD_GRAD_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/mirror_pad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class MirrorPadGpuBackKernel : public GpuKernel { + public: + MirrorPadGpuBackKernel() + : num_input_(0), num_paddings_(0), mode_(0), input_size_(1), output_size_(1), workspace_size_(0) {} + ~MirrorPadGpuBackKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + int *paddings = GetDeviceAddress(inputs, 1); + T *output = GetDeviceAddress(outputs, 0); + + size_t size = output_size_ / sizeof(T); + int dim_offset = output_shape_.size() - 2; + + CalMirrorPadGrad(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], + output_shape_[dim_offset + 0], output_shape_[dim_offset + 1], num_paddings_, paddings, mode_, + output, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but MirrorPadGrad needs 2 input."; + return false; + } + // check number of output -> should be 1 + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but MirrorPadGrad needs 1 output."; + return false; + } + + string mode = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("mode")); + + if (mode == "REFLECT") { + mode_ = 0; // reflected mirroring + } else { + mode_ = 1; // symmetric mirroring + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + // shape adjustement -> from 2d/3d to 4d to standardize + if (input_shape.size() == 4) { + } else if (input_shape.size() == 3) { + auto it = input_shape.begin(); + input_shape.insert(it, 1); // batch padding + } else if (input_shape.size() == 2) { + auto it = input_shape.begin(); + input_shape.insert(it, 2, 1); // channel padding + } + + for (auto in_shape : input_shape) { + input_size_ *= in_shape; + input_shape_.push_back(in_shape); + } + num_input_ = input_size_; + input_size_ *= sizeof(T); + + auto padding_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + num_paddings_ = padding_shape[0]; + input_size_ += +(2 * num_paddings_ * sizeof(int)); + + output_size_ = sizeof(T); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (auto x : output_shape) { + output_size_ *= x; + output_shape_.push_back(x); + } + + int max_width = input_shape_[3]; + int max_height = input_shape_[2]; + + // basic error check for padding value + if (mode_ == 1) { // symmetric + max_width = max_width + (2 * max_width); + max_height = max_height + (2 * max_height); + } else { // reflect + max_width = max_width + (2 * (max_width - 1)); + max_height = max_height + (2 * (max_height - 1)); + } + + if (output_shape_[(output_shape_.size() - 2) + 0] > max_width || + output_shape_[(output_shape_.size() - 2) + 1] > max_width) { + MS_LOG(ERROR) << "ERROR: Padding value too high for input Tensor on 1 or more DIMS"; + return false; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(num_input_ * sizeof(T)); + input_size_list_.push_back(2 * num_paddings_ * sizeof(int)); + output_size_list_.push_back(output_size_); + } + + private: + size_t num_input_; + int num_paddings_; + int mode_; + std::vector input_shape_; // dims of the input data + std::vector output_shape_; // dims of the output data + // default + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MIRROR_PAD_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 8ca6521e81..ac4941fa71 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2757,12 +2757,17 @@ class MirrorPad(PrimitiveWithInfer): paddings_value = paddings['value'].asnumpy() paddings_size = paddings_value.size validator.check_integer('paddings.shape', paddings_size, len(x_shape) * 2, Rel.EQ, self.name) - if not np.all(paddings_size >= 0): + if not np.all(paddings_value >= 0): raise ValueError('All elements of paddings must be >= 0.') + adjust = 0 + if self.mode == 'SYMMETRIC': + adjust = 1 + for i in range(0, int(paddings_size / 2)): + if (paddings_value[i, 0] >= x_shape[i] + adjust) or (paddings_value[i, 1] >= x_shape[i] + adjust): + raise ValueError('At least one dim has too high a padding value for this input and mode') y_shape = () for i in range(0, int(paddings_size / 2)): y_shape += ((x_shape[i] + paddings_value[i, 0] + paddings_value[i, 1]),) - return {'shape': y_shape, 'dtype': input_x['dtype'], 'value': None} diff --git a/tests/st/ops/gpu/test_mirror_pad.py b/tests/st/ops/gpu/test_mirror_pad.py new file mode 100644 index 0000000000..9e6613d744 --- /dev/null +++ b/tests/st/ops/gpu/test_mirror_pad.py @@ -0,0 +1,88 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np + +import mindspore +import mindspore.nn as nn +import mindspore.context as context + +from mindspore import Tensor +from mindspore.ops.composite import GradOperation + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_mirror_pad(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + test1_arr_in = [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]] + test_1_paddings = ((0, 0), (0, 0), (1, 1), (2, 2)) + test1_arr_exp = [[[[6, 5, 4, 5, 6, 5, 4], [3, 2, 1, 2, 3, 2, 1], [6, 5, 4, 5, 6, 5, 4], + [9, 8, 7, 8, 9, 8, 7], [6, 5, 4, 5, 6, 5, 4]]]] + + test2_arr_in = [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]] + test_2_paddings = ((0, 0), (0, 0), (1, 1), (2, 2)) + test2_arr_exp = [[[[2, 1, 1, 2, 3, 3, 2], [2, 1, 1, 2, 3, 3, 2], [5, 4, 4, 5, 6, 6, 5], + [8, 7, 7, 8, 9, 9, 8], [8, 7, 7, 8, 9, 9, 8]]]] + + reflectOp = nn.Pad(mode='REFLECT', paddings=test_1_paddings) + symmOp = nn.Pad(mode='SYMMETRIC', paddings=test_2_paddings) + + x_test_1 = Tensor(np.array(test1_arr_in), dtype=mindspore.float32) + x_test_2 = Tensor(np.array(test2_arr_in), dtype=mindspore.float32) + + y_test_1 = reflectOp(x_test_1).asnumpy() + y_test_2 = symmOp(x_test_2).asnumpy() + + print(np.array(test1_arr_in)) + print(y_test_1) + + np.testing.assert_equal(np.array(test1_arr_exp), y_test_1) + np.testing.assert_equal(np.array(test2_arr_exp), y_test_2) + + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) + self.network = network + def construct(self, input_, output_grad): + return self.grad(self.network)(input_, output_grad) + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.pad = nn.Pad(mode="REFLECT", paddings=((0, 0), (0, 0), (1, 0), (0, 2))) + def construct(self, x): + return self.pad(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_mirror_pad_backprop(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_arr_in = [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]] # size -> 3*3 + test_arr_in = Tensor(test_arr_in, dtype=mindspore.float32) + dy = (np.ones((1, 1, 4, 5)) * 0.1).astype(np.float32) + expected_dx = np.array([[[[0.2, 0.2, 0.1], + [0.4, 0.4, 0.2], + [0.2, 0.2, 0.1]]]]) + net = Grad(Net()) + dx = net(test_arr_in, Tensor(dy)) + dx = dx[0].asnumpy() + np.testing.assert_array_almost_equal(dx, expected_dx)