fix conv2d_grad_inpu infer strides.

This commit is contained in:
linqingke 2021-03-10 16:14:12 +08:00
parent bc6ad21278
commit 07b50f76ab
3 changed files with 15 additions and 3 deletions

View File

@ -34,6 +34,7 @@ const char KERNEL_SIZE[] = "kernel_size";
const char STRIDE[] = "stride";
const char STRIDES[] = "strides";
const char DILATION[] = "dilation";
const char FORMAT[] = "format";
const char PAD[] = "pad";
const char PAD_LIST[] = "pad_list";
const char PAD_MODE[] = "pad_mode";

View File

@ -15,6 +15,7 @@
*/
#include "backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h"
#include <string>
#include <map>
#include <algorithm>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
@ -22,6 +23,7 @@
namespace mindspore {
namespace kernel {
const std::map<std::string, size_t> kFormatIndexMap = {{"NCHW", 2}, {"HWCN", 0}, {"NHWC", 1}};
void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> src_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
@ -47,7 +49,16 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) {
std::vector<int> dilation_ori;
auto stride_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDE);
auto dilation_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, DILATION);
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_ori),
auto format_me = AnfAlgo::GetNodeAttr<std::string>(kernel_node, FORMAT);
auto iter = kFormatIndexMap.find(format_me);
if (iter == kFormatIndexMap.end()) {
MS_LOG(EXCEPTION) << "OriFormat is " << format_me << ", Please confirm that in {NCHW, HWCN, NHWC}.";
}
size_t h_index = iter->second;
if (stride_me.size() < h_index + 2) {
MS_LOG(EXCEPTION) << "Strides should greater than " << h_index + 1 << ", but got " << stride_me.size();
}
(void)std::transform(stride_me.begin() + h_index, stride_me.begin() + h_index + 2, std::back_inserter(stride_ori),
[](const int64_t &value) { return static_cast<int>(value); });
(void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_ori),
[](const int64_t &value) { return static_cast<int>(value); });

View File

@ -2028,8 +2028,8 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
[dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]]
kernel_h = self.kernel_size[0]
kernel_w = self.kernel_size[1]
stride_h = self.stride[0]
stride_w = self.stride[1]
stride_h = self.stride[2]
stride_w = self.stride[3]
dilation_h = self.dilation[2]
dilation_w = self.dilation[3]
# default pad mode is valid