!16022 GPU Conv3d grad op support

From: @tom__chen
Reviewed-by: @robingrosman,@mikef
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-05-08 07:02:03 +08:00 committed by Gitee
commit 1dc0efbab5
8 changed files with 1011 additions and 5 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -129,7 +129,7 @@ __global__ void Pad3d(const size_t size, const T* input, const int num, const in
const int pos_w = pos % padded_width;
const int block_num = pos / padded_dhw;
if (pos_d - pad_head < 0 || pos_h - pad_top < 0 || pos_w - pad_left < 0 || pos_h - pad_head >= old_depth ||
if (pos_d - pad_head < 0 || pos_h - pad_top < 0 || pos_w - pad_left < 0 || pos_d - pad_head >= old_depth ||
pos_h - pad_top >= old_height || pos_w - pad_left >= old_width) {
output[pos] = pad_value_;
} else {
@ -140,6 +140,23 @@ __global__ void Pad3d(const size_t size, const T* input, const int num, const in
return;
}
template <typename T>
__global__ void PadGrad3d(const size_t size, const T* dy, const int num, const int channels, const int old_depth,
const int old_height, const int old_width, const int old_dhw, const int old_hw,
const int padded_depth, const int padded_height, const int padded_width,
const int padded_dhw, const int padded_hw, const int pad_head, const int pad_top,
const int pad_left, T* dx) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
const int block_num = pos / old_dhw;
const int pos_d = pos / old_hw % old_depth + pad_head;
const int pos_h = pos / old_width % old_height + pad_top;
const int pos_w = pos % old_width + pad_left;
const int index = block_num * padded_dhw + pos_d * padded_hw + pos_h * padded_width + pos_w;
dx[pos] = dy[index];
}
return;
}
template <typename T>
void CalPad(const size_t size, const T* input, const int num, const int channels, const int old_height,
const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left,
@ -204,6 +221,22 @@ void CalPad3d(const size_t size, const T* input, const int num, const int channe
return;
}
template <typename T>
void CalPadGrad3d(const size_t size, const T* dy, const int num, const int channels, const int old_depth,
const int old_height, const int old_width, const int padded_depth, const int padded_height,
const int padded_width, const int pad_head, const int pad_top, const int pad_left, T* dx,
cudaStream_t cuda_stream) {
const int old_hw = old_height * old_width;
const int old_dhw = old_depth * old_hw;
const int padded_hw = padded_height * padded_width;
const int padded_dhw = padded_depth * padded_hw;
PadGrad3d<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy, num, channels, old_depth, old_height,
old_width, old_dhw, old_hw, padded_depth, padded_height,
padded_width, padded_dhw, padded_hw, pad_head, pad_top,
pad_left, dx);
return;
}
template void CalPad<float>(const size_t size, const float* input, const int num, const int channels,
const int old_height, const int old_width, const int padded_height, const int padded_width,
const int pad_top, const int pad_left, float pad_value, float* output,
@ -259,3 +292,13 @@ template void CalPad3d<half>(const size_t size, const half* input, const int num
const int old_depth, const int old_height, const int old_width, const int padded_depth,
const int padded_height, const int padded_width, const int pad_head, const int pad_top,
const int pad_left, const float pad_value, half* output, cudaStream_t cuda_stream);
template void CalPadGrad3d<float>(const size_t size, const float* dy, const int num, const int channels,
const int old_depth, const int old_height, const int old_width,
const int padded_depth, const int padded_height, const int padded_width,
const int pad_head, const int pad_top, const int pad_left, float* dx,
cudaStream_t cuda_stream);
template void CalPadGrad3d<half>(const size_t size, const half* dy, const int num, const int channels,
const int old_depth, const int old_height, const int old_width,
const int padded_depth, const int padded_height, const int padded_width,
const int pad_head, const int pad_top, const int pad_left, half* dx,
cudaStream_t cuda_stream);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -45,4 +45,9 @@ void CalPad3d(const size_t size, const T* input, const int num, const int channe
const int old_height, const int old_width, const int padded_depth, const int padded_height,
const int padded_width, const int pad_head, const int pad_top, const int pad_left, const float pad_value,
T* output, cudaStream_t cuda_stream);
template <typename T>
void CalPadGrad3d(const size_t size, const T* dy, const int num, const int channels, const int old_depth,
const int old_height, const int old_width, const int padded_depth, const int padded_height,
const int padded_width, const int pad_head, const int pad_top, const int pad_left, T* dx,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_

View File

@ -0,0 +1,30 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/conv3d_grad_filter_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
Conv3DBackpropFilter,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
Conv3dGradFilterGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Conv3DBackpropFilter,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32),
Conv3dGradFilterGpuKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,418 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_FILTER_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_FILTER_GPU_KERNEL_H_
#include <algorithm>
#include <string>
#include <vector>
#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/kernel_compiler/gpu/cuda_impl/cast_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class Conv3dGradFilterGpuKernel : public GpuKernel {
public:
Conv3dGradFilterGpuKernel() { ResetResource(); }
~Conv3dGradFilterGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *x = GetDeviceAddress<T>(inputs, 0);
T *dy = GetDeviceAddress<T>(inputs, 1);
T *work_space = nullptr;
if (workspace_size_ != 0) {
work_space = GetDeviceAddress<T>(workspace, 0);
}
T *dw = nullptr;
float *dw_float32 = nullptr;
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
dw = GetDeviceAddress<T>(workspace, 1);
dw_float32 = GetDeviceAddress<float>(outputs, 0);
} else {
dw = GetDeviceAddress<T>(outputs, 0);
}
const float alpha = 1;
const float beta = 0;
if (use_pad_) {
T *padded = GetDeviceAddress<T>(workspace, 1);
CalPad3d(padded_size_ / sizeof(T), x, n_, c_, old_depth_, old_height_, old_width_, old_depth_ + pad_depth_,
old_height_ + pad_height_, old_width_ + pad_width_, pad_head_, pad_top_, pad_left_, pad_value_, padded,
reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnConvolutionBackwardFilter(cudnn_handle_, &alpha, padded_descriptor_, padded, dy_desc_, dy, conv_desc_,
algo_, work_space, workspace_size_, &beta, dw_desc_, dw),
"ConvolutionBackwardFilter failed");
return true;
}
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnConvolutionBackwardFilter(cudnn_handle_, &alpha, x_desc_, x, dy_desc_, dy, conv_desc_, algo_, work_space,
workspace_size_, &beta, dw_desc_, dw),
"ConvolutionBackwardFilter failed");
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
Cast(num_output_elements_, dw, dw_float32, reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
InitResource();
if (!CheckParam(kernel_node)) {
return false;
}
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
auto in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto dy_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
is_null_input_ = CHECK_NULL_INPUT(dy_shape) || CHECK_NULL_INPUT(in_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "Conv3dGradFilterGpuKernel input is null.";
InitSizeLists();
return true;
}
CHECK_TENSOR_SIZE(in_shape);
data_format_ = kOpFormat_NCDHW;
std::vector<size_t> filter_shape;
GetFilterShape(kernel_node, &filter_shape);
num_output_elements_ = 1;
for (auto x : filter_shape) {
num_output_elements_ *= x;
}
compute_format_ = CUDNN_TENSOR_NCHW;
n_ = SizeToInt(in_shape[0]);
c_ = SizeToInt(in_shape[1]);
old_depth_ = SizeToInt(in_shape[2]);
old_height_ = SizeToInt(in_shape[3]);
old_width_ = SizeToInt(in_shape[3]);
SetNDDesc(dy_shape, filter_shape, in_shape);
group_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "group"));
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionGroupCount(conv_desc_, group_),
"cudnnSetConvGroupCount failed");
std::vector<int> pad_list;
std::vector<int64_t> pad_list_me = GetAttr<std::vector<int64_t>>(kernel_node, "pad_list");
(void)std::transform(pad_list_me.begin(), pad_list_me.end(), std::back_inserter(pad_list),
[](const int64_t &value) { return static_cast<int>(value); });
pad_depth_ = pad_list[0];
pad_height_ = pad_list[2];
pad_width_ = pad_list[4];
use_pad_ = !((pad_depth_ == pad_list[1]) && (pad_height_ == pad_list[3]) && (pad_width_ == pad_list[5]));
pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode");
SetStrideAndDilation(kernel_node);
cudnnTensorDescriptor_t x_desc_real = nullptr;
const int kNumDims = 5;
const int kConvDims = 3;
int padA[kConvDims];
int strideA[kConvDims] = {stride_[2], stride_[3], stride_[4]};
int dilaA[kConvDims] = {dilation_[2], dilation_[3], dilation_[4]};
if (use_pad_) {
pad_depth_ = pad_list[0] + pad_list[1];
pad_height_ = pad_list[2] + pad_list[3];
pad_width_ = pad_list[4] + pad_list[5];
pad_head_ = pad_list[0];
pad_top_ = pad_list[2];
pad_left_ = pad_list[4];
int dimA[kNumDims];
int strideApadded[kNumDims];
if (data_format_ == kOpFormat_NCDHW) {
auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_depth_ + pad_depth_),
IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_)};
SetDimA(padded_shape, dimA, kNumDims, data_format_);
SetStrideA(padded_shape, strideApadded, kNumDims, data_format_);
} else {
MS_LOG(EXCEPTION) << "Conv3dGradFilterGpuKernel only support NCDHW format right now.";
}
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, kNumDims, dimA, strideApadded),
"cudnnSetTensor4dDescriptor failed");
padA[0] = 0;
padA[1] = 0;
padA[2] = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetConvolutionNdDescriptor(conv_desc_, kConvDims, padA, strideA, dilaA,
CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
"cudnnSetConvolutionNdDescriptor failed");
x_desc_real = padded_descriptor_;
} else {
if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) {
pad_depth_ = 0;
pad_height_ = 0;
pad_width_ = 0;
}
padA[0] = pad_depth_;
padA[1] = pad_height_;
padA[2] = pad_width_;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetConvolutionNdDescriptor(conv_desc_, kConvDims, padA, strideA, dilaA,
CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
"cudnnSetConvolutionNdDescriptor failed");
x_desc_real = x_desc_;
}
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH),
"cudnnSetConvolutionMathType failed.")
}
SelectAlgorithm(x_desc_real);
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
cudnn_handle_ = nullptr;
dw_desc_ = nullptr;
conv_desc_ = nullptr;
dy_desc_ = nullptr;
x_desc_ = nullptr;
padded_descriptor_ = nullptr;
cudnn_data_type_ = CUDNN_DATA_FLOAT;
compute_format_ = CUDNN_TENSOR_NCHW;
old_depth_ = 0;
old_height_ = 0;
old_width_ = 0;
pad_depth_ = 0;
pad_height_ = 0;
pad_width_ = 0;
pad_head_ = 0;
pad_top_ = 0;
pad_left_ = 0;
n_ = 0;
c_ = 0;
group_ = 1;
is_null_input_ = false;
input_size_ = 0;
dy_size_ = 0;
output_size_ = 0;
padded_size_ = 0;
workspace_size_ = 0;
use_pad_ = true;
num_output_elements_ = 1;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyConvolutionDescriptor(conv_desc_),
"cudnnDestroyConvolutionDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyFilterDescriptor(dw_desc_),
"cudnnDestroyFilterDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(padded_descriptor_),
"cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_),
"cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_),
"cudnnDestroyTensorDescriptor failed");
}
protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&padded_descriptor_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateFilterDescriptor(&dw_desc_),
"cudnnCreateFilterDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateConvolutionDescriptor(&conv_desc_),
"cudnnCreateConvolutionDescriptor failed");
}
void InitSizeLists() override {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnGetTensorSizeInBytes(dy_desc_, reinterpret_cast<size_t *>(&dy_size_)),
"cudnnGetTensorSizeInBytes failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnGetTensorSizeInBytes(x_desc_, reinterpret_cast<size_t *>(&input_size_)),
"cudnnGetTensorSizeInBytes failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnGetFilterSizeInBytes(dw_desc_, reinterpret_cast<size_t *>(&output_size_)),
"cudnnGetFilterSizeInBytes failed");
}
input_size_list_.push_back(dy_size_);
input_size_list_.push_back(input_size_);
if (use_pad_ && !is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnGetTensorSizeInBytes(padded_descriptor_, reinterpret_cast<size_t *>(&padded_size_)),
"cudnnGetTensorSizeInBytes failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle_, padded_descriptor_, dy_desc_, conv_desc_,
dw_desc_, algo_, reinterpret_cast<size_t *>(&workspace_size_)),
"cudnnGetConvolutionBackwardFilterWorkspaceSize failed");
workspace_size_list_.push_back(padded_size_);
} else {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle_, x_desc_, dy_desc_, conv_desc_, dw_desc_, algo_,
reinterpret_cast<size_t *>(&workspace_size_)),
"cudnnGetConvolutionBackwardFilterWorkspaceSize failed");
}
}
(void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_);
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
workspace_size_list_.push_back(output_size_);
output_size_list_.push_back(num_output_elements_ * sizeof(float));
} else {
output_size_list_.push_back(output_size_);
}
}
private:
bool CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but Conv3dGradFilterGpuKernel needs 2 inputs.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but Conv3dGradFilterGpuKernel needs 1 output.";
return false;
}
return true;
}
void SelectAlgorithm(cudnnTensorDescriptor_t x_desc_real) {
const int requested_algo_count = 1;
int returned_algo_count = 0;
cudnnConvolutionBwdFilterAlgoPerf_t perf_results;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnn_handle_, x_desc_real, dy_desc_, conv_desc_, dw_desc_,
requested_algo_count, &returned_algo_count, &perf_results),
"GetConvolutionBackwardFilterAlgorithm failed");
algo_ = perf_results.algo;
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
}
}
void GetFilterShape(const CNodePtr &kernel_node, std::vector<size_t> *filter_shape) {
auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("filter_size")->cast<ValueTuplePtr>()->value();
(void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*filter_shape),
[](const ValuePtr &e) -> size_t { return static_cast<int>(e->cast<Int64ImmPtr>()->value()); });
}
void SetNDDesc(const std::vector<size_t> &dy_shape, const std::vector<size_t> &filter_shape,
const std::vector<size_t> &in_shape) {
const int kDims = 5;
int dimA[kDims];
int strideAin[kDims];
int dimAdy[kDims];
int strideAdy[kDims];
int filterDimA[kDims];
SetDimA(in_shape, dimA, kDims, data_format_);
SetStrideA(in_shape, strideAin, kDims, data_format_);
SetDimA(dy_shape, dimAdy, kDims, data_format_);
SetStrideA(dy_shape, strideAdy, kDims, data_format_);
SetDimA(filter_shape, filterDimA, kDims, data_format_);
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptor(dy_desc_, cudnn_data_type_, kDims, dimAdy, strideAdy),
"cudnnSetTensorNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetFilterNdDescriptor(dw_desc_, cudnn_data_type_, compute_format_, kDims, filterDimA),
"cudnnSetFilterNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptor(x_desc_, cudnn_data_type_, kDims, dimA, strideAin),
"cudnnSetTensorNdDescriptor failed");
}
void SetStrideAndDilation(const CNodePtr &kernel_node) {
std::vector<int64_t> stride_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "strides");
std::vector<int64_t> dilation_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "dilations");
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_),
[](const int64_t &value) { return static_cast<int>(value); });
(void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_),
[](const int64_t &value) { return static_cast<int>(value); });
if (stride_.size() != 5) {
MS_LOG(EXCEPTION) << "Conv3dGradFilterGpuKernel stride must be 5d, but got " << stride_.size();
}
if (stride_[0] != 1 || stride_[1] != 1) {
MS_LOG(EXCEPTION) << "Conv3dGradFilterGpuKernel stride only support 1 in N axis and C axis!";
}
if (dilation_.size() != 5) {
MS_LOG(EXCEPTION) << "Conv3dGradFilterGpuKernel dilation must be 5d!";
}
if (dilation_[0] != 1 || dilation_[1] != 1) {
MS_LOG(EXCEPTION) << "Conv3dGradFilterGpuKernel dilation only support 1 in N axis and C axis!";
}
}
cudnnHandle_t cudnn_handle_;
cudnnFilterDescriptor_t dw_desc_;
cudnnConvolutionDescriptor_t conv_desc_;
cudnnTensorDescriptor_t dy_desc_;
cudnnTensorDescriptor_t x_desc_;
cudnnTensorDescriptor_t padded_descriptor_;
cudnnConvolutionBwdFilterAlgo_t algo_;
std::string pad_mode_;
std::string data_format_ = kOpFormat_NCDHW;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
const float pad_value_ = 0.0;
cudnnDataType_t cudnn_data_type_;
cudnnTensorFormat_t compute_format_;
int old_depth_;
int old_height_;
int old_width_;
int pad_depth_;
int pad_height_;
int pad_width_;
int pad_head_;
int pad_top_;
int pad_left_;
int n_;
int c_;
std::vector<int> stride_;
std::vector<int> dilation_;
int group_;
bool is_null_input_;
size_t input_size_;
size_t dy_size_;
size_t output_size_;
size_t padded_size_;
size_t workspace_size_;
bool use_pad_;
size_t num_output_elements_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_FILTER_GPU_KERNEL_H_

View File

@ -0,0 +1,30 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/conv3d_grad_input_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
Conv3DBackpropInput,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
Conv3dGradInputGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Conv3DBackpropInput,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
Conv3dGradInputGpuKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,397 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_INPUT_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_INPUT_GPU_KERNEL_H_
#include <algorithm>
#include <string>
#include <vector>
#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
namespace mindspore {
namespace kernel {
template <typename T>
class Conv3dGradInputGpuKernel : public GpuKernel {
public:
Conv3dGradInputGpuKernel() { ResetResource(); }
~Conv3dGradInputGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *w = GetDeviceAddress<T>(inputs, 0);
T *dy = GetDeviceAddress<T>(inputs, 1);
T *dx = GetDeviceAddress<T>(outputs, 0);
T *work_space = nullptr;
if (workspace_size_ != 0) {
work_space = GetDeviceAddress<T>(workspace, 0);
}
const float alpha = 1;
if (use_pad_) {
T *padded = GetDeviceAddress<T>(workspace, 1);
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space,
workspace_size_, &beta_, padded_descriptor_, padded),
"ConvolutionBackwardData failed");
CalPadGrad3d(output_size_ / sizeof(T), padded, n_, c_, old_depth_, old_height_, old_width_,
old_depth_ + pad_depth_, old_height_ + pad_height_, old_width_ + pad_width_, pad_head_, pad_top_,
pad_left_, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space,
workspace_size_, &beta_, dx_desc_, dx),
"ConvolutionBackwardData failed");
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
InitResource();
if (!CheckParam(kernel_node)) {
return false;
}
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
data_format_ = kOpFormat_NCDHW;
auto filter_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto dy_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
is_null_input_ = CHECK_NULL_INPUT(dy_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "Conv3dGradInputGpuKernel input is null.";
InitSizeLists();
return true;
}
std::vector<size_t> input_shape;
GetInputShape(kernel_node, &input_shape);
compute_format_ = CUDNN_TENSOR_NCHW;
CHECK_TENSOR_SIZE(input_shape);
n_ = SizeToInt(input_shape[0]);
c_ = SizeToInt(input_shape[1]);
old_depth_ = SizeToInt(input_shape[2]);
old_height_ = SizeToInt(input_shape[3]);
old_width_ = SizeToInt(input_shape[3]);
SetNDDesc(dy_shape, input_shape, filter_shape);
group_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "group"));
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionGroupCount(conv_desc_, group_),
"cudnnSetConvGroupCount failed");
std::vector<int> pad_list;
std::vector<int64_t> pad_list_me = GetAttr<std::vector<int64_t>>(kernel_node, "pad_list");
(void)std::transform(pad_list_me.begin(), pad_list_me.end(), std::back_inserter(pad_list),
[](const int64_t &value) { return static_cast<int>(value); });
pad_depth_ = pad_list[0];
pad_height_ = pad_list[2];
pad_width_ = pad_list[4];
use_pad_ = !((pad_depth_ == pad_list[1]) && (pad_height_ == pad_list[3]) && (pad_width_ == pad_list[5]));
pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode");
SetStrideAndDilation(kernel_node);
cudnnTensorDescriptor_t dx_desc_real = nullptr;
const int kNumDims = 5;
const int kConvDims = 3;
int padA[kConvDims];
int strideA[kConvDims] = {stride_[2], stride_[3], stride_[4]};
int dilaA[kConvDims] = {dilation_[2], dilation_[3], dilation_[4]};
if (use_pad_) {
pad_depth_ = pad_list[0] + pad_list[1];
pad_height_ = pad_list[2] + pad_list[3];
pad_width_ = pad_list[4] + pad_list[5];
pad_head_ = pad_list[0];
pad_top_ = pad_list[2];
pad_left_ = pad_list[4];
int dimA[kNumDims];
int strideApadded[kNumDims];
if (data_format_ == kOpFormat_NCDHW) {
auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_depth_ + pad_depth_),
IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_)};
SetDimA(padded_shape, dimA, kNumDims, data_format_);
SetStrideA(padded_shape, strideApadded, kNumDims, data_format_);
} else {
MS_LOG(EXCEPTION) << "Conv3dGradInputGpuKernel only support NCDHW format right now.";
}
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, kNumDims, dimA, strideApadded),
"cudnnSetTensorNdDescriptor failed");
padA[0] = 0;
padA[1] = 0;
padA[2] = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetConvolutionNdDescriptor(conv_desc_, kConvDims, padA, strideA, dilaA,
CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
"cudnnSetConvolutionNdDescriptor failed");
dx_desc_real = padded_descriptor_;
} else {
if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) {
pad_depth_ = 0;
pad_height_ = 0;
pad_width_ = 0;
}
padA[0] = pad_depth_;
padA[1] = pad_height_;
padA[2] = pad_width_;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetConvolutionNdDescriptor(conv_desc_, kConvDims, padA, strideA, dilaA,
CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
"cudnnSetConvolution2dDescriptor failed");
dx_desc_real = dx_desc_;
}
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH),
"cudnnSetConvolutionMathType failed.")
}
SelectAlgorithm(dx_desc_real);
beta_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1;
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
cudnn_handle_ = nullptr;
w_desc_ = nullptr;
conv_desc_ = nullptr;
dy_desc_ = nullptr;
dx_desc_ = nullptr;
padded_descriptor_ = nullptr;
cudnn_data_type_ = CUDNN_DATA_FLOAT;
compute_format_ = CUDNN_TENSOR_NCHW;
old_depth_ = 0;
old_height_ = 0;
old_width_ = 0;
pad_depth_ = 0;
pad_height_ = 0;
pad_width_ = 0;
pad_head_ = 0;
pad_top_ = 0;
pad_left_ = 0;
n_ = 0;
c_ = 0;
group_ = 1;
is_null_input_ = false;
dy_size_ = 0;
w_size_ = 0;
output_size_ = 0;
padded_size_ = 0;
workspace_size_ = 0;
use_pad_ = true;
beta_ = 0;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyConvolutionDescriptor(conv_desc_),
"cudnnDestroyConvolutionDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyFilterDescriptor(w_desc_),
"cudnnDestroyFilterDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(padded_descriptor_),
"cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_),
"cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dx_desc_),
"cudnnDestroyTensorDescriptor failed");
}
protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dx_desc_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&padded_descriptor_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateFilterDescriptor(&w_desc_),
"cudnnCreateFilterDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateConvolutionDescriptor(&conv_desc_),
"cudnnCreateConvolutionDescriptor failed");
}
void InitSizeLists() override {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(dy_desc_, &dy_size_),
"cudnnGetTensorSizeInBytes failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetFilterSizeInBytes(w_desc_, &w_size_),
"cudnnGetTensorSizeInBytes failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(dx_desc_, &output_size_),
"cudnnGetTensorSizeInBytes failed");
}
input_size_list_.push_back(dy_size_);
input_size_list_.push_back(w_size_);
output_size_list_.push_back(output_size_);
if (use_pad_ && !is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(padded_descriptor_, &padded_size_),
"cudnnGetTensorSizeInBytes failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, padded_descriptor_,
algo_, &workspace_size_),
"cudnnGetConvolutionBackwardDataWorkspaceSize failed");
workspace_size_list_.push_back(padded_size_);
} else {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnGetConvolutionBackwardDataWorkspaceSize(
cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_, algo_, &workspace_size_),
"cudnnGetConvolutionBackwardDataWorkspaceSize failed");
}
}
(void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_);
}
private:
bool CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but Conv3dGradInputGpuKernel needs 2 inputs.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but Conv3dGradInputGpuKernel needs 1 output.";
return false;
}
return true;
}
void SetPad(const std::vector<int> &input_shape, const CNodePtr &kernel_node) {
std::vector<int> pad_list;
std::vector<int64_t> pad_list_me = GetAttr<std::vector<int64_t>>(kernel_node, "pad_list");
(void)std::transform(pad_list_me.begin(), pad_list_me.end(), std::back_inserter(pad_list),
[](const int64_t &value) { return static_cast<int>(value); });
}
void SelectAlgorithm(cudnnTensorDescriptor_t dx_desc_real) {
const int requested_algo_count = 1;
int returned_algo_count = 0;
cudnnConvolutionBwdDataAlgoPerf_t perf_results;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_real,
requested_algo_count, &returned_algo_count, &perf_results),
"cudnnGetConvolutionBackwardDataAlgorithm_v7 failed");
algo_ = perf_results.algo;
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
}
}
void GetInputShape(const CNodePtr &kernel_node, std::vector<size_t> *input_shape) {
auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("input_size")->cast<ValueTuplePtr>()->value();
(void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*input_shape),
[](const ValuePtr &e) -> size_t { return static_cast<int>(e->cast<Int64ImmPtr>()->value()); });
}
void SetNDDesc(const std::vector<size_t> &dy_shape, const std::vector<size_t> &input_shape,
const std::vector<size_t> &filter_shape) {
const int kDims = 5;
int dimA[kDims];
int strideAin[kDims];
int dimAdy[kDims];
int strideAdy[kDims];
int filterDimA[kDims];
SetDimA(input_shape, dimA, kDims, data_format_);
SetStrideA(input_shape, strideAin, kDims, data_format_);
SetDimA(dy_shape, dimAdy, kDims, data_format_);
SetStrideA(dy_shape, strideAdy, kDims, data_format_);
SetDimA(filter_shape, filterDimA, kDims, data_format_);
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptor(dy_desc_, cudnn_data_type_, kDims, dimAdy, strideAdy),
"cudnnSetTensorNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, compute_format_, kDims, filterDimA),
"cudnnSetFilterNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptor(dx_desc_, cudnn_data_type_, kDims, dimA, strideAin),
"cudnnSetTensorNdDescriptor failed");
}
void SetStrideAndDilation(const CNodePtr &kernel_node) {
std::vector<int64_t> stride_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "strides");
std::vector<int64_t> dilation_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "dilations");
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_),
[](const int64_t &value) { return static_cast<int>(value); });
(void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_),
[](const int64_t &value) { return static_cast<int>(value); });
if (stride_.size() != 5) {
MS_LOG(EXCEPTION) << "Conv3dGradInputGpuKernel stride must be 5d, but got " << stride_.size();
}
if (stride_[0] != 1 || stride_[1] != 1) {
MS_LOG(EXCEPTION) << "Conv3dGradInputGpuKernel stride only support 1 in N axis and C axis!";
}
if (dilation_.size() != 5) {
MS_LOG(EXCEPTION) << "Conv3dGradInputGpuKernel dilation must be 5d!";
}
if (dilation_[0] != 1 || dilation_[1] != 1) {
MS_LOG(EXCEPTION) << "Conv3dGradInputGpuKernel dilation only support 1 in N axis and C axis!";
}
}
cudnnHandle_t cudnn_handle_;
cudnnFilterDescriptor_t w_desc_;
cudnnConvolutionDescriptor_t conv_desc_;
cudnnTensorDescriptor_t dy_desc_;
cudnnTensorDescriptor_t dx_desc_;
cudnnTensorDescriptor_t padded_descriptor_;
cudnnConvolutionBwdDataAlgo_t algo_;
std::string pad_mode_;
std::string data_format_ = kOpFormat_NCDHW;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
cudnnDataType_t cudnn_data_type_;
cudnnTensorFormat_t compute_format_;
int old_depth_;
int old_height_;
int old_width_;
int pad_depth_;
int pad_height_;
int pad_width_;
int pad_head_;
int pad_top_;
int pad_left_;
int n_;
int c_;
std::vector<int> stride_;
std::vector<int> dilation_;
int group_;
bool is_null_input_;
size_t dy_size_;
size_t w_size_;
size_t output_size_;
size_t padded_size_;
size_t workspace_size_;
bool use_pad_;
float beta_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_INPUT_GPU_KERNEL_H_

View File

@ -58,8 +58,8 @@ class _Conv(Cell):
self.format = Validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.cls_name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
if context.get_context("device_target") != "Ascend" and self.format == "NCDHW":
raise ValueError("NCDHW format only support in Ascend target.")
if context.get_context("device_target") == "CPU" and self.format == "NCDHW":
raise ValueError("NCDHW format only support in Ascend and GPU targets.")
if isinstance(padding, int):
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
self.padding = padding

View File

@ -19,7 +19,9 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.parameter import ParameterTuple
from mindspore.ops import operations as P
from mindspore.ops import composite as C
class NetConv3d(nn.Cell):
@ -71,3 +73,84 @@ def test_conv3d():
net = NetConv3d()
output = net(x, w)
assert (output.asnumpy() == expect).all()
class MSConv3dNet(nn.Cell):
def __init__(self, in_channels, out_channels, kernel_size, pad_mode='pad', padding=0, stride=1, dilation=1,
has_bias=False, weight_init='normal'):
super(MSConv3dNet, self).__init__()
self.cv1 = nn.Conv3d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
pad_mode=pad_mode,
padding=padding,
stride=stride,
dilation=dilation,
group=1,
has_bias=has_bias,
weight_init=weight_init,
data_format='NCDHW')
def construct(self, x):
x = self.cv1(x)
return x
class MSGradNet(nn.Cell):
def __init__(self, network):
super(MSGradNet, self).__init__()
self.grad = C.GradOperation(get_all=True, sens_param=True, get_by_list=True)
self.network = network
self.params = ParameterTuple(network.trainable_params())
def construct(self, x, dy):
grad_op = self.grad(self.network, self.params)
output = grad_op(x, dy)
return output
def test_conv3d_grad():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
dtype = np.float32
out_c = 2
kernel_size = (2, 2, 2)
x = Tensor(np.array([[[[[1.6924546, 0.05080776, -0.6369957],
[0.19091548, 2.1002553, 0.12015896],
[0.6172031, 0.30017033, -0.35224986]],
[[-1.1425182, -0.34934273, -0.20889424],
[0.5866232, 0.8389834, 0.9311021],
[0.2855873, 0.8851412, -0.7543979]],
[[1.2528682, 0.5129298, -0.29809284],
[0.48851815, -0.07557172, 1.1316293],
[1.5198169, 2.1855755, -1.3964963]]]]]).astype(dtype))
dy = Tensor(np.array([[[[[-1.4441139, -0.5044659],
[0.16003707, 0.8761689]],
[[0.31563494, -2.0222013],
[-0.30620402, 0.8279746]]],
[[[0.23009473, 0.7620112],
[-0.22232814, -0.20075807]],
[[0.18656139, 0.41005164],
[0.19829972, 0.11900865]]]]]).astype(dtype))
w = Tensor(np.array([[[[[-0.9358, -0.2679],
[0.5304, -0.6917]],
[[-0.3968, -0.6872],
[-0.8452, -0.6712]]]],
[[[[-0.0127, -1.1173],
[0.2344, 1.6598]],
[[0.7420, -0.1918],
[-0.8876, -0.7472]]]]]).astype(dtype))
w_exp = np.array([[[[[-0.9384, -0.2830],
[0.5487, -0.6330]],
[[-0.4148, -0.7200],
[-0.8572, -0.6079]]]],
[[[[-0.0109, -1.1089],
[0.2138, 1.6478]],
[[0.7450, -0.1866],
[-0.8992, -0.7629]]]]]).astype(dtype)
net = MSConv3dNet(x.shape[1], out_c, kernel_size, weight_init=w)
grad_net = MSGradNet(net)
optimizer = nn.SGD(net.trainable_params(), learning_rate=0.01, momentum=0.9)
grad_net.set_train(True)
output = grad_net(x, dy)
optimizer(output[1])
assert np.allclose(net.cv1.weight.asnumpy(), w_exp, atol=1.0e-4)