forked from mindspore-Ecosystem/mindspore
!3924 support Im2Col for resnet50 thor GPU
Merge pull request !3924 from mamba_ni/master
This commit is contained in:
commit
677c193e96
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "identity_impl.cuh"
|
||||
#include <iostream>
|
||||
template <typename T>
|
||||
__global__ void IdentityKernel(const size_t size, const size_t dim, T *output_addr) {
|
||||
for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) {
|
||||
size_t batchIdx = pointIdx / (dim * dim);
|
||||
size_t dst_x = (pointIdx - batchIdx * dim * dim) / dim;
|
||||
size_t dst_y = (pointIdx - batchIdx * dim * dim) % dim;
|
||||
if (dst_x == dst_y) {
|
||||
output_addr[pointIdx] = 1;
|
||||
} else {
|
||||
output_addr[pointIdx] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Identity(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream) {
|
||||
IdentityKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dim, output_addr);
|
||||
return;
|
||||
}
|
||||
|
||||
template void Identity<float>(const size_t size, const size_t dim, float *output_addr, cudaStream_t cuda_stream);
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IDENTITY_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IDENTITY_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void Identity(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "matrix_combine_impl.cuh"
|
||||
#include <iostream>
|
||||
template <typename T>
|
||||
__global__ void MatrixCombineKernel(const size_t size, const size_t src_height, const size_t src_width,
|
||||
const size_t dst_width, T *input_addr, T *output_addr) {
|
||||
for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) {
|
||||
size_t batchIdx = pointIdx / (src_height * src_width);
|
||||
size_t src_h = (pointIdx - batchIdx * src_height * src_width) / src_width;
|
||||
size_t src_w = (pointIdx - batchIdx * src_height * src_width) % src_width;
|
||||
size_t dst_h = src_height * batchIdx + src_h;
|
||||
size_t dst_w = src_width * batchIdx + src_w;
|
||||
output_addr[dst_h * dst_width + dst_w] = input_addr[pointIdx];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void MatrixCombineKernel(const size_t size, const size_t src_height, const size_t src_width,
|
||||
const size_t dst_width, const size_t res_width, const size_t batch, T *input_addr,
|
||||
T *output_addr) {
|
||||
for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) {
|
||||
size_t batchIdx = pointIdx / (src_height * src_width);
|
||||
if (batchIdx != (batch - 1)) {
|
||||
size_t src_h = (pointIdx - batchIdx * src_height * src_width) / src_width;
|
||||
size_t src_w = (pointIdx - batchIdx * src_height * src_width) % src_width;
|
||||
size_t dst_h = src_height * batchIdx + src_h;
|
||||
size_t dst_w = src_width * batchIdx + src_w;
|
||||
output_addr[dst_h * dst_width + dst_w] = input_addr[pointIdx];
|
||||
} else {
|
||||
size_t src_h = (pointIdx - (batch - 1) * src_height * src_width) / res_width;
|
||||
size_t src_w = (pointIdx - (batch - 1) * src_height * src_width) % res_width;
|
||||
size_t src_coordinate = (batch - 1) * src_height * src_width + src_h * src_width + src_w;
|
||||
size_t dst_h = src_height * (batch - 1) + src_h;
|
||||
size_t dst_w = src_width * (batch - 1) + src_w;
|
||||
output_addr[dst_h * dst_width + dst_w] = input_addr[src_coordinate];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixCombine(const size_t size, const size_t src_height, const size_t src_width, const size_t dst_width,
|
||||
const size_t residual, const size_t res_width, const size_t batch, T *input_addr, T *output_addr,
|
||||
cudaStream_t cuda_stream) {
|
||||
if (residual == 0) {
|
||||
MatrixCombineKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, src_height, src_width, dst_width,
|
||||
input_addr, output_addr);
|
||||
} else {
|
||||
MatrixCombineKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, src_height, src_width, dst_width,
|
||||
res_width, batch, input_addr, output_addr);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template void MatrixCombine<float>(const size_t size, const size_t src_height, const size_t src_width,
|
||||
const size_t dst_width, const size_t residual, const size_t res_width,
|
||||
const size_t batch, float *input_addr, float *output_addr, cudaStream_t cuda_stream);
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXCOMBINE_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXCOMBINE_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void MatrixCombine(const size_t size, const size_t src_height, const size_t src_width, const size_t dst_width,
|
||||
const size_t residual, const size_t res_width, const size_t batch, T *input_addr, T *output_addr,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXCOMBINE_H_
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "matrix_split_impl.cuh"
|
||||
#include <iostream>
|
||||
template <typename T>
|
||||
__global__ void MatrixSplitKernel(const size_t size, const size_t split_dim, const size_t dim, T *input_addr,
|
||||
T *output_addr) {
|
||||
for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) {
|
||||
size_t batchIdx = pointIdx / (split_dim * split_dim);
|
||||
size_t dst_x = (pointIdx - batchIdx * split_dim * split_dim) / split_dim;
|
||||
size_t dst_y = (pointIdx - batchIdx * split_dim * split_dim) % split_dim;
|
||||
size_t src_coordinate = (batchIdx * split_dim + dst_x) * dim + batchIdx * split_dim + dst_y;
|
||||
output_addr[pointIdx] = input_addr[src_coordinate];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void MatrixSplitKernel(const size_t size, const size_t split_dim, const size_t dim, const size_t res_dim,
|
||||
T *input_addr, T *output_addr) {
|
||||
for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) {
|
||||
size_t batchIdx = pointIdx / (split_dim * split_dim);
|
||||
size_t dst_x = (pointIdx - batchIdx * split_dim * split_dim) / split_dim;
|
||||
size_t dst_y = (pointIdx - batchIdx * split_dim * split_dim) % split_dim;
|
||||
size_t src_coordinate = (batchIdx * split_dim + dst_x) * dim + batchIdx * split_dim + dst_y;
|
||||
size_t batch_lower = dim / split_dim;
|
||||
if (batchIdx < batch_lower) {
|
||||
output_addr[pointIdx] = input_addr[src_coordinate];
|
||||
} else {
|
||||
if (dst_x < res_dim && dst_y < res_dim) {
|
||||
output_addr[pointIdx] = input_addr[src_coordinate];
|
||||
} else if (dst_x == dst_y) {
|
||||
output_addr[pointIdx] = 1;
|
||||
} else {
|
||||
output_addr[pointIdx] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixSplit(const size_t size, const size_t split_dim, const size_t dim, T *input_addr, T *output_addr,
|
||||
cudaStream_t cuda_stream) {
|
||||
size_t batch = dim / split_dim;
|
||||
size_t res_dim = dim - batch * split_dim;
|
||||
if (res_dim == 0) {
|
||||
MatrixSplitKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, split_dim, dim, input_addr, output_addr);
|
||||
} else {
|
||||
MatrixSplitKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, split_dim, dim, res_dim, input_addr,
|
||||
output_addr);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template void MatrixSplit<float>(const size_t size, const size_t split_dim, const size_t dim, float *input_addr,
|
||||
float *output_addr, cudaStream_t cuda_stream);
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void MatrixSplit(const size_t size, const size_t split_dim, const size_t dim, T *input_addr, T *output_addr,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Im2Col, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
Im2ColGpuFwdKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Im2Col, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
Im2ColGpuFwdKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,269 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_IM2COLGPUKERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_IM2COLGPUKERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class Im2ColGpuFwdKernel : public GpuKernel {
|
||||
public:
|
||||
Im2ColGpuFwdKernel()
|
||||
: cudnn_handle_(nullptr),
|
||||
input_desc_(nullptr),
|
||||
output_desc_(nullptr),
|
||||
filter_desc_(nullptr),
|
||||
conv_desc_(nullptr),
|
||||
padded_desc_(nullptr),
|
||||
cudnn_data_type_(CUDNN_DATA_FLOAT),
|
||||
old_height_(0),
|
||||
old_width_(0),
|
||||
pad_height_(0),
|
||||
pad_width_(0),
|
||||
pad_top_(0),
|
||||
pad_left_(0),
|
||||
n_(0),
|
||||
c_(0),
|
||||
is_null_input_(false),
|
||||
input_size_(0),
|
||||
output_size_(0),
|
||||
padded_size_(0),
|
||||
workspace_size_(0),
|
||||
use_pad_(true) {}
|
||||
~Im2ColGpuFwdKernel() override { DestroyResource(); }
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) {
|
||||
T *padded_addr = GetDeviceAddress<T>(workspace, 0);
|
||||
CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_,
|
||||
old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnIm2Col(cudnn_handle_, padded_desc_, padded_addr, filter_desc_, conv_desc_, output_addr),
|
||||
"cudnnIm2ColForward failed");
|
||||
} else {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnIm2Col(cudnn_handle_, input_desc_, input_addr, filter_desc_, conv_desc_, output_addr),
|
||||
"cudnnIm2ColForward failed");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
InitResource();
|
||||
if (!CheckParam(kernel_node)) {
|
||||
return false;
|
||||
}
|
||||
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto filter_shape = GetAttr<std::vector<int>>(kernel_node, "kernel_size");
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
is_null_input_ = CHECK_NULL_INPUT(in_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "cudnnIm2ColForward input is null.";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
Set4DDesc(in_shape, filter_shape, output_shape);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, 1), "cudnnSetConvGroupCount failed");
|
||||
pad_height_ = GetAttr<int>(kernel_node, "pad");
|
||||
pad_width_ = pad_height_;
|
||||
pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode");
|
||||
SetStrideAndDilation(kernel_node);
|
||||
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) {
|
||||
SetPad(in_shape, kernel_node);
|
||||
} else {
|
||||
if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) {
|
||||
pad_height_ = 0;
|
||||
pad_width_ = 0;
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[2], stride_[3], dilation_[2],
|
||||
dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
||||
"cudnnSetConvolution2dDescriptor failed");
|
||||
}
|
||||
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH),
|
||||
"cudnnSetConvolutionMathType failed.")
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() override {
|
||||
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_desc_), "cudnnCreateTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_desc_), "cudnnCreateTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_desc_), "cudnnCreateTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&filter_desc_), "cudnnCreateTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateConvolutionDescriptor(&conv_desc_),
|
||||
"cudnnCreateConvolutionDescriptor failed");
|
||||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
if (!is_null_input_) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_desc_, reinterpret_cast<size_t *>(&input_size_)),
|
||||
"cudnnGetTensorSizeInBytes failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(output_desc_, reinterpret_cast<size_t *>(&output_size_)),
|
||||
"cudnnGetTensorSizeInBytes failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_desc_, reinterpret_cast<size_t *>(&padded_size_)),
|
||||
"cudnnGetTensorSizeInBytes failed");
|
||||
}
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) {
|
||||
workspace_size_list_.push_back(padded_size_);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
private:
|
||||
void DestroyResource() noexcept {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_),
|
||||
"cudnnDestroyConvolutionDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(filter_desc_), "cudnnDestroyTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_desc_), "cudnnDestroyTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_desc_), "cudnnDestroyTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_desc_), "cudnnDestroyTensorDescriptor failed");
|
||||
}
|
||||
bool CheckParam(const CNodePtr &kernel_node) {
|
||||
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but Im2Col needs 1 inputs.";
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but Im2Col needs 1 output.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
void SetPad(const std::vector<size_t> &in_shape, const CNodePtr &kernel_node) {
|
||||
auto pad_list = GetAttr<std::vector<int>>(kernel_node, "pad_list");
|
||||
|
||||
n_ = SizeToInt(in_shape[0]);
|
||||
c_ = SizeToInt(in_shape[1]);
|
||||
old_height_ = SizeToInt(in_shape[2]);
|
||||
old_width_ = SizeToInt(in_shape[3]);
|
||||
pad_height_ = pad_list[0] + pad_list[1];
|
||||
pad_width_ = pad_list[2] + pad_list[3];
|
||||
pad_top_ = pad_list[0];
|
||||
pad_left_ = pad_list[2];
|
||||
|
||||
// if use_pad_ == true, using zero padding in advance, else using the default cudnn pad.
|
||||
if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) {
|
||||
use_pad_ = false;
|
||||
}
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, c_,
|
||||
old_height_ + pad_height_, old_width_ + pad_width_),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor(
|
||||
conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[2], stride_[3],
|
||||
dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
||||
"cudnnSetConvolution2dDescriptor failed");
|
||||
}
|
||||
|
||||
void Set4DDesc(const std::vector<size_t> &in_shape, const std::vector<int> &filter_shape,
|
||||
const std::vector<size_t> &output_shape) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]),
|
||||
SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetFilter4dDescriptor(filter_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 1,
|
||||
SizeToInt(in_shape[1]), filter_shape[0], filter_shape[1]),
|
||||
"cudnnSetFilter4dDescriptor failed");
|
||||
|
||||
auto out_H = output_shape[0] * output_shape[1] * output_shape[2];
|
||||
auto out_W = output_shape[3] * output_shape[4] * output_shape[5];
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
|
||||
SizeToInt(out_H), SizeToInt(out_W), 1, 1),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
}
|
||||
|
||||
void SetStrideAndDilation(const CNodePtr &kernel_node) {
|
||||
stride_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride");
|
||||
dilation_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "dilation");
|
||||
if (stride_.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "Im2Col's stride must be 4d!";
|
||||
}
|
||||
if (stride_[0] != 1 || stride_[1] != 1) {
|
||||
MS_LOG(EXCEPTION) << "Im2Col's stride only support 1 in N axis and C axis!";
|
||||
}
|
||||
if (dilation_.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "Im2Col's dilation must be 4d!";
|
||||
}
|
||||
if (dilation_[0] != 1 || dilation_[1] != 1) {
|
||||
MS_LOG(EXCEPTION) << "Im2Col's dilation only support 1 in N axis and C axis!";
|
||||
}
|
||||
}
|
||||
|
||||
cudnnHandle_t cudnn_handle_;
|
||||
cudnnTensorDescriptor_t input_desc_;
|
||||
cudnnTensorDescriptor_t output_desc_;
|
||||
cudnnFilterDescriptor_t filter_desc_;
|
||||
cudnnConvolutionFwdAlgo_t conv_algorithm_;
|
||||
cudnnConvolutionDescriptor_t conv_desc_;
|
||||
cudnnTensorDescriptor_t padded_desc_;
|
||||
std::string pad_mode_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
const float pad_value_ = 0.0;
|
||||
cudnnDataType_t cudnn_data_type_;
|
||||
int old_height_;
|
||||
int old_width_;
|
||||
int pad_height_;
|
||||
int pad_width_;
|
||||
int pad_top_;
|
||||
int pad_left_;
|
||||
int n_;
|
||||
int c_;
|
||||
std::vector<int> stride_;
|
||||
std::vector<int> dilation_;
|
||||
bool is_null_input_;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t padded_size_;
|
||||
size_t workspace_size_;
|
||||
bool use_pad_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_IM2COLGPUKERNEL_H_
|
|
@ -83,7 +83,10 @@ from . import _quant_ops
|
|||
from ._quant_ops import *
|
||||
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
|
||||
CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull)
|
||||
from .thor_ops import *
|
||||
from .thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
|
||||
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
|
||||
CusMatMulCubeDenseRight,
|
||||
CusMatMulCubeFraczLeftCast, Im2Col)
|
||||
from .sparse_ops import SparseToDense
|
||||
|
||||
__all__ = [
|
||||
|
|
|
@ -13,9 +13,12 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""thor_ops"""
|
||||
import math
|
||||
|
||||
from ..primitive import prim_attr_register, PrimitiveWithInfer
|
||||
from ...common import dtype as mstype
|
||||
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
|
||||
__all__ = ["CusBatchMatMul",
|
||||
"CusCholeskyTrsm",
|
||||
|
@ -31,6 +34,37 @@ __all__ = ["CusBatchMatMul",
|
|||
]
|
||||
|
||||
|
||||
def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False):
|
||||
"""
|
||||
Checks whether an argument is a positive int or tuple with 2 or 4(when allow_four is True) positive int elements.
|
||||
"""
|
||||
|
||||
def _raise_message():
|
||||
raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two "
|
||||
f"{'or four ' if allow_four else ''}positive int numbers, but got {arg_value}")
|
||||
|
||||
def _get_return_value():
|
||||
if isinstance(arg_value, int):
|
||||
ret = (1, 1, arg_value, arg_value) if ret_four else (arg_value, arg_value)
|
||||
elif len(arg_value) == 2:
|
||||
ret = (1, 1, arg_value[0], arg_value[1]) if ret_four else arg_value
|
||||
elif len(arg_value) == 4:
|
||||
if not allow_four:
|
||||
_raise_message()
|
||||
ret = arg_value if ret_four else (arg_value[2], arg_value[3])
|
||||
else:
|
||||
_raise_message()
|
||||
return ret
|
||||
|
||||
validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
|
||||
ret_value = _get_return_value()
|
||||
for item in ret_value:
|
||||
if isinstance(item, int) and item > 0:
|
||||
continue
|
||||
_raise_message()
|
||||
return ret_value
|
||||
|
||||
|
||||
class CusBatchMatMul(PrimitiveWithInfer):
|
||||
"""
|
||||
Multiplies matrix `a` by matrix `b` in batch.
|
||||
|
@ -360,6 +394,7 @@ class CusTranspose02314(PrimitiveWithInfer):
|
|||
"""init CusTranspose02314"""
|
||||
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
|
||||
from mindspore.ops._op_impl._custom_op.transpose02314_impl import CusTranspose02314
|
||||
|
||||
def get_bprop(self):
|
||||
def bprop(x, out, dout):
|
||||
return (C.zeros_like(x),)
|
||||
|
@ -446,3 +481,84 @@ class CusMatMulCubeFraczLeftCast(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, data1_dtype, data2_dtype):
|
||||
return mstype.float16
|
||||
|
||||
|
||||
class Im2Col(PrimitiveWithInfer):
|
||||
"""
|
||||
extract image pathes from image.
|
||||
|
||||
The rank of input_x1 must be `4`, data_format is "NCHW".
|
||||
|
||||
Inputs:
|
||||
- **input_x1** (Tensor) - The feature map.
|
||||
The shape of the tensor is :math:`(N, C, H, W)`.
|
||||
Outputs:
|
||||
Tensor.
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.random.rand(32, 3, 224, 224).astype(np.float16))
|
||||
>>> img2col = P.CusMatMulCubeDenseLeft(kernel_size=7, pad=3, stride=2)
|
||||
>>> output = img2col(input_x)
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
pad_mode="valid",
|
||||
pad=0,
|
||||
stride=1,
|
||||
dilation=1):
|
||||
"""init Im2Col"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
|
||||
self.add_prim_attr('kernel_size', self.kernel_size)
|
||||
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
|
||||
self.add_prim_attr('stride', self.stride)
|
||||
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
|
||||
self.add_prim_attr('dilation', self.dilation)
|
||||
validator.check_value_type('pad', pad, (int,), self.name)
|
||||
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
|
||||
self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name)
|
||||
if self.pad_mode == 'pad':
|
||||
validator.check_integer('pad', self.pad, 0, Rel.GE, self.name)
|
||||
self.add_prim_attr('data_format', "NCHW")
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
|
||||
kernel_size_h = self.kernel_size[0]
|
||||
kernel_size_w = self.kernel_size[1]
|
||||
stride_h = self.stride[2]
|
||||
stride_w = self.stride[3]
|
||||
dilation_h = self.dilation[2]
|
||||
dilation_w = self.dilation[3]
|
||||
if self.pad_mode == "valid":
|
||||
h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h)
|
||||
w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w)
|
||||
pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0
|
||||
elif self.pad_mode == "same":
|
||||
h_out = math.ceil(x_shape[2] / stride_h)
|
||||
w_out = math.ceil(x_shape[3] / stride_w)
|
||||
pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2])
|
||||
pad_top = math.floor(pad_needed_h / 2)
|
||||
pad_bottom = pad_needed_h - pad_top
|
||||
pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3])
|
||||
pad_left = math.floor(pad_needed_w / 2)
|
||||
pad_right = pad_needed_w - pad_left
|
||||
elif self.pad_mode == 'pad':
|
||||
pad_top, pad_bottom, pad_left, pad_right = self.pad, self.pad, self.pad, self.pad
|
||||
h_out = 1 + (x_shape[2] + 2 * self.pad - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h
|
||||
w_out = 1 + (x_shape[3] + 2 * self.pad - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w
|
||||
h_out = math.floor(h_out)
|
||||
w_out = math.floor(w_out)
|
||||
self.pad_list = [pad_top, pad_bottom, pad_left, pad_right]
|
||||
self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
|
||||
batch_size = x_shape[0]
|
||||
channel = x_shape[1]
|
||||
k_h = kernel_size_h
|
||||
k_w = kernel_size_w
|
||||
out_shape = [channel, k_h, k_w, batch_size, h_out, w_out]
|
||||
return out_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
args = {'x': x_dtype}
|
||||
valid_types = [mstype.float16, mstype.float32]
|
||||
validator.check_tensor_type_same(args, valid_types, self.name)
|
||||
return x_dtype
|
||||
|
|
Loading…
Reference in New Issue