Refactor ConvGradFilter/ConvGradInput CpuKernelMod
This commit is contained in:
parent
8247c11786
commit
0e1224addf
|
@ -29,35 +29,53 @@ constexpr auto kConv3DBackpropFilter = "Conv3DBackpropFilter";
|
|||
constexpr size_t kConvGradFilterInputsMinNum = 2;
|
||||
constexpr size_t kConvGradFilterOutputsNum = 1;
|
||||
} // namespace
|
||||
void ConvGradFilterCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
|
||||
bool ConvGradFilterCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
if (kernel_name_ == kConv2DBackpropFilterOpName) {
|
||||
src_index_ = 1;
|
||||
diff_dst_index_ = 0;
|
||||
}
|
||||
std::vector<int64_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, src_index_);
|
||||
std::vector<int64_t> dst_shape = AnfAlgo::GetInputDeviceShape(kernel_node, diff_dst_index_);
|
||||
std::vector<int64_t> weight_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
|
||||
if (AnfAlgo::IsShapesDynamic({src_shape, weight_shape, dst_shape})) {
|
||||
return;
|
||||
auto prim = base_operator->GetPrim();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
format_ = GetValue<std::string>(prim->GetAttr(FORMAT));
|
||||
group_ = GetValue<int64_t>(prim->GetAttr(GROUP));
|
||||
pad_mode_ = GetValue<std::string>(prim->GetAttr(PAD_MODE));
|
||||
if (format_ != NCHW && format_ != NCDHW) {
|
||||
MS_LOG(EXCEPTION) << kernel_name_ << " only supports " << NCHW << " or " << NCDHW << " format "
|
||||
<< ", but got format: " << format_;
|
||||
}
|
||||
const auto stride_attr = format_ == NCHW ? STRIDE : STRIDES;
|
||||
const auto dilation_attr = format_ == NCHW ? DILATION : DILATIONS;
|
||||
strides_include_nc_ = GetValue<std::vector<int64_t>>(prim->GetAttr(stride_attr));
|
||||
dilation_include_nc_ = GetValue<std::vector<int64_t>>(prim->GetAttr(dilation_attr));
|
||||
return true;
|
||||
}
|
||||
|
||||
int ConvGradFilterCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
auto src_shape = inputs[src_index_]->GetDeviceShapeAdaptively();
|
||||
auto dst_shape = inputs[diff_dst_index_]->GetDeviceShapeAdaptively();
|
||||
auto weight_shape = outputs[0]->GetDeviceShapeAdaptively();
|
||||
size_t src_dim = src_shape.size();
|
||||
if (src_dim != SHAPE_4D && src_dim != SHAPE_5D) {
|
||||
MS_LOG(EXCEPTION) << "Conv Grad only supports 4D/5D input, but got " << src_dim << "D!";
|
||||
}
|
||||
const auto format = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, FORMAT);
|
||||
const auto &format = format_;
|
||||
if (src_dim == SHAPE_4D && format != NCHW) {
|
||||
MS_LOG(EXCEPTION) << kernel_name_ << " only supports 4D input with NCHW format, but got format " << format;
|
||||
}
|
||||
if (src_dim == SHAPE_5D && format != NCDHW) {
|
||||
MS_LOG(EXCEPTION) << kernel_name_ << " only supports 5D input with NCDHW format, but got fornat " << format;
|
||||
}
|
||||
|
||||
dnnl::memory::dims kernel_size(weight_shape.begin() + NC_LEN, weight_shape.end());
|
||||
const size_t group = LongToSize(common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, GROUP));
|
||||
const auto group = group_;
|
||||
if (group > 1) {
|
||||
if (src_shape[1] % group != 0) {
|
||||
MS_LOG(EXCEPTION) << kernel_name_ << " requires channels must be divided by group!";
|
||||
|
@ -65,15 +83,11 @@ void ConvGradFilterCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
(void)weight_shape.insert(weight_shape.begin(), group);
|
||||
weight_shape[1] = weight_shape[1] / group;
|
||||
}
|
||||
|
||||
const dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
|
||||
const dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape);
|
||||
const dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape);
|
||||
const auto stride_attr = src_dim == SHAPE_4D ? STRIDE : STRIDES;
|
||||
const auto dilation_attr = src_dim == SHAPE_4D ? DILATION : DILATIONS;
|
||||
const auto pad_mode = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE);
|
||||
const auto strides_include_nc = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, stride_attr);
|
||||
const auto dilation_include_nc = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, dilation_attr);
|
||||
const auto &strides_include_nc = strides_include_nc_;
|
||||
const auto &dilation_include_nc = dilation_include_nc_;
|
||||
if (strides_include_nc.size() != src_dim) {
|
||||
MS_LOG(EXCEPTION) << kernel_name_ << "requires strides must be " << src_dim << "D, but got "
|
||||
<< strides_include_nc.size() << "D!";
|
||||
|
@ -89,9 +103,9 @@ void ConvGradFilterCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
dnnl::memory::dims padding_r;
|
||||
(void)std::transform(dilation.begin(), dilation.end(), std::back_inserter(dilates),
|
||||
[](const int64_t &value) { return value - 1; });
|
||||
const auto &pad_mode = pad_mode_;
|
||||
PaddingInfo padding_info{pad_mode, kernel_size, strides, dilation, &padding_l, &padding_r};
|
||||
GetPadding(kernel_node, src_shape, padding_info);
|
||||
|
||||
GetPadding(base_operator, src_shape, padding_info);
|
||||
const auto forward_desc = CreateDesc<dnnl::convolution_forward::desc>(
|
||||
dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides,
|
||||
dilates, padding_l, padding_r);
|
||||
|
@ -104,6 +118,7 @@ void ConvGradFilterCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
AddArgument(DNNL_ARG_SRC, src_desc);
|
||||
AddArgument(DNNL_ARG_DIFF_DST, dst_desc);
|
||||
AddArgument(DNNL_ARG_DIFF_WEIGHTS, weights_desc);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool ConvGradFilterCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
|
|
|
@ -20,18 +20,25 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
#include "plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class ConvGradFilterCpuKernelMod : public DeprecatedMKLCpuKernelMod {
|
||||
class ConvGradFilterCpuKernelMod : public MKLCpuKernelMod {
|
||||
public:
|
||||
ConvGradFilterCpuKernelMod() = default;
|
||||
explicit ConvGradFilterCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
|
||||
~ConvGradFilterCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
@ -42,6 +49,11 @@ class ConvGradFilterCpuKernelMod : public DeprecatedMKLCpuKernelMod {
|
|||
size_t src_index_{0};
|
||||
size_t diff_dst_index_{1};
|
||||
std::string kernel_type_;
|
||||
std::string format_;
|
||||
int64_t group_;
|
||||
std::string pad_mode_;
|
||||
std::vector<int64_t> strides_include_nc_;
|
||||
std::vector<int64_t> dilation_include_nc_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,26 +30,44 @@ constexpr size_t kConvGradInputInputsMinNum = 2;
|
|||
constexpr size_t kConvGradInputOutputsNum = 1;
|
||||
} // namespace
|
||||
|
||||
void ConvGradInputCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
bool ConvGradInputCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
if (kernel_name_ == kConv2DBackpropInputOpName) {
|
||||
weight_index_ = 1;
|
||||
diff_dst_index_ = 0;
|
||||
}
|
||||
std::vector<int64_t> src_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
std::vector<int64_t> weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, weight_index_);
|
||||
std::vector<int64_t> dst_shape = AnfAlgo::GetInputDeviceShape(kernel_node, diff_dst_index_);
|
||||
|
||||
if (AnfAlgo::IsShapesDynamic({src_shape, weight_shape, dst_shape})) {
|
||||
return;
|
||||
auto prim = base_operator->GetPrim();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
group_ = GetValue<int64_t>(prim->GetAttr(GROUP));
|
||||
format_ = GetValue<std::string>(prim->GetAttr(FORMAT));
|
||||
if (format_ != NCHW && format_ != NCDHW) {
|
||||
MS_LOG(EXCEPTION) << kernel_name_ << " only supports " << NCHW << " or " << NCDHW << " format "
|
||||
<< ", but got format: " << format_;
|
||||
}
|
||||
const auto stride_attr = format_ == NCHW ? STRIDE : STRIDES;
|
||||
const auto dilation_attr = format_ == NCHW ? DILATION : DILATIONS;
|
||||
pad_mode_ = GetValue<std::string>(prim->GetAttr(PAD_MODE));
|
||||
strides_include_nc_ = GetValue<std::vector<int64_t>>(prim->GetAttr(stride_attr));
|
||||
dilation_include_nc_ = GetValue<std::vector<int64_t>>(prim->GetAttr(dilation_attr));
|
||||
return true;
|
||||
}
|
||||
|
||||
int ConvGradInputCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
std::vector<int64_t> src_shape = outputs[0]->GetDeviceShapeAdaptively();
|
||||
std::vector<int64_t> weight_shape = inputs[weight_index_]->GetDeviceShapeAdaptively();
|
||||
std::vector<int64_t> dst_shape = inputs[diff_dst_index_]->GetDeviceShapeAdaptively();
|
||||
size_t src_dim = src_shape.size();
|
||||
if (src_dim != SHAPE_4D && src_dim != SHAPE_5D) {
|
||||
MS_LOG(EXCEPTION) << "Conv grad only supports 4D/5D input, but got " << src_dim << "D!";
|
||||
}
|
||||
const auto format = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, FORMAT);
|
||||
const auto &format = format_;
|
||||
if (src_dim == SHAPE_4D && format != NCHW) {
|
||||
MS_LOG(EXCEPTION) << kernel_name_ << " only supports 4D input with NCHW format, but got format" << format;
|
||||
}
|
||||
|
@ -57,7 +75,7 @@ void ConvGradInputCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
MS_LOG(EXCEPTION) << kernel_name_ << " only supports 5D input with NCDHW format, but got format " << format;
|
||||
}
|
||||
dnnl::memory::dims kernel_size(weight_shape.begin() + NC_LEN, weight_shape.end());
|
||||
const size_t group = LongToSize(common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, GROUP));
|
||||
const auto group = group_;
|
||||
if (group > 1) {
|
||||
if (src_shape[1] % group != 0) {
|
||||
MS_LOG(EXCEPTION) << "Conv grad channels must be divided by group!";
|
||||
|
@ -65,19 +83,15 @@ void ConvGradInputCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
(void)weight_shape.insert(weight_shape.begin(), group);
|
||||
weight_shape[1] = weight_shape[1] / group;
|
||||
}
|
||||
|
||||
const dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
|
||||
const dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape);
|
||||
const dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape);
|
||||
const auto stride_attr = src_dim == SHAPE_4D ? STRIDE : STRIDES;
|
||||
const auto dilation_attr = src_dim == SHAPE_4D ? DILATION : DILATIONS;
|
||||
const auto pad_mode = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE);
|
||||
const auto strides_include_nc = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, stride_attr);
|
||||
const auto dilation_include_nc = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, dilation_attr);
|
||||
const auto &strides_include_nc = strides_include_nc_;
|
||||
if (strides_include_nc.size() != src_dim) {
|
||||
MS_LOG(EXCEPTION) << kernel_name_ << "requires strides must be " << src_dim << "D, but got "
|
||||
<< strides_include_nc.size() << "D!";
|
||||
}
|
||||
const auto &dilation_include_nc = dilation_include_nc_;
|
||||
if (dilation_include_nc.size() != src_dim) {
|
||||
MS_LOG(EXCEPTION) << kernel_name_ << " requires dilation must be " << src_dim << "D, but got "
|
||||
<< dilation_include_nc.size() << "D!";
|
||||
|
@ -89,8 +103,9 @@ void ConvGradInputCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
dnnl::memory::dims padding_r;
|
||||
(void)std::transform(dilation.begin(), dilation.end(), std::back_inserter(dilates),
|
||||
[](const int64_t &value) { return value - 1; });
|
||||
const auto &pad_mode = pad_mode_;
|
||||
PaddingInfo padding_info{pad_mode, kernel_size, strides, dilation, &padding_l, &padding_r};
|
||||
GetPadding(kernel_node, src_shape, padding_info);
|
||||
GetPadding(base_operator, src_shape, padding_info);
|
||||
|
||||
const auto forward_desc = CreateDesc<dnnl::convolution_forward::desc>(
|
||||
dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides,
|
||||
|
@ -104,6 +119,7 @@ void ConvGradInputCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
AddArgument(DNNL_ARG_DIFF_SRC, src_desc);
|
||||
AddArgument(DNNL_ARG_DIFF_DST, dst_desc);
|
||||
AddArgument(DNNL_ARG_WEIGHTS, weights_desc);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool ConvGradInputCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
|
|
|
@ -20,18 +20,25 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
#include "plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class ConvGradInputCpuKernelMod : public DeprecatedMKLCpuKernelMod {
|
||||
class ConvGradInputCpuKernelMod : public MKLCpuKernelMod {
|
||||
public:
|
||||
ConvGradInputCpuKernelMod() = default;
|
||||
explicit ConvGradInputCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
|
||||
~ConvGradInputCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
@ -42,6 +49,11 @@ class ConvGradInputCpuKernelMod : public DeprecatedMKLCpuKernelMod {
|
|||
size_t weight_index_{0};
|
||||
size_t diff_dst_index_{1};
|
||||
std::string kernel_type_;
|
||||
std::string format_;
|
||||
std::string pad_mode_;
|
||||
int64_t group_;
|
||||
std::vector<int64_t> strides_include_nc_;
|
||||
std::vector<int64_t> dilation_include_nc_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue