Refactor ConvGradFilter/ConvGradInput CpuKernelMod

This commit is contained in:
hanhuifeng2020 2022-10-31 16:51:23 +08:00
parent 8247c11786
commit 0e1224addf
4 changed files with 97 additions and 42 deletions

View File

@ -29,35 +29,53 @@ constexpr auto kConv3DBackpropFilter = "Conv3DBackpropFilter";
constexpr size_t kConvGradFilterInputsMinNum = 2;
constexpr size_t kConvGradFilterOutputsNum = 1;
} // namespace
void ConvGradFilterCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
bool ConvGradFilterCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
if (kernel_name_ == kConv2DBackpropFilterOpName) {
src_index_ = 1;
diff_dst_index_ = 0;
}
std::vector<int64_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, src_index_);
std::vector<int64_t> dst_shape = AnfAlgo::GetInputDeviceShape(kernel_node, diff_dst_index_);
std::vector<int64_t> weight_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
if (AnfAlgo::IsShapesDynamic({src_shape, weight_shape, dst_shape})) {
return;
auto prim = base_operator->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
format_ = GetValue<std::string>(prim->GetAttr(FORMAT));
group_ = GetValue<int64_t>(prim->GetAttr(GROUP));
pad_mode_ = GetValue<std::string>(prim->GetAttr(PAD_MODE));
if (format_ != NCHW && format_ != NCDHW) {
MS_LOG(EXCEPTION) << kernel_name_ << " only supports " << NCHW << " or " << NCDHW << " format "
<< ", but got format: " << format_;
}
const auto stride_attr = format_ == NCHW ? STRIDE : STRIDES;
const auto dilation_attr = format_ == NCHW ? DILATION : DILATIONS;
strides_include_nc_ = GetValue<std::vector<int64_t>>(prim->GetAttr(stride_attr));
dilation_include_nc_ = GetValue<std::vector<int64_t>>(prim->GetAttr(dilation_attr));
return true;
}
int ConvGradFilterCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto src_shape = inputs[src_index_]->GetDeviceShapeAdaptively();
auto dst_shape = inputs[diff_dst_index_]->GetDeviceShapeAdaptively();
auto weight_shape = outputs[0]->GetDeviceShapeAdaptively();
size_t src_dim = src_shape.size();
if (src_dim != SHAPE_4D && src_dim != SHAPE_5D) {
MS_LOG(EXCEPTION) << "Conv Grad only supports 4D/5D input, but got " << src_dim << "D!";
}
const auto format = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, FORMAT);
const auto &format = format_;
if (src_dim == SHAPE_4D && format != NCHW) {
MS_LOG(EXCEPTION) << kernel_name_ << " only supports 4D input with NCHW format, but got format " << format;
}
if (src_dim == SHAPE_5D && format != NCDHW) {
MS_LOG(EXCEPTION) << kernel_name_ << " only supports 5D input with NCDHW format, but got fornat " << format;
}
dnnl::memory::dims kernel_size(weight_shape.begin() + NC_LEN, weight_shape.end());
const size_t group = LongToSize(common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, GROUP));
const auto group = group_;
if (group > 1) {
if (src_shape[1] % group != 0) {
MS_LOG(EXCEPTION) << kernel_name_ << " requires channels must be divided by group!";
@ -65,15 +83,11 @@ void ConvGradFilterCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
(void)weight_shape.insert(weight_shape.begin(), group);
weight_shape[1] = weight_shape[1] / group;
}
const dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
const dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape);
const dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape);
const auto stride_attr = src_dim == SHAPE_4D ? STRIDE : STRIDES;
const auto dilation_attr = src_dim == SHAPE_4D ? DILATION : DILATIONS;
const auto pad_mode = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE);
const auto strides_include_nc = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, stride_attr);
const auto dilation_include_nc = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, dilation_attr);
const auto &strides_include_nc = strides_include_nc_;
const auto &dilation_include_nc = dilation_include_nc_;
if (strides_include_nc.size() != src_dim) {
MS_LOG(EXCEPTION) << kernel_name_ << "requires strides must be " << src_dim << "D, but got "
<< strides_include_nc.size() << "D!";
@ -89,9 +103,9 @@ void ConvGradFilterCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::dims padding_r;
(void)std::transform(dilation.begin(), dilation.end(), std::back_inserter(dilates),
[](const int64_t &value) { return value - 1; });
const auto &pad_mode = pad_mode_;
PaddingInfo padding_info{pad_mode, kernel_size, strides, dilation, &padding_l, &padding_r};
GetPadding(kernel_node, src_shape, padding_info);
GetPadding(base_operator, src_shape, padding_info);
const auto forward_desc = CreateDesc<dnnl::convolution_forward::desc>(
dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides,
dilates, padding_l, padding_r);
@ -104,6 +118,7 @@ void ConvGradFilterCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
AddArgument(DNNL_ARG_SRC, src_desc);
AddArgument(DNNL_ARG_DIFF_DST, dst_desc);
AddArgument(DNNL_ARG_DIFF_WEIGHTS, weights_desc);
return KRET_OK;
}
bool ConvGradFilterCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,

View File

@ -20,18 +20,25 @@
#include <vector>
#include <memory>
#include <string>
#include <map>
#include "plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h"
namespace mindspore {
namespace kernel {
class ConvGradFilterCpuKernelMod : public DeprecatedMKLCpuKernelMod {
class ConvGradFilterCpuKernelMod : public MKLCpuKernelMod {
public:
ConvGradFilterCpuKernelMod() = default;
explicit ConvGradFilterCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
~ConvGradFilterCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
@ -42,6 +49,11 @@ class ConvGradFilterCpuKernelMod : public DeprecatedMKLCpuKernelMod {
size_t src_index_{0};
size_t diff_dst_index_{1};
std::string kernel_type_;
std::string format_;
int64_t group_;
std::string pad_mode_;
std::vector<int64_t> strides_include_nc_;
std::vector<int64_t> dilation_include_nc_;
};
} // namespace kernel
} // namespace mindspore

View File

@ -30,26 +30,44 @@ constexpr size_t kConvGradInputInputsMinNum = 2;
constexpr size_t kConvGradInputOutputsNum = 1;
} // namespace
void ConvGradInputCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
bool ConvGradInputCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
if (kernel_name_ == kConv2DBackpropInputOpName) {
weight_index_ = 1;
diff_dst_index_ = 0;
}
std::vector<int64_t> src_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
std::vector<int64_t> weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, weight_index_);
std::vector<int64_t> dst_shape = AnfAlgo::GetInputDeviceShape(kernel_node, diff_dst_index_);
if (AnfAlgo::IsShapesDynamic({src_shape, weight_shape, dst_shape})) {
return;
auto prim = base_operator->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
group_ = GetValue<int64_t>(prim->GetAttr(GROUP));
format_ = GetValue<std::string>(prim->GetAttr(FORMAT));
if (format_ != NCHW && format_ != NCDHW) {
MS_LOG(EXCEPTION) << kernel_name_ << " only supports " << NCHW << " or " << NCDHW << " format "
<< ", but got format: " << format_;
}
const auto stride_attr = format_ == NCHW ? STRIDE : STRIDES;
const auto dilation_attr = format_ == NCHW ? DILATION : DILATIONS;
pad_mode_ = GetValue<std::string>(prim->GetAttr(PAD_MODE));
strides_include_nc_ = GetValue<std::vector<int64_t>>(prim->GetAttr(stride_attr));
dilation_include_nc_ = GetValue<std::vector<int64_t>>(prim->GetAttr(dilation_attr));
return true;
}
int ConvGradInputCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
std::vector<int64_t> src_shape = outputs[0]->GetDeviceShapeAdaptively();
std::vector<int64_t> weight_shape = inputs[weight_index_]->GetDeviceShapeAdaptively();
std::vector<int64_t> dst_shape = inputs[diff_dst_index_]->GetDeviceShapeAdaptively();
size_t src_dim = src_shape.size();
if (src_dim != SHAPE_4D && src_dim != SHAPE_5D) {
MS_LOG(EXCEPTION) << "Conv grad only supports 4D/5D input, but got " << src_dim << "D!";
}
const auto format = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, FORMAT);
const auto &format = format_;
if (src_dim == SHAPE_4D && format != NCHW) {
MS_LOG(EXCEPTION) << kernel_name_ << " only supports 4D input with NCHW format, but got format" << format;
}
@ -57,7 +75,7 @@ void ConvGradInputCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(EXCEPTION) << kernel_name_ << " only supports 5D input with NCDHW format, but got format " << format;
}
dnnl::memory::dims kernel_size(weight_shape.begin() + NC_LEN, weight_shape.end());
const size_t group = LongToSize(common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, GROUP));
const auto group = group_;
if (group > 1) {
if (src_shape[1] % group != 0) {
MS_LOG(EXCEPTION) << "Conv grad channels must be divided by group!";
@ -65,19 +83,15 @@ void ConvGradInputCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
(void)weight_shape.insert(weight_shape.begin(), group);
weight_shape[1] = weight_shape[1] / group;
}
const dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
const dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape);
const dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape);
const auto stride_attr = src_dim == SHAPE_4D ? STRIDE : STRIDES;
const auto dilation_attr = src_dim == SHAPE_4D ? DILATION : DILATIONS;
const auto pad_mode = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE);
const auto strides_include_nc = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, stride_attr);
const auto dilation_include_nc = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, dilation_attr);
const auto &strides_include_nc = strides_include_nc_;
if (strides_include_nc.size() != src_dim) {
MS_LOG(EXCEPTION) << kernel_name_ << "requires strides must be " << src_dim << "D, but got "
<< strides_include_nc.size() << "D!";
}
const auto &dilation_include_nc = dilation_include_nc_;
if (dilation_include_nc.size() != src_dim) {
MS_LOG(EXCEPTION) << kernel_name_ << " requires dilation must be " << src_dim << "D, but got "
<< dilation_include_nc.size() << "D!";
@ -89,8 +103,9 @@ void ConvGradInputCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::dims padding_r;
(void)std::transform(dilation.begin(), dilation.end(), std::back_inserter(dilates),
[](const int64_t &value) { return value - 1; });
const auto &pad_mode = pad_mode_;
PaddingInfo padding_info{pad_mode, kernel_size, strides, dilation, &padding_l, &padding_r};
GetPadding(kernel_node, src_shape, padding_info);
GetPadding(base_operator, src_shape, padding_info);
const auto forward_desc = CreateDesc<dnnl::convolution_forward::desc>(
dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides,
@ -104,6 +119,7 @@ void ConvGradInputCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
AddArgument(DNNL_ARG_DIFF_SRC, src_desc);
AddArgument(DNNL_ARG_DIFF_DST, dst_desc);
AddArgument(DNNL_ARG_WEIGHTS, weights_desc);
return KRET_OK;
}
bool ConvGradInputCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,

View File

@ -20,18 +20,25 @@
#include <vector>
#include <memory>
#include <string>
#include <map>
#include "plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h"
namespace mindspore {
namespace kernel {
class ConvGradInputCpuKernelMod : public DeprecatedMKLCpuKernelMod {
class ConvGradInputCpuKernelMod : public MKLCpuKernelMod {
public:
ConvGradInputCpuKernelMod() = default;
explicit ConvGradInputCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
~ConvGradInputCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
@ -42,6 +49,11 @@ class ConvGradInputCpuKernelMod : public DeprecatedMKLCpuKernelMod {
size_t weight_index_{0};
size_t diff_dst_index_{1};
std::string kernel_type_;
std::string format_;
std::string pad_mode_;
int64_t group_;
std::vector<int64_t> strides_include_nc_;
std::vector<int64_t> dilation_include_nc_;
};
} // namespace kernel
} // namespace mindspore