forked from mindspore-Ecosystem/mindspore
fix conv2d_grad_inpu infer strides.
This commit is contained in:
parent
bc6ad21278
commit
07b50f76ab
|
@ -34,6 +34,7 @@ const char KERNEL_SIZE[] = "kernel_size";
|
|||
const char STRIDE[] = "stride";
|
||||
const char STRIDES[] = "strides";
|
||||
const char DILATION[] = "dilation";
|
||||
const char FORMAT[] = "format";
|
||||
const char PAD[] = "pad";
|
||||
const char PAD_LIST[] = "pad_list";
|
||||
const char PAD_MODE[] = "pad_mode";
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h"
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
@ -22,6 +23,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
const std::map<std::string, size_t> kFormatIndexMap = {{"NCHW", 2}, {"HWCN", 0}, {"NHWC", 1}};
|
||||
void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<size_t> src_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
|
@ -47,7 +49,16 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
std::vector<int> dilation_ori;
|
||||
auto stride_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDE);
|
||||
auto dilation_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, DILATION);
|
||||
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_ori),
|
||||
auto format_me = AnfAlgo::GetNodeAttr<std::string>(kernel_node, FORMAT);
|
||||
auto iter = kFormatIndexMap.find(format_me);
|
||||
if (iter == kFormatIndexMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "OriFormat is " << format_me << ", Please confirm that in {NCHW, HWCN, NHWC}.";
|
||||
}
|
||||
size_t h_index = iter->second;
|
||||
if (stride_me.size() < h_index + 2) {
|
||||
MS_LOG(EXCEPTION) << "Strides should greater than " << h_index + 1 << ", but got " << stride_me.size();
|
||||
}
|
||||
(void)std::transform(stride_me.begin() + h_index, stride_me.begin() + h_index + 2, std::back_inserter(stride_ori),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
(void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_ori),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
|
|
|
@ -2028,8 +2028,8 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
|
|||
[dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]]
|
||||
kernel_h = self.kernel_size[0]
|
||||
kernel_w = self.kernel_size[1]
|
||||
stride_h = self.stride[0]
|
||||
stride_w = self.stride[1]
|
||||
stride_h = self.stride[2]
|
||||
stride_w = self.stride[3]
|
||||
dilation_h = self.dilation[2]
|
||||
dilation_w = self.dilation[3]
|
||||
# default pad mode is valid
|
||||
|
|
Loading…
Reference in New Issue