diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/conv_grad_filter_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/conv_grad_filter_cpu_kernel.cc index 2040d7c7df2..b7cacd964a5 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/conv_grad_filter_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/conv_grad_filter_cpu_kernel.cc @@ -29,35 +29,53 @@ constexpr auto kConv3DBackpropFilter = "Conv3DBackpropFilter"; constexpr size_t kConvGradFilterInputsMinNum = 2; constexpr size_t kConvGradFilterOutputsNum = 1; } // namespace -void ConvGradFilterCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); + +bool ConvGradFilterCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(base_operator); + kernel_name_ = base_operator->name(); if (kernel_name_ == kConv2DBackpropFilterOpName) { src_index_ = 1; diff_dst_index_ = 0; } - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, src_index_); - std::vector dst_shape = AnfAlgo::GetInputDeviceShape(kernel_node, diff_dst_index_); - std::vector weight_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - - if (AnfAlgo::IsShapesDynamic({src_shape, weight_shape, dst_shape})) { - return; + auto prim = base_operator->GetPrim(); + MS_EXCEPTION_IF_NULL(prim); + format_ = GetValue(prim->GetAttr(FORMAT)); + group_ = GetValue(prim->GetAttr(GROUP)); + pad_mode_ = GetValue(prim->GetAttr(PAD_MODE)); + if (format_ != NCHW && format_ != NCDHW) { + MS_LOG(EXCEPTION) << kernel_name_ << " only supports " << NCHW << " or " << NCDHW << " format " + << ", but got format: " << format_; } + const auto stride_attr = format_ == NCHW ? STRIDE : STRIDES; + const auto dilation_attr = format_ == NCHW ? DILATION : DILATIONS; + strides_include_nc_ = GetValue>(prim->GetAttr(stride_attr)); + dilation_include_nc_ = GetValue>(prim->GetAttr(dilation_attr)); + return true; +} +int ConvGradFilterCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &) { + if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) { + return ret; + } + auto src_shape = inputs[src_index_]->GetDeviceShapeAdaptively(); + auto dst_shape = inputs[diff_dst_index_]->GetDeviceShapeAdaptively(); + auto weight_shape = outputs[0]->GetDeviceShapeAdaptively(); size_t src_dim = src_shape.size(); if (src_dim != SHAPE_4D && src_dim != SHAPE_5D) { MS_LOG(EXCEPTION) << "Conv Grad only supports 4D/5D input, but got " << src_dim << "D!"; } - const auto format = common::AnfAlgo::GetNodeAttr(kernel_node, FORMAT); + const auto &format = format_; if (src_dim == SHAPE_4D && format != NCHW) { MS_LOG(EXCEPTION) << kernel_name_ << " only supports 4D input with NCHW format, but got format " << format; } if (src_dim == SHAPE_5D && format != NCDHW) { MS_LOG(EXCEPTION) << kernel_name_ << " only supports 5D input with NCDHW format, but got fornat " << format; } - dnnl::memory::dims kernel_size(weight_shape.begin() + NC_LEN, weight_shape.end()); - const size_t group = LongToSize(common::AnfAlgo::GetNodeAttr(kernel_node, GROUP)); + const auto group = group_; if (group > 1) { if (src_shape[1] % group != 0) { MS_LOG(EXCEPTION) << kernel_name_ << " requires channels must be divided by group!"; @@ -65,15 +83,11 @@ void ConvGradFilterCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { (void)weight_shape.insert(weight_shape.begin(), group); weight_shape[1] = weight_shape[1] / group; } - const dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); const dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape); const dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); - const auto stride_attr = src_dim == SHAPE_4D ? STRIDE : STRIDES; - const auto dilation_attr = src_dim == SHAPE_4D ? DILATION : DILATIONS; - const auto pad_mode = common::AnfAlgo::GetNodeAttr(kernel_node, PAD_MODE); - const auto strides_include_nc = common::AnfAlgo::GetNodeAttr>(kernel_node, stride_attr); - const auto dilation_include_nc = common::AnfAlgo::GetNodeAttr>(kernel_node, dilation_attr); + const auto &strides_include_nc = strides_include_nc_; + const auto &dilation_include_nc = dilation_include_nc_; if (strides_include_nc.size() != src_dim) { MS_LOG(EXCEPTION) << kernel_name_ << "requires strides must be " << src_dim << "D, but got " << strides_include_nc.size() << "D!"; @@ -89,9 +103,9 @@ void ConvGradFilterCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { dnnl::memory::dims padding_r; (void)std::transform(dilation.begin(), dilation.end(), std::back_inserter(dilates), [](const int64_t &value) { return value - 1; }); + const auto &pad_mode = pad_mode_; PaddingInfo padding_info{pad_mode, kernel_size, strides, dilation, &padding_l, &padding_r}; - GetPadding(kernel_node, src_shape, padding_info); - + GetPadding(base_operator, src_shape, padding_info); const auto forward_desc = CreateDesc( dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r); @@ -104,6 +118,7 @@ void ConvGradFilterCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { AddArgument(DNNL_ARG_SRC, src_desc); AddArgument(DNNL_ARG_DIFF_DST, dst_desc); AddArgument(DNNL_ARG_DIFF_WEIGHTS, weights_desc); + return KRET_OK; } bool ConvGradFilterCpuKernelMod::Launch(const std::vector &inputs, diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/conv_grad_filter_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/conv_grad_filter_cpu_kernel.h index 0fbf77c8d01..7e23e5ccbf0 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/conv_grad_filter_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/conv_grad_filter_cpu_kernel.h @@ -20,18 +20,25 @@ #include #include #include +#include #include "plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h" namespace mindspore { namespace kernel { -class ConvGradFilterCpuKernelMod : public DeprecatedMKLCpuKernelMod { +class ConvGradFilterCpuKernelMod : public MKLCpuKernelMod { public: ConvGradFilterCpuKernelMod() = default; explicit ConvGradFilterCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {} ~ConvGradFilterCpuKernelMod() override = default; - void InitKernel(const CNodePtr &kernel_node) override; + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + int Resize( + const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost = std::map()) override; bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; @@ -42,6 +49,11 @@ class ConvGradFilterCpuKernelMod : public DeprecatedMKLCpuKernelMod { size_t src_index_{0}; size_t diff_dst_index_{1}; std::string kernel_type_; + std::string format_; + int64_t group_; + std::string pad_mode_; + std::vector strides_include_nc_; + std::vector dilation_include_nc_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/conv_grad_input_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/conv_grad_input_cpu_kernel.cc index 9a9a61e7074..be565412d2d 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/conv_grad_input_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/conv_grad_input_cpu_kernel.cc @@ -30,26 +30,44 @@ constexpr size_t kConvGradInputInputsMinNum = 2; constexpr size_t kConvGradInputOutputsNum = 1; } // namespace -void ConvGradInputCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); +bool ConvGradInputCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(base_operator); + kernel_name_ = base_operator->name(); if (kernel_name_ == kConv2DBackpropInputOpName) { weight_index_ = 1; diff_dst_index_ = 0; } - std::vector src_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - std::vector weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, weight_index_); - std::vector dst_shape = AnfAlgo::GetInputDeviceShape(kernel_node, diff_dst_index_); - - if (AnfAlgo::IsShapesDynamic({src_shape, weight_shape, dst_shape})) { - return; + auto prim = base_operator->GetPrim(); + MS_EXCEPTION_IF_NULL(prim); + group_ = GetValue(prim->GetAttr(GROUP)); + format_ = GetValue(prim->GetAttr(FORMAT)); + if (format_ != NCHW && format_ != NCDHW) { + MS_LOG(EXCEPTION) << kernel_name_ << " only supports " << NCHW << " or " << NCDHW << " format " + << ", but got format: " << format_; } + const auto stride_attr = format_ == NCHW ? STRIDE : STRIDES; + const auto dilation_attr = format_ == NCHW ? DILATION : DILATIONS; + pad_mode_ = GetValue(prim->GetAttr(PAD_MODE)); + strides_include_nc_ = GetValue>(prim->GetAttr(stride_attr)); + dilation_include_nc_ = GetValue>(prim->GetAttr(dilation_attr)); + return true; +} +int ConvGradInputCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &) { + if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) { + return ret; + } + std::vector src_shape = outputs[0]->GetDeviceShapeAdaptively(); + std::vector weight_shape = inputs[weight_index_]->GetDeviceShapeAdaptively(); + std::vector dst_shape = inputs[diff_dst_index_]->GetDeviceShapeAdaptively(); size_t src_dim = src_shape.size(); if (src_dim != SHAPE_4D && src_dim != SHAPE_5D) { MS_LOG(EXCEPTION) << "Conv grad only supports 4D/5D input, but got " << src_dim << "D!"; } - const auto format = common::AnfAlgo::GetNodeAttr(kernel_node, FORMAT); + const auto &format = format_; if (src_dim == SHAPE_4D && format != NCHW) { MS_LOG(EXCEPTION) << kernel_name_ << " only supports 4D input with NCHW format, but got format" << format; } @@ -57,7 +75,7 @@ void ConvGradInputCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { MS_LOG(EXCEPTION) << kernel_name_ << " only supports 5D input with NCDHW format, but got format " << format; } dnnl::memory::dims kernel_size(weight_shape.begin() + NC_LEN, weight_shape.end()); - const size_t group = LongToSize(common::AnfAlgo::GetNodeAttr(kernel_node, GROUP)); + const auto group = group_; if (group > 1) { if (src_shape[1] % group != 0) { MS_LOG(EXCEPTION) << "Conv grad channels must be divided by group!"; @@ -65,19 +83,15 @@ void ConvGradInputCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { (void)weight_shape.insert(weight_shape.begin(), group); weight_shape[1] = weight_shape[1] / group; } - const dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); const dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape); const dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); - const auto stride_attr = src_dim == SHAPE_4D ? STRIDE : STRIDES; - const auto dilation_attr = src_dim == SHAPE_4D ? DILATION : DILATIONS; - const auto pad_mode = common::AnfAlgo::GetNodeAttr(kernel_node, PAD_MODE); - const auto strides_include_nc = common::AnfAlgo::GetNodeAttr>(kernel_node, stride_attr); - const auto dilation_include_nc = common::AnfAlgo::GetNodeAttr>(kernel_node, dilation_attr); + const auto &strides_include_nc = strides_include_nc_; if (strides_include_nc.size() != src_dim) { MS_LOG(EXCEPTION) << kernel_name_ << "requires strides must be " << src_dim << "D, but got " << strides_include_nc.size() << "D!"; } + const auto &dilation_include_nc = dilation_include_nc_; if (dilation_include_nc.size() != src_dim) { MS_LOG(EXCEPTION) << kernel_name_ << " requires dilation must be " << src_dim << "D, but got " << dilation_include_nc.size() << "D!"; @@ -89,8 +103,9 @@ void ConvGradInputCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { dnnl::memory::dims padding_r; (void)std::transform(dilation.begin(), dilation.end(), std::back_inserter(dilates), [](const int64_t &value) { return value - 1; }); + const auto &pad_mode = pad_mode_; PaddingInfo padding_info{pad_mode, kernel_size, strides, dilation, &padding_l, &padding_r}; - GetPadding(kernel_node, src_shape, padding_info); + GetPadding(base_operator, src_shape, padding_info); const auto forward_desc = CreateDesc( dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, @@ -104,6 +119,7 @@ void ConvGradInputCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { AddArgument(DNNL_ARG_DIFF_SRC, src_desc); AddArgument(DNNL_ARG_DIFF_DST, dst_desc); AddArgument(DNNL_ARG_WEIGHTS, weights_desc); + return KRET_OK; } bool ConvGradInputCpuKernelMod::Launch(const std::vector &inputs, diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/conv_grad_input_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/conv_grad_input_cpu_kernel.h index 48f9abce684..58d49615f74 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/conv_grad_input_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/conv_grad_input_cpu_kernel.h @@ -20,18 +20,25 @@ #include #include #include +#include #include "plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h" namespace mindspore { namespace kernel { -class ConvGradInputCpuKernelMod : public DeprecatedMKLCpuKernelMod { +class ConvGradInputCpuKernelMod : public MKLCpuKernelMod { public: ConvGradInputCpuKernelMod() = default; explicit ConvGradInputCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {} ~ConvGradInputCpuKernelMod() override = default; - void InitKernel(const CNodePtr &kernel_node) override; + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + int Resize( + const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost = std::map()) override; bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; @@ -42,6 +49,11 @@ class ConvGradInputCpuKernelMod : public DeprecatedMKLCpuKernelMod { size_t weight_index_{0}; size_t diff_dst_index_{1}; std::string kernel_type_; + std::string format_; + std::string pad_mode_; + int64_t group_; + std::vector strides_include_nc_; + std::vector dilation_include_nc_; }; } // namespace kernel } // namespace mindspore