From 4fce4c7c34e4eb7f1656fb0683404ab917037663 Mon Sep 17 00:00:00 2001 From: mamba_ni Date: Mon, 3 Aug 2020 20:11:59 +0800 Subject: [PATCH] support img2col for resnet50_thor GPU primitive for im2col fix bug clang code format clang format fix fix pylint fix license delete useless code --- .../gpu/cuda_impl/identity_impl.cu | 40 +++ .../gpu/cuda_impl/identity_impl.cuh | 24 ++ .../gpu/cuda_impl/matrix_combine_impl.cu | 72 +++++ .../gpu/cuda_impl/matrix_combine_impl.cuh | 27 ++ .../gpu/cuda_impl/matrix_split_impl.cu | 70 +++++ .../gpu/cuda_impl/matrix_split_impl.cuh | 25 ++ .../gpu/nn/im2col_gpu_kernel.cc | 26 ++ .../gpu/nn/im2col_gpu_kernel.h | 269 ++++++++++++++++++ mindspore/ops/operations/__init__.py | 5 +- mindspore/ops/operations/thor_ops.py | 118 +++++++- 10 files changed, 674 insertions(+), 2 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_combine_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_combine_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cu new file mode 100644 index 00000000000..ecb44c45acf --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cu @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "identity_impl.cuh" +#include +template +__global__ void IdentityKernel(const size_t size, const size_t dim, T *output_addr) { + for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { + size_t batchIdx = pointIdx / (dim * dim); + size_t dst_x = (pointIdx - batchIdx * dim * dim) / dim; + size_t dst_y = (pointIdx - batchIdx * dim * dim) % dim; + if (dst_x == dst_y) { + output_addr[pointIdx] = 1; + } else { + output_addr[pointIdx] = 0; + } + } +} + +template +void Identity(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream) { + IdentityKernel<<>>(size, dim, output_addr); + return; +} + +template void Identity(const size_t size, const size_t dim, float *output_addr, cudaStream_t cuda_stream); + diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh new file mode 100644 index 00000000000..b8fd4a0be3f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IDENTITY_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IDENTITY_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void Identity(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_combine_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_combine_impl.cu new file mode 100644 index 00000000000..b1bd5fdb695 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_combine_impl.cu @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "matrix_combine_impl.cuh" +#include +template +__global__ void MatrixCombineKernel(const size_t size, const size_t src_height, const size_t src_width, + const size_t dst_width, T *input_addr, T *output_addr) { + for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { + size_t batchIdx = pointIdx / (src_height * src_width); + size_t src_h = (pointIdx - batchIdx * src_height * src_width) / src_width; + size_t src_w = (pointIdx - batchIdx * src_height * src_width) % src_width; + size_t dst_h = src_height * batchIdx + src_h; + size_t dst_w = src_width * batchIdx + src_w; + output_addr[dst_h * dst_width + dst_w] = input_addr[pointIdx]; + } +} + +template +__global__ void MatrixCombineKernel(const size_t size, const size_t src_height, const size_t src_width, + const size_t dst_width, const size_t res_width, const size_t batch, T *input_addr, + T *output_addr) { + for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { + size_t batchIdx = pointIdx / (src_height * src_width); + if (batchIdx != (batch - 1)) { + size_t src_h = (pointIdx - batchIdx * src_height * src_width) / src_width; + size_t src_w = (pointIdx - batchIdx * src_height * src_width) % src_width; + size_t dst_h = src_height * batchIdx + src_h; + size_t dst_w = src_width * batchIdx + src_w; + output_addr[dst_h * dst_width + dst_w] = input_addr[pointIdx]; + } else { + size_t src_h = (pointIdx - (batch - 1) * src_height * src_width) / res_width; + size_t src_w = (pointIdx - (batch - 1) * src_height * src_width) % res_width; + size_t src_coordinate = (batch - 1) * src_height * src_width + src_h * src_width + src_w; + size_t dst_h = src_height * (batch - 1) + src_h; + size_t dst_w = src_width * (batch - 1) + src_w; + output_addr[dst_h * dst_width + dst_w] = input_addr[src_coordinate]; + } + } +} + +template +void MatrixCombine(const size_t size, const size_t src_height, const size_t src_width, const size_t dst_width, + const size_t residual, const size_t res_width, const size_t batch, T *input_addr, T *output_addr, + cudaStream_t cuda_stream) { + if (residual == 0) { + MatrixCombineKernel<<>>(size, src_height, src_width, dst_width, + input_addr, output_addr); + } else { + MatrixCombineKernel<<>>(size, src_height, src_width, dst_width, + res_width, batch, input_addr, output_addr); + } + return; +} + +template void MatrixCombine(const size_t size, const size_t src_height, const size_t src_width, + const size_t dst_width, const size_t residual, const size_t res_width, + const size_t batch, float *input_addr, float *output_addr, cudaStream_t cuda_stream); + diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_combine_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_combine_impl.cuh new file mode 100644 index 00000000000..737ec133834 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_combine_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXCOMBINE_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXCOMBINE_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void MatrixCombine(const size_t size, const size_t src_height, const size_t src_width, const size_t dst_width, + const size_t residual, const size_t res_width, const size_t batch, T *input_addr, T *output_addr, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXCOMBINE_H_ + diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cu new file mode 100644 index 00000000000..15013377fe5 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cu @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "matrix_split_impl.cuh" +#include +template +__global__ void MatrixSplitKernel(const size_t size, const size_t split_dim, const size_t dim, T *input_addr, + T *output_addr) { + for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { + size_t batchIdx = pointIdx / (split_dim * split_dim); + size_t dst_x = (pointIdx - batchIdx * split_dim * split_dim) / split_dim; + size_t dst_y = (pointIdx - batchIdx * split_dim * split_dim) % split_dim; + size_t src_coordinate = (batchIdx * split_dim + dst_x) * dim + batchIdx * split_dim + dst_y; + output_addr[pointIdx] = input_addr[src_coordinate]; + } +} + +template +__global__ void MatrixSplitKernel(const size_t size, const size_t split_dim, const size_t dim, const size_t res_dim, + T *input_addr, T *output_addr) { + for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { + size_t batchIdx = pointIdx / (split_dim * split_dim); + size_t dst_x = (pointIdx - batchIdx * split_dim * split_dim) / split_dim; + size_t dst_y = (pointIdx - batchIdx * split_dim * split_dim) % split_dim; + size_t src_coordinate = (batchIdx * split_dim + dst_x) * dim + batchIdx * split_dim + dst_y; + size_t batch_lower = dim / split_dim; + if (batchIdx < batch_lower) { + output_addr[pointIdx] = input_addr[src_coordinate]; + } else { + if (dst_x < res_dim && dst_y < res_dim) { + output_addr[pointIdx] = input_addr[src_coordinate]; + } else if (dst_x == dst_y) { + output_addr[pointIdx] = 1; + } else { + output_addr[pointIdx] = 0; + } + } + } +} + +template +void MatrixSplit(const size_t size, const size_t split_dim, const size_t dim, T *input_addr, T *output_addr, + cudaStream_t cuda_stream) { + size_t batch = dim / split_dim; + size_t res_dim = dim - batch * split_dim; + if (res_dim == 0) { + MatrixSplitKernel<<>>(size, split_dim, dim, input_addr, output_addr); + } else { + MatrixSplitKernel<<>>(size, split_dim, dim, res_dim, input_addr, + output_addr); + } + return; +} + +template void MatrixSplit(const size_t size, const size_t split_dim, const size_t dim, float *input_addr, + float *output_addr, cudaStream_t cuda_stream); + diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh new file mode 100644 index 00000000000..edae55c14da --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void MatrixSplit(const size_t size, const size_t split_dim, const size_t dim, T *input_addr, T *output_addr, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.cc new file mode 100644 index 00000000000..b7a71308cea --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Im2Col, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + Im2ColGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Im2Col, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + Im2ColGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h new file mode 100644 index 00000000000..53a711775ab --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h @@ -0,0 +1,269 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_IM2COLGPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_IM2COLGPUKERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class Im2ColGpuFwdKernel : public GpuKernel { + public: + Im2ColGpuFwdKernel() + : cudnn_handle_(nullptr), + input_desc_(nullptr), + output_desc_(nullptr), + filter_desc_(nullptr), + conv_desc_(nullptr), + padded_desc_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT), + old_height_(0), + old_width_(0), + pad_height_(0), + pad_width_(0), + pad_top_(0), + pad_left_(0), + n_(0), + c_(0), + is_null_input_(false), + input_size_(0), + output_size_(0), + padded_size_(0), + workspace_size_(0), + use_pad_(true) {} + ~Im2ColGpuFwdKernel() override { DestroyResource(); } + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { + T *padded_addr = GetDeviceAddress(workspace, 0); + CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, + reinterpret_cast(stream_ptr)); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnIm2Col(cudnn_handle_, padded_desc_, padded_addr, filter_desc_, conv_desc_, output_addr), + "cudnnIm2ColForward failed"); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnIm2Col(cudnn_handle_, input_desc_, input_addr, filter_desc_, conv_desc_, output_addr), + "cudnnIm2ColForward failed"); + } + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + if (!CheckParam(kernel_node)) { + return false; + } + auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto filter_shape = GetAttr>(kernel_node, "kernel_size"); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(in_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "cudnnIm2ColForward input is null."; + InitSizeLists(); + return true; + } + Set4DDesc(in_shape, filter_shape, output_shape); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, 1), "cudnnSetConvGroupCount failed"); + pad_height_ = GetAttr(kernel_node, "pad"); + pad_width_ = pad_height_; + pad_mode_ = GetAttr(kernel_node, "pad_mode"); + SetStrideAndDilation(kernel_node); + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + SetPad(in_shape, kernel_node); + } else { + if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { + pad_height_ = 0; + pad_width_ = 0; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[2], stride_[3], dilation_[2], + dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolution2dDescriptor failed"); + } + if (cudnn_data_type_ == CUDNN_DATA_HALF) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), + "cudnnSetConvolutionMathType failed.") + } + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&filter_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateConvolutionDescriptor(&conv_desc_), + "cudnnCreateConvolutionDescriptor failed"); + } + + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_desc_, reinterpret_cast(&input_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(output_desc_, reinterpret_cast(&output_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_desc_, reinterpret_cast(&padded_size_)), + "cudnnGetTensorSizeInBytes failed"); + } + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { + workspace_size_list_.push_back(padded_size_); + } + return; + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_), + "cudnnDestroyConvolutionDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(filter_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_desc_), "cudnnDestroyTensorDescriptor failed"); + } + bool CheckParam(const CNodePtr &kernel_node) { + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but Im2Col needs 1 inputs."; + return false; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but Im2Col needs 1 output."; + return false; + } + return true; + } + void SetPad(const std::vector &in_shape, const CNodePtr &kernel_node) { + auto pad_list = GetAttr>(kernel_node, "pad_list"); + + n_ = SizeToInt(in_shape[0]); + c_ = SizeToInt(in_shape[1]); + old_height_ = SizeToInt(in_shape[2]); + old_width_ = SizeToInt(in_shape[3]); + pad_height_ = pad_list[0] + pad_list[1]; + pad_width_ = pad_list[2] + pad_list[3]; + pad_top_ = pad_list[0]; + pad_left_ = pad_list[2]; + + // if use_pad_ == true, using zero padding in advance, else using the default cudnn pad. + if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { + use_pad_ = false; + } + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, c_, + old_height_ + pad_height_, old_width_ + pad_width_), + "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( + conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[2], stride_[3], + dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolution2dDescriptor failed"); + } + + void Set4DDesc(const std::vector &in_shape, const std::vector &filter_shape, + const std::vector &output_shape) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]), + SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])), + "cudnnSetTensor4dDescriptor failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetFilter4dDescriptor(filter_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 1, + SizeToInt(in_shape[1]), filter_shape[0], filter_shape[1]), + "cudnnSetFilter4dDescriptor failed"); + + auto out_H = output_shape[0] * output_shape[1] * output_shape[2]; + auto out_W = output_shape[3] * output_shape[4] * output_shape[5]; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, + SizeToInt(out_H), SizeToInt(out_W), 1, 1), + "cudnnSetTensor4dDescriptor failed"); + } + + void SetStrideAndDilation(const CNodePtr &kernel_node) { + stride_ = AnfAlgo::GetNodeAttr>(kernel_node, "stride"); + dilation_ = AnfAlgo::GetNodeAttr>(kernel_node, "dilation"); + if (stride_.size() != 4) { + MS_LOG(EXCEPTION) << "Im2Col's stride must be 4d!"; + } + if (stride_[0] != 1 || stride_[1] != 1) { + MS_LOG(EXCEPTION) << "Im2Col's stride only support 1 in N axis and C axis!"; + } + if (dilation_.size() != 4) { + MS_LOG(EXCEPTION) << "Im2Col's dilation must be 4d!"; + } + if (dilation_[0] != 1 || dilation_[1] != 1) { + MS_LOG(EXCEPTION) << "Im2Col's dilation only support 1 in N axis and C axis!"; + } + } + + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t input_desc_; + cudnnTensorDescriptor_t output_desc_; + cudnnFilterDescriptor_t filter_desc_; + cudnnConvolutionFwdAlgo_t conv_algorithm_; + cudnnConvolutionDescriptor_t conv_desc_; + cudnnTensorDescriptor_t padded_desc_; + std::string pad_mode_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + const float pad_value_ = 0.0; + cudnnDataType_t cudnn_data_type_; + int old_height_; + int old_width_; + int pad_height_; + int pad_width_; + int pad_top_; + int pad_left_; + int n_; + int c_; + std::vector stride_; + std::vector dilation_; + bool is_null_input_; + size_t input_size_; + size_t output_size_; + size_t padded_size_; + size_t workspace_size_; + bool use_pad_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_IM2COLGPUKERNEL_H_ diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 49d4ad1f080..2aeab197285 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -83,7 +83,10 @@ from . import _quant_ops from ._quant_ops import * from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull) -from .thor_ops import * +from .thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, + CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, + CusMatMulCubeDenseRight, + CusMatMulCubeFraczLeftCast, Im2Col) from .sparse_ops import SparseToDense __all__ = [ diff --git a/mindspore/ops/operations/thor_ops.py b/mindspore/ops/operations/thor_ops.py index d2de0190a63..cc02202e2bb 100644 --- a/mindspore/ops/operations/thor_ops.py +++ b/mindspore/ops/operations/thor_ops.py @@ -13,9 +13,12 @@ # limitations under the License. # ============================================================================ """thor_ops""" +import math + from ..primitive import prim_attr_register, PrimitiveWithInfer from ...common import dtype as mstype - +from ..._checkparam import Validator as validator +from ..._checkparam import Rel __all__ = ["CusBatchMatMul", "CusCholeskyTrsm", @@ -31,6 +34,37 @@ __all__ = ["CusBatchMatMul", ] +def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False): + """ + Checks whether an argument is a positive int or tuple with 2 or 4(when allow_four is True) positive int elements. + """ + + def _raise_message(): + raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two " + f"{'or four ' if allow_four else ''}positive int numbers, but got {arg_value}") + + def _get_return_value(): + if isinstance(arg_value, int): + ret = (1, 1, arg_value, arg_value) if ret_four else (arg_value, arg_value) + elif len(arg_value) == 2: + ret = (1, 1, arg_value[0], arg_value[1]) if ret_four else arg_value + elif len(arg_value) == 4: + if not allow_four: + _raise_message() + ret = arg_value if ret_four else (arg_value[2], arg_value[3]) + else: + _raise_message() + return ret + + validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name) + ret_value = _get_return_value() + for item in ret_value: + if isinstance(item, int) and item > 0: + continue + _raise_message() + return ret_value + + class CusBatchMatMul(PrimitiveWithInfer): """ Multiplies matrix `a` by matrix `b` in batch. @@ -360,6 +394,7 @@ class CusTranspose02314(PrimitiveWithInfer): """init CusTranspose02314""" self.init_prim_io_names(inputs=['x1'], outputs=['y']) from mindspore.ops._op_impl._custom_op.transpose02314_impl import CusTranspose02314 + def get_bprop(self): def bprop(x, out, dout): return (C.zeros_like(x),) @@ -446,3 +481,84 @@ class CusMatMulCubeFraczLeftCast(PrimitiveWithInfer): def infer_dtype(self, data1_dtype, data2_dtype): return mstype.float16 + + +class Im2Col(PrimitiveWithInfer): + """ + extract image pathes from image. + + The rank of input_x1 must be `4`, data_format is "NCHW". + + Inputs: + - **input_x1** (Tensor) - The feature map. + The shape of the tensor is :math:`(N, C, H, W)`. + Outputs: + Tensor. + Examples: + >>> input_x = Tensor(np.random.rand(32, 3, 224, 224).astype(np.float16)) + >>> img2col = P.CusMatMulCubeDenseLeft(kernel_size=7, pad=3, stride=2) + >>> output = img2col(input_x) + """ + @prim_attr_register + def __init__(self, + kernel_size, + pad_mode="valid", + pad=0, + stride=1, + dilation=1): + """init Im2Col""" + self.init_prim_io_names(inputs=['x'], outputs=['output']) + self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) + self.add_prim_attr('kernel_size', self.kernel_size) + self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True) + self.add_prim_attr('stride', self.stride) + self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) + self.add_prim_attr('dilation', self.dilation) + validator.check_value_type('pad', pad, (int,), self.name) + self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) + self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) + if self.pad_mode == 'pad': + validator.check_integer('pad', self.pad, 0, Rel.GE, self.name) + self.add_prim_attr('data_format', "NCHW") + + def infer_shape(self, x_shape): + validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) + kernel_size_h = self.kernel_size[0] + kernel_size_w = self.kernel_size[1] + stride_h = self.stride[2] + stride_w = self.stride[3] + dilation_h = self.dilation[2] + dilation_w = self.dilation[3] + if self.pad_mode == "valid": + h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h) + w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w) + pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0 + elif self.pad_mode == "same": + h_out = math.ceil(x_shape[2] / stride_h) + w_out = math.ceil(x_shape[3] / stride_w) + pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]) + pad_top = math.floor(pad_needed_h / 2) + pad_bottom = pad_needed_h - pad_top + pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]) + pad_left = math.floor(pad_needed_w / 2) + pad_right = pad_needed_w - pad_left + elif self.pad_mode == 'pad': + pad_top, pad_bottom, pad_left, pad_right = self.pad, self.pad, self.pad, self.pad + h_out = 1 + (x_shape[2] + 2 * self.pad - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h + w_out = 1 + (x_shape[3] + 2 * self.pad - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w + h_out = math.floor(h_out) + w_out = math.floor(w_out) + self.pad_list = [pad_top, pad_bottom, pad_left, pad_right] + self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right)) + batch_size = x_shape[0] + channel = x_shape[1] + k_h = kernel_size_h + k_w = kernel_size_w + out_shape = [channel, k_h, k_w, batch_size, h_out, w_out] + return out_shape + + def infer_dtype(self, x_dtype): + args = {'x': x_dtype} + valid_types = [mstype.float16, mstype.float32] + validator.check_tensor_type_same(args, valid_types, self.name) + return x_dtype