onednn dfx

This commit is contained in:
zuochuanyong 2022-01-14 19:17:47 +08:00
parent 9bb2c9b9b0
commit 9a1f9d69f6
23 changed files with 486 additions and 207 deletions

View File

@ -63,9 +63,9 @@ void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc src0_mem_desc = GetDefaultMemDesc(src0_shape);
dnnl::memory::desc src1_mem_desc = GetDefaultMemDesc(src1_shape);
dnnl::memory::desc dst_mem_desc = GetDefaultMemDesc(dst_shape);
dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_add, src0_mem_desc, src1_mem_desc, dst_mem_desc);
auto prim_desc = dnnl::binary::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::binary>(prim_desc);
auto desc = CreateDesc<dnnl::binary::desc>(dnnl::algorithm::binary_add, src0_mem_desc, src1_mem_desc, dst_mem_desc);
auto prim_desc = CreateDesc<dnnl::binary::primitive_desc>(desc, engine_);
primitive_ = CreatePrimitive<dnnl::binary>(prim_desc);
AddArgument(DNNL_ARG_SRC_0, src0_mem_desc);
AddArgument(DNNL_ARG_SRC_1, src1_mem_desc);
AddArgument(DNNL_ARG_DST, dst_mem_desc);

View File

@ -45,9 +45,9 @@ void AssignAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
dnnl::memory::desc src0_desc = GetDefaultMemDesc(src0_shape);
dnnl::memory::desc src1_desc = GetDefaultMemDesc(src1_shape);
dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_add, src0_desc, src1_desc, src0_desc);
auto prim_desc = dnnl::binary::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::binary>(prim_desc);
auto desc = CreateDesc<dnnl::binary::desc>(dnnl::algorithm::binary_add, src0_desc, src1_desc, src0_desc);
auto prim_desc = CreateDesc<dnnl::binary::primitive_desc>(desc, engine_);
primitive_ = CreatePrimitive<dnnl::binary>(prim_desc);
AddArgument(DNNL_ARG_SRC_0, src0_desc);
AddArgument(DNNL_ARG_SRC_1, src1_desc);
AddArgument(DNNL_ARG_DST, src0_desc);

View File

@ -59,15 +59,17 @@ void BatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
prop_kind = dnnl::prop_kind::forward_training;
normalization_flags = dnnl::normalization_flags::use_scale_shift;
}
dnnl::batch_normalization_forward::desc desc =
dnnl::batch_normalization_forward::desc(prop_kind, x_desc, epsilon, normalization_flags);
auto prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::batch_normalization_forward>(prim_desc);
auto desc = CreateDesc<dnnl::batch_normalization_forward::desc>(prop_kind, x_desc, epsilon, normalization_flags);
auto prim_desc = CreateDesc<dnnl::batch_normalization_forward::primitive_desc>(desc, engine_);
auto wksp_desc = GetWorkspaceDesc(prim_desc);
auto mean = GetMeanDesc(prim_desc);
auto variance = GetVarianceDesc(prim_desc);
primitive_ = CreatePrimitive<dnnl::batch_normalization_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, x_desc);
AddArgument(DNNL_ARG_MEAN, prim_desc.mean_desc());
AddArgument(DNNL_ARG_VARIANCE, prim_desc.variance_desc());
AddArgument(DNNL_ARG_MEAN, mean);
AddArgument(DNNL_ARG_VARIANCE, variance);
AddArgument(DNNL_ARG_SCALE_SHIFT, scale_bias_desc);
AddArgument(DNNL_ARG_WORKSPACE, prim_desc.workspace_desc());
AddArgument(DNNL_ARG_WORKSPACE, wksp_desc);
AddArgument(DNNL_ARG_DST, x_desc);
}

View File

@ -64,21 +64,23 @@ void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
// fused Batch Normalization forward description
dnnl::batch_normalization_forward::desc desc =
dnnl::batch_normalization_forward::desc(prop_kind, x_desc, epsilon, normalization_flags);
auto forward_prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, engine_);
auto desc = CreateDesc<dnnl::batch_normalization_forward::desc>(prop_kind, x_desc, epsilon, normalization_flags);
auto forward_prim_desc = CreateDesc<dnnl::batch_normalization_forward::primitive_desc>(desc, engine_);
// fused Batch Normalization backward description
dnnl::batch_normalization_backward::desc backward_desc =
dnnl::batch_normalization_backward::desc(dnnl::prop_kind::backward, x_desc, x_desc, epsilon, normalization_flags);
auto backward_desc = CreateDesc<dnnl::batch_normalization_backward::desc>(dnnl::prop_kind::backward, x_desc, x_desc,
epsilon, normalization_flags);
auto backward_prim_desc =
dnnl::batch_normalization_backward::primitive_desc(backward_desc, engine_, forward_prim_desc);
primitive_ = std::make_shared<dnnl::batch_normalization_backward>(backward_prim_desc);
CreateDesc<dnnl::batch_normalization_backward::primitive_desc>(backward_desc, engine_, forward_prim_desc);
auto wksp_desc = GetWorkspaceDesc(forward_prim_desc);
auto mean = GetMeanDesc(forward_prim_desc);
auto variance = GetVarianceDesc(forward_prim_desc);
primitive_ = CreatePrimitive<dnnl::batch_normalization_backward>(backward_prim_desc);
AddArgument(DNNL_ARG_SRC, x_desc);
AddArgument(DNNL_ARG_MEAN, forward_prim_desc.mean_desc());
AddArgument(DNNL_ARG_VARIANCE, forward_prim_desc.variance_desc());
AddArgument(DNNL_ARG_MEAN, mean);
AddArgument(DNNL_ARG_VARIANCE, variance);
AddArgument(DNNL_ARG_SCALE_SHIFT, scale_bias_desc);
AddArgument(DNNL_ARG_WORKSPACE, forward_prim_desc.workspace_desc());
AddArgument(DNNL_ARG_WORKSPACE, wksp_desc);
AddArgument(DNNL_ARG_DST, x_desc);
AddArgument(DNNL_ARG_DIFF_DST, x_desc);
AddArgument(DNNL_ARG_DIFF_SRC, x_desc);

View File

@ -81,18 +81,16 @@ void Conv2dGradFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]};
dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]};
dnnl::convolution_forward::desc forward_desc =
dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc,
weights_desc, dst_desc, strides, dilates, padding_l, padding_r);
auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, engine_);
dnnl::convolution_backward_weights::desc backward_desc = dnnl::convolution_backward_weights::desc(
auto forward_desc = CreateDesc<dnnl::convolution_forward::desc>(
dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides,
dilates, padding_l, padding_r);
auto forward_prim_desc = CreateDesc<dnnl::convolution_forward::primitive_desc>(forward_desc, engine_);
auto backward_desc = CreateDesc<dnnl::convolution_backward_weights::desc>(
dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r);
auto backward_prim_desc =
dnnl::convolution_backward_weights::primitive_desc(backward_desc, engine_, forward_prim_desc);
primitive_ = std::make_shared<dnnl::convolution_backward_weights>(backward_prim_desc);
CreateDesc<dnnl::convolution_backward_weights::primitive_desc>(backward_desc, engine_, forward_prim_desc);
primitive_ = CreatePrimitive<dnnl::convolution_backward_weights>(backward_prim_desc);
AddArgument(DNNL_ARG_SRC, src_desc);
AddArgument(DNNL_ARG_DIFF_DST, dst_desc);

View File

@ -94,17 +94,15 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]};
dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]};
dnnl::convolution_forward::desc forward_desc =
dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc,
weights_desc, dst_desc, strides, dilates, padding_l, padding_r);
auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, engine_);
dnnl::convolution_backward_data::desc backward_desc = dnnl::convolution_backward_data::desc(
auto forward_desc = CreateDesc<dnnl::convolution_forward::desc>(
dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides,
dilates, padding_l, padding_r);
auto forward_prim_desc = CreateDesc<dnnl::convolution_forward::primitive_desc>(forward_desc, engine_);
auto backward_desc = CreateDesc<dnnl::convolution_backward_data::desc>(
dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r);
auto backward_prim_desc = dnnl::convolution_backward_data::primitive_desc(backward_desc, engine_, forward_prim_desc);
primitive_ = std::make_shared<dnnl::convolution_backward_data>(backward_prim_desc);
auto backward_prim_desc =
CreateDesc<dnnl::convolution_backward_data::primitive_desc>(backward_desc, engine_, forward_prim_desc);
primitive_ = CreatePrimitive<dnnl::convolution_backward_data>(backward_prim_desc);
AddArgument(DNNL_ARG_DIFF_SRC, src_desc);
AddArgument(DNNL_ARG_DIFF_DST, dst_desc);

View File

@ -102,12 +102,12 @@ void ConvCPUKernel::InitKernel(const CNodePtr &kernel_node) {
(void)padding_l.emplace_back(int_padding_l[i]);
(void)padding_r.emplace_back(int_padding_r[i]);
}
dnnl::convolution_forward::desc desc =
dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc,
weights_desc, dst_desc, strides, dilates, padding_l, padding_r);
auto desc = CreateDesc<dnnl::convolution_forward::desc>(dnnl::prop_kind::forward_training,
dnnl::algorithm::convolution_auto, src_desc, weights_desc,
dst_desc, strides, dilates, padding_l, padding_r);
auto prim_desc = dnnl::convolution_forward::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::convolution_forward>(prim_desc);
auto prim_desc = CreateDesc<dnnl::convolution_forward::primitive_desc>(desc, engine_);
primitive_ = CreatePrimitive<dnnl::convolution_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, src_desc);
AddArgument(DNNL_ARG_WEIGHTS, weights_desc);
AddArgument(DNNL_ARG_DST, dst_desc);

View File

@ -50,8 +50,9 @@ dnnl::eltwise_forward::desc EltWiseCPUKernel::GetForwardEltwiseDesc(const dnnl::
if (desc_pair == eltWiseOpDescMap.end()) {
MS_LOG(EXCEPTION) << "EltWiseCPUKernel does not support " << kernel_name_;
}
return dnnl::eltwise_forward::desc(dnnl_forward_, desc_pair->second.algorithm, src_desc, desc_pair->second.alpha,
desc_pair->second.beta);
auto desc = CreateDesc<dnnl::eltwise_forward::desc>(dnnl_forward_, desc_pair->second.algorithm, src_desc,
desc_pair->second.alpha, desc_pair->second.beta);
return desc;
}
void EltWiseCPUKernel::InitKernel(const CNodePtr &kernel_node) {
@ -64,8 +65,8 @@ void EltWiseCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
auto desc = GetForwardEltwiseDesc(src_desc);
auto prim_desc = dnnl::eltwise_forward::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::eltwise_forward>(prim_desc);
auto prim_desc = CreateDesc<dnnl::eltwise_forward::primitive_desc>(desc, engine_);
primitive_ = CreatePrimitive<dnnl::eltwise_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, src_desc);
AddArgument(DNNL_ARG_DST, src_desc);
}

View File

@ -38,10 +38,9 @@ void LogSoftmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) {
axis += SizeToInt(src_shape.size());
}
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
dnnl::logsoftmax_forward::desc desc =
dnnl::logsoftmax_forward::desc(dnnl::prop_kind::forward_training, src_desc, axis);
auto prim_desc = dnnl::logsoftmax_forward::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::logsoftmax_forward>(prim_desc);
auto desc = CreateDesc<dnnl::logsoftmax_forward::desc>(dnnl::prop_kind::forward_training, src_desc, axis);
auto prim_desc = CreateDesc<dnnl::logsoftmax_forward::primitive_desc>(desc, engine_);
primitive_ = CreatePrimitive<dnnl::logsoftmax_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, src_desc);
AddArgument(DNNL_ARG_DST, src_desc);
}

View File

@ -38,13 +38,12 @@ void LogSoftmaxGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
axis += SizeToInt(src_shape.size());
}
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
dnnl::logsoftmax_forward::desc desc =
dnnl::logsoftmax_forward::desc(dnnl::prop_kind::forward_training, src_desc, axis);
auto prim_desc = dnnl::logsoftmax_forward::primitive_desc(desc, engine_);
auto desc = CreateDesc<dnnl::logsoftmax_forward::desc>(dnnl::prop_kind::forward_training, src_desc, axis);
auto prim_desc = CreateDesc<dnnl::logsoftmax_forward::primitive_desc>(desc, engine_);
// backward description
dnnl::logsoftmax_backward::desc backward_desc = dnnl::logsoftmax_backward::desc(src_desc, src_desc, axis);
auto backward_prim_desc = dnnl::logsoftmax_backward::primitive_desc(backward_desc, engine_, prim_desc);
primitive_ = std::make_shared<dnnl::logsoftmax_backward>(backward_prim_desc);
auto backward_desc = CreateDesc<dnnl::logsoftmax_backward::desc>(src_desc, src_desc, axis);
auto backward_prim_desc = CreateDesc<dnnl::logsoftmax_backward::primitive_desc>(backward_desc, engine_, prim_desc);
primitive_ = CreatePrimitive<dnnl::logsoftmax_backward>(backward_prim_desc);
AddArgument(DNNL_ARG_DST, src_desc);
AddArgument(DNNL_ARG_DIFF_SRC, src_desc);
AddArgument(DNNL_ARG_DIFF_DST, src_desc);

View File

@ -25,6 +25,8 @@ constexpr size_t kLstmInputsNum = 4;
constexpr size_t kLstmOutputsNum = 5;
constexpr int kMaxLSTMLayer = 100;
constexpr int kOutputWorkSpaceIndex = 3;
constexpr int kInputCIndex = 2;
constexpr int kInputWeightIndex = 3;
constexpr int kGateNum = 4;
using tag = dnnl::memory::format_tag;
@ -82,29 +84,46 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
if (!is_training) {
prop_kind = dnnl::prop_kind::forward_inference;
}
auto desc = std::make_shared<dnnl::lstm_forward::desc>(
prop_kind, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any),
formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc);
prim_desc_ = dnnl::lstm_forward::primitive_desc(*desc, eng);
primitive_ = std::make_shared<dnnl::lstm_forward>(prim_desc_);
auto weights_desc = formatted_md(weights_dims_, tag::any);
auto weights_h_desc = formatted_md(weights_h_dims_, tag::any);
auto desc =
CreatePrimitive<dnnl::lstm_forward::desc>(prop_kind, direction, src_desc, src_h_desc, src_c_desc, weights_desc,
weights_h_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc);
prim_desc_ = CreateDesc<dnnl::lstm_forward::primitive_desc>(*desc, eng);
primitive_ = CreatePrimitive<dnnl::lstm_forward>(prim_desc_);
if (is_training) {
reserve_size_ = static_cast<size_t>(prim_desc_.workspace_desc().get_size());
AddArgument(DNNL_ARG_WORKSPACE, prim_desc_.workspace_desc());
auto wksp_desc = GetWorkspaceDesc(prim_desc_);
reserve_size_ = GetSize(wksp_desc);
AddArgument(DNNL_ARG_WORKSPACE, wksp_desc);
} else {
reserve_size_ = 1;
}
auto weights_layer = GetWeightsLayerDesc(prim_desc_);
auto weights_iter = GetWeightsIterDesc(prim_desc_);
bias_desc_ = GetBiasDesc(prim_desc_);
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc);
AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_desc_.weights_layer_desc());
AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_desc_.weights_iter_desc());
AddArgument(DNNL_ARG_WEIGHTS_LAYER, weights_layer);
AddArgument(DNNL_ARG_WEIGHTS_ITER, weights_iter);
AddArgument(DNNL_ARG_BIAS, bias_desc);
AddArgument(DNNL_ARG_DST_LAYER, dst_desc);
AddArgument(DNNL_ARG_DST_ITER, dst_h_desc);
AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc);
using dt = dnnl::memory::data_type;
using tag = dnnl::memory::format_tag;
auto weights_dims_desc = CreateDesc<dnnl::memory::desc>(weights_dims_, dt::f32, tag::ldgoi);
auto weights_h_dims_desc = CreateDesc<dnnl::memory::desc>(weights_h_dims_, dt::f32, tag::ldgoi);
user_weights_memory_ = CreateDesc<dnnl::memory>(weights_dims_desc, eng);
user_weights_h_memory_ = CreateDesc<dnnl::memory>(weights_h_dims_desc, eng);
weights_memory_ = CreateDesc<dnnl::memory>(weights_layer, eng);
weights_h_memory_ = CreateDesc<dnnl::memory>(weights_iter, eng);
bias_memory_ = CreateDesc<dnnl::memory>(bias_desc_, eng);
}
void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) {
constexpr int kBidirectional = 2;
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
@ -117,7 +136,7 @@ void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) {
seq_len_ = SizeToInt(src_shape[0]);
num_directions_ = 1;
if (bidirectional_) {
num_directions_ = 2;
num_directions_ = kBidirectional;
}
const int gate_size = kGateNum * hidden_size_;
if (num_layers_ <= 0) {
@ -142,33 +161,26 @@ void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) {
bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
using dt = dnnl::memory::data_type;
using tag = dnnl::memory::format_tag;
auto eng = engine_;
auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng);
auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
auto weights_memory = dnnl::memory(prim_desc_.weights_layer_desc(), eng);
auto weights_h_memory = dnnl::memory(prim_desc_.weights_iter_desc(), eng);
user_weights_memory.set_data_handle(inputs[3]->addr);
user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_);
Reorder(&user_weights_memory, &weights_memory);
Reorder(&user_weights_h_memory, &weights_h_memory);
auto bias_memory = dnnl::memory(prim_desc_.bias_desc(), eng);
SetDataHandle(user_weights_memory_, inputs[kInputWeightIndex]->addr);
SetDataHandle(user_weights_h_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + weight_size_);
Reorder(&user_weights_memory_, &weights_memory_);
Reorder(&user_weights_h_memory_, &weights_h_memory_);
if (has_bias_) {
bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
SetDataHandle(bias_memory_,
reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + weight_size_ + weight_h_size_);
} else {
if (memset_s(bias_memory.get_data_handle(), prim_desc_.bias_desc().get_size(), 0,
prim_desc_.bias_desc().get_size())) {
auto size = GetSize(bias_desc_);
if (memset_s(GetDataHandle(bias_memory_), size, 0, size)) {
MS_LOG(EXCEPTION) << "Bias memset error";
}
}
// set handle
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr);
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[kInputCIndex]->addr);
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, GetDataHandle(weights_memory_));
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, GetDataHandle(weights_h_memory_));
SetArgumentHandle(DNNL_ARG_BIAS, GetDataHandle(bias_memory_));
SetArgumentHandle(DNNL_ARG_DST_LAYER, outputs[0]->addr);
SetArgumentHandle(DNNL_ARG_DST_ITER, outputs[1]->addr);
SetArgumentHandle(DNNL_ARG_DST_ITER_C, outputs[2]->addr);

View File

@ -55,6 +55,12 @@ class LstmCPUKernel : public MKLCPUKernel {
dnnl::memory::dims weights_h_dims_;
dnnl::memory::dims bias_dims_;
dnnl::lstm_forward::primitive_desc prim_desc_;
dnnl::memory::desc bias_desc_;
dnnl::memory user_weights_memory_;
dnnl::memory user_weights_h_memory_;
dnnl::memory weights_memory_;
dnnl::memory weights_h_memory_;
dnnl::memory bias_memory_;
};
MS_REG_CPU_KERNEL(LSTM,

View File

@ -27,6 +27,19 @@ constexpr size_t kLstmGradInputsNum = 11;
constexpr size_t kLstmGradOutputsNum = 4;
constexpr int kMaxLSTMLayer = 100;
constexpr int kInputWorkSpaceIndex = 10;
constexpr int kInputWeightIndex = 3;
constexpr int kOutputWeightIndex = 3;
constexpr int kSrcLayerIdx = 0;
constexpr int kSrcIterIdx = 1;
constexpr int kSrcIterCIdx = 2;
constexpr int kDstLayerIdx = 4;
constexpr int kDstIterIdx = 5;
constexpr int kDstIterCIdx = 6;
constexpr int kDiffDstLayerIdx = 7;
constexpr int kDiffDstIterIdx = 8;
constexpr int kDiffDstIterCIdx = 9;
constexpr int kWorkspaceIdx = 10;
using tag = dnnl::memory::format_tag;
using dim = dnnl::memory::dims;
@ -63,21 +76,45 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
auto forward_desc = std::make_shared<dnnl::lstm_forward::desc>(
dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc,
dst_c_desc);
auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(*forward_desc, eng);
auto backward_desc = std::make_shared<dnnl::lstm_backward::desc>(
dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any),
formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc,
src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc,
dst_h_desc, dst_c_desc);
prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(*backward_desc, eng, prim_forward_desc);
primitive_ = std::make_shared<dnnl::lstm_backward>(prim_backward_desc_);
reserve_size_ = static_cast<size_t>(prim_forward_desc.workspace_desc().get_size());
AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc());
auto weights_desc = formatted_md(weights_dims_, tag::any);
auto weights_h_desc = formatted_md(weights_h_dims_, tag::any);
auto forward_desc = CreatePrimitive<dnnl::lstm_forward::desc>(dnnl::prop_kind::forward_training, direction, src_desc,
src_h_desc, src_c_desc, weights_desc, weights_h_desc,
bias_desc, dst_desc, dst_h_desc, dst_c_desc);
auto prim_forward_desc = CreateDesc<dnnl::lstm_forward::primitive_desc>(*forward_desc, eng);
auto backward_desc = CreatePrimitive<dnnl::lstm_backward::desc>(
dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, weights_desc, weights_h_desc, bias_desc,
dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc, src_c_desc, weights_desc, weights_h_desc, bias_desc,
dst_desc, dst_h_desc, dst_c_desc);
prim_backward_desc_ = CreateDesc<dnnl::lstm_backward::primitive_desc>(*backward_desc, eng, prim_forward_desc);
primitive_ = CreatePrimitive<dnnl::lstm_backward>(prim_backward_desc_);
auto wksp_desc = GetWorkspaceDesc(prim_forward_desc);
reserve_size_ = GetSize(wksp_desc);
AddArgument(DNNL_ARG_WORKSPACE, wksp_desc);
AddArgumentOp(src_desc, src_h_desc, src_c_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc);
// construct fw memory
weights_layer_desc_ = GetWeightsLayerDesc(prim_backward_desc_);
weights_iter_desc_ = GetWeightsIterDesc(prim_backward_desc_);
bias_desc_ = GetBiasDesc(prim_backward_desc_);
auto weights_mem_desc = CreateDesc<dnnl::memory::desc>(weights_dims_, dt::f32, tag::ldgoi);
auto weights_h_mem_desc = CreateDesc<dnnl::memory::desc>(weights_h_dims_, dt::f32, tag::ldgoi);
user_weights_memory_ = CreateDesc<dnnl::memory>(weights_mem_desc, eng);
user_weights_h_memory_ = CreateDesc<dnnl::memory>(weights_h_mem_desc, eng);
weights_memory_ = CreateDesc<dnnl::memory>(weights_layer_desc_, eng);
weights_h_memory_ = CreateDesc<dnnl::memory>(weights_iter_desc_, eng);
bias_memory_ = CreateDesc<dnnl::memory>(bias_desc_, eng);
// construct bw memory
diff_weights_layer_desc_ = GetDiffWeightsLayerDesc(prim_backward_desc_);
diff_weights_iter_desc_ = GetDiffWeightsIterDesc(prim_backward_desc_);
diff_bias_desc_ = GetDiffBiasDesc(prim_backward_desc_);
diff_weights_memory_ = CreateDesc<dnnl::memory>(diff_weights_layer_desc_, eng);
diff_weights_h_memory_ = CreateDesc<dnnl::memory>(diff_weights_iter_desc_, eng);
diff_bias_memory_ = CreateDesc<dnnl::memory>(diff_bias_desc_, eng);
user_diff_weights_memory_ = CreateDesc<dnnl::memory>(weights_mem_desc, eng);
user_diff_weights_h_memory_ = CreateDesc<dnnl::memory>(weights_h_mem_desc, eng);
}
void LSTMGradCPUKernel::AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_h_desc,
@ -87,8 +124,8 @@ void LSTMGradCPUKernel::AddArgumentOp(const dnnl::memory::desc &src_desc, const
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc);
AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_backward_desc_.weights_layer_desc());
AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_backward_desc_.weights_iter_desc());
AddArgument(DNNL_ARG_WEIGHTS_LAYER, weights_layer_desc_);
AddArgument(DNNL_ARG_WEIGHTS_ITER, weights_iter_desc_);
AddArgument(DNNL_ARG_BIAS, bias_desc);
AddArgument(DNNL_ARG_DST_LAYER, dst_desc);
AddArgument(DNNL_ARG_DST_ITER, dst_h_desc);
@ -96,8 +133,8 @@ void LSTMGradCPUKernel::AddArgumentOp(const dnnl::memory::desc &src_desc, const
AddArgument(DNNL_ARG_DIFF_SRC_LAYER, src_desc);
AddArgument(DNNL_ARG_DIFF_SRC_ITER, src_h_desc);
AddArgument(DNNL_ARG_DIFF_SRC_ITER_C, src_c_desc);
AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, prim_backward_desc_.diff_weights_layer_desc());
AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, prim_backward_desc_.diff_weights_iter_desc());
AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_layer_desc_);
AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_iter_desc_);
AddArgument(DNNL_ARG_DIFF_BIAS, bias_desc);
AddArgument(DNNL_ARG_DIFF_DST_LAYER, dst_desc);
AddArgument(DNNL_ARG_DIFF_DST_ITER, dst_h_desc);
@ -141,34 +178,33 @@ void LSTMGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {
}
void LSTMGradCPUKernel::SetArgumentHandleOp(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs,
const dnnl::memory &weights_memory, const dnnl::memory &weights_h_memory,
const dnnl::memory &bias_memory, const dnnl::memory &diff_weights_memory,
const dnnl::memory &diff_weights_h_memory,
const dnnl::memory &diff_bias_memory) {
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr);
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr);
SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr);
SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr);
SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr);
const std::vector<kernel::AddressPtr> &outputs) {
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[kSrcLayerIdx]->addr);
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[kSrcIterIdx]->addr);
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[kSrcIterCIdx]->addr);
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, GetDataHandle(weights_memory_));
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, GetDataHandle(weights_h_memory_));
SetArgumentHandle(DNNL_ARG_BIAS, GetDataHandle(bias_memory_));
SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[kDstLayerIdx]->addr);
SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[kDstIterIdx]->addr);
SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[kDstIterCIdx]->addr);
SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[kWorkspaceIdx]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[kSrcLayerIdx]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[kSrcIterIdx]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[kSrcIterCIdx]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, GetDataHandle(diff_weights_memory_));
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, GetDataHandle(diff_weights_h_memory_));
SetArgumentHandle(DNNL_ARG_DIFF_BIAS, GetDataHandle(diff_bias_memory_));
SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[kDiffDstLayerIdx]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[kDiffDstIterIdx]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[kDiffDstIterCIdx]->addr);
}
void LSTMGradCPUKernel::ResetMemory(const dnnl::memory &mem, const string name) const {
if (memset_s(mem.get_data_handle(), mem.get_desc().get_size(), 0, mem.get_desc().get_size())) {
auto dst_ptr = GetDataHandle(mem);
auto mem_desc = GetMemDesc(mem);
auto size = GetSize(mem_desc);
if (memset_s(dst_ptr, size, 0, size)) {
MS_LOG(EXCEPTION) << name << " memset error";
}
}
@ -177,49 +213,41 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, co
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kLstmGradInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kLstmGradOutputsNum, kernel_name_);
auto eng = engine_;
// construct fw memory
auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng);
auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
auto weights_memory = dnnl::memory(prim_backward_desc_.weights_layer_desc(), eng);
auto weights_h_memory = dnnl::memory(prim_backward_desc_.weights_iter_desc(), eng);
auto bias_memory = dnnl::memory(prim_backward_desc_.bias_desc(), eng);
user_weights_memory.set_data_handle(inputs[3]->addr);
user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_);
Reorder(&user_weights_memory, &weights_memory);
Reorder(&user_weights_h_memory, &weights_h_memory);
SetDataHandle(user_weights_memory_, inputs[kInputWeightIndex]->addr);
SetDataHandle(user_weights_h_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + weight_size_);
Reorder(&user_weights_memory_, &weights_memory_);
Reorder(&user_weights_h_memory_, &weights_h_memory_);
if (has_bias_) {
bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
SetDataHandle(bias_memory_,
reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + weight_size_ + weight_h_size_);
} else {
if (memset_s(bias_memory.get_data_handle(), prim_backward_desc_.bias_desc().get_size(), 0,
prim_backward_desc_.bias_desc().get_size())) {
auto dst_ptr = GetDataHandle(bias_memory_);
auto size = GetSize(bias_desc_);
if (memset_s(dst_ptr, size, 0, size)) {
MS_LOG(EXCEPTION) << "Bias memset error";
}
}
// construct bw memory
auto diff_weights_memory = dnnl::memory(prim_backward_desc_.diff_weights_layer_desc(), eng);
auto diff_weights_h_memory = dnnl::memory(prim_backward_desc_.diff_weights_iter_desc(), eng);
auto diff_bias_memory = dnnl::memory(prim_backward_desc_.diff_bias_desc(), eng);
auto user_diff_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng);
auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
user_diff_weights_memory.set_data_handle(outputs[3]->addr);
user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_);
ResetMemory(user_diff_weights_memory, "user weights grad");
ResetMemory(user_diff_weights_h_memory, "user weights iter grad");
ResetMemory(diff_weights_memory, "weights grad");
ResetMemory(diff_weights_h_memory, "weights iter grad");
SetDataHandle(user_diff_weights_memory_, outputs[kOutputWeightIndex]->addr);
SetDataHandle(user_diff_weights_h_memory_,
reinterpret_cast<float *>(outputs[kOutputWeightIndex]->addr) + weight_size_);
ResetMemory(user_diff_weights_memory_, "user weights grad");
ResetMemory(user_diff_weights_h_memory_, "user weights iter grad");
ResetMemory(diff_weights_memory_, "weights grad");
ResetMemory(diff_weights_h_memory_, "weights iter grad");
if (has_bias_) {
diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_);
SetDataHandle(diff_bias_memory_,
reinterpret_cast<float *>(outputs[kOutputWeightIndex]->addr) + weight_size_ + weight_h_size_);
}
if (memset_s(diff_bias_memory.get_data_handle(), prim_backward_desc_.diff_bias_desc().get_size(), 0,
prim_backward_desc_.diff_bias_desc().get_size())) {
auto dst_ptr = GetDataHandle(diff_bias_memory_);
auto size = GetSize(diff_bias_desc_);
if (memset_s(dst_ptr, size, 0, size)) {
MS_LOG(EXCEPTION) << "Bias grad memset error";
}
SetArgumentHandleOp(inputs, outputs, weights_memory, weights_h_memory, bias_memory, diff_weights_memory,
diff_weights_h_memory, diff_bias_memory);
SetArgumentHandleOp(inputs, outputs);
ExecutePrimitive();
Reorder(&diff_weights_memory, &user_diff_weights_memory);
Reorder(&diff_weights_h_memory, &user_diff_weights_h_memory);
Reorder(&diff_weights_memory_, &user_diff_weights_memory_);
Reorder(&diff_weights_h_memory_, &user_diff_weights_h_memory_);
return true;
}
} // namespace kernel

View File

@ -42,10 +42,7 @@ class LSTMGradCPUKernel : public MKLCPUKernel {
const dnnl::memory::desc &dst_desc, const dnnl::memory::desc &dst_h_desc,
const dnnl::memory::desc &dst_c_desc);
void SetArgumentHandleOp(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs, const dnnl::memory &weights_memory,
const dnnl::memory &weights_h_memory, const dnnl::memory &bias_memory,
const dnnl::memory &diff_weights_memory, const dnnl::memory &diff_weights_h_memory,
const dnnl::memory &diff_bias_memory);
const std::vector<kernel::AddressPtr> &outputs);
void ResetMemory(const dnnl::memory &mem, const string name) const;
void CheckParam(const CNodePtr &kernel_node);
@ -65,6 +62,23 @@ class LSTMGradCPUKernel : public MKLCPUKernel {
dnnl::memory::dims weights_h_dims_;
dnnl::memory::dims bias_dims_;
dnnl::lstm_backward::primitive_desc prim_backward_desc_;
dnnl::memory::desc weights_layer_desc_;
dnnl::memory::desc weights_iter_desc_;
dnnl::memory::desc bias_desc_;
dnnl::memory::desc diff_weights_layer_desc_;
dnnl::memory::desc diff_weights_iter_desc_;
dnnl::memory::desc diff_bias_desc_;
dnnl::memory user_weights_memory_;
dnnl::memory user_weights_h_memory_;
dnnl::memory weights_memory_;
dnnl::memory weights_h_memory_;
dnnl::memory bias_memory_;
dnnl::memory diff_weights_memory_;
dnnl::memory diff_weights_h_memory_;
dnnl::memory diff_bias_memory_;
dnnl::memory user_diff_weights_memory_;
dnnl::memory user_diff_weights_h_memory_;
};
MS_REG_CPU_KERNEL(LSTMGrad,

View File

@ -76,12 +76,12 @@ void MatMulCPUKernel::InitKernel(const CNodePtr &kernel_node) {
o_strides = {dim_n, 1};
}
dnnl::memory::desc src_md(src_dims, dnnl::memory::data_type::f32, a_strides);
dnnl::memory::desc weights_md(weights_dims, dnnl::memory::data_type::f32, b_strides);
dnnl::memory::desc dst_md(dst_dims, dnnl::memory::data_type::f32, o_strides);
dnnl::matmul::desc matmul_desc(src_md, weights_md, dst_md);
dnnl::matmul::primitive_desc prim_desc(matmul_desc, engine_);
primitive_ = std::make_shared<dnnl::matmul>(prim_desc);
auto src_md = CreateDesc<dnnl::memory::desc>(src_dims, dnnl::memory::data_type::f32, a_strides);
auto weights_md = CreateDesc<dnnl::memory::desc>(weights_dims, dnnl::memory::data_type::f32, b_strides);
auto dst_md = CreateDesc<dnnl::memory::desc>(dst_dims, dnnl::memory::data_type::f32, o_strides);
auto matmul_desc = CreateDesc<dnnl::matmul::desc>(src_md, weights_md, dst_md);
auto prim_desc = CreateDesc<dnnl::matmul::primitive_desc>(matmul_desc, engine_);
primitive_ = CreatePrimitive<dnnl::matmul>(prim_desc);
AddArgument(DNNL_ARG_SRC, src_md);
AddArgument(DNNL_ARG_WEIGHTS, weights_md);

View File

@ -134,7 +134,7 @@ dnnl::memory::desc MKLCPUKernel::GetDefaultMemDesc(const std::vector<size_t> &sh
(void)dims.insert(dims.end(), shape.begin(), shape.end());
}
dnnl::memory::format_tag mem_tag = GetDefaultFormatTag(dims);
dnnl::memory::desc mem_desc(dims, dnnl::memory::data_type::f32, mem_tag);
auto mem_desc = CreateDesc<dnnl::memory::desc>(dims, dnnl::memory::data_type::f32, mem_tag);
return mem_desc;
}
@ -149,7 +149,9 @@ void MKLCPUKernel::AddArgument(int arg_key, const dnnl::memory::desc &mem_desc,
void MKLCPUKernel::SetArgumentHandle(int arg_key, void *ptr) {
auto arg_iter = arguments_.find(arg_key);
if (arg_iter != arguments_.end()) {
MS_LOG(DEBUG) << "begin to invoke dnnl::memory::set_data_handle";
arg_iter->second.set_data_handle(ptr);
MS_LOG(DEBUG) << "end to invoke dnnl::memory::set_data_handle";
}
}
@ -169,7 +171,9 @@ void MKLCPUKernel::ExecutePrimitive() {
}
double start_time = GetTime();
mkl_pool->set_num_threads(current_thread_nums);
MS_LOG(DEBUG) << "begin to invoke primitive::execute";
primitive_->execute(stream_, arguments_);
MS_LOG(DEBUG) << "end to invoke primitive::execute";
double cost_time = GetTime() - start_time;
parallel_search_info_.tmp_sum_cost_time += cost_time;
parallel_search_info_.search_count++;
@ -184,16 +188,45 @@ void MKLCPUKernel::ExecutePrimitive() {
} else {
int best_thread_nums = static_cast<int>(std::pow(2.0f, parallel_search_info_.best_pow));
mkl_pool->set_num_threads(best_thread_nums);
MS_LOG(DEBUG) << "begin to invoke primitive::execute";
primitive_->execute(stream_, arguments_);
MS_LOG(DEBUG) << "end to invoke primitive::execute";
}
#else
MS_LOG(DEBUG) << "begin to invoke primitive::execute";
primitive_->execute(stream_, arguments_);
MS_LOG(DEBUG) << "end to invoke primitive::execute";
#endif
(void)stream_.wait();
}
void MKLCPUKernel::SetDataHandle(dnnl::memory mem, void *ptr) {
MS_LOG(DEBUG) << "begin to invoke dnnl::memory::set_data_handle";
mem.set_data_handle(ptr);
MS_LOG(DEBUG) << "end to invoke dnnl::memory::set_data_handle";
}
void *MKLCPUKernel::GetDataHandle(const dnnl::memory &mem) const {
MS_LOG(DEBUG) << "begin to invoke dnnl::memory::get_data_handle";
auto ptr = mem.get_data_handle();
MS_LOG(DEBUG) << "end to invoke dnnl::memory::get_data_handle";
return ptr;
}
size_t MKLCPUKernel::GetSize(const dnnl::memory::desc &desc) const {
MS_LOG(DEBUG) << "begin to invoke dnnl::memory::desc::get_size()";
auto size = desc.get_size();
MS_LOG(DEBUG) << "end to invoke dnnl::memory::desc::get_size()";
return size;
}
void MKLCPUKernel::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) {
dnnl::reorder(*src_mem, *dst_mem).execute(stream_, *src_mem, *dst_mem);
MS_LOG(DEBUG) << "begin to invoke constructor of dnnl::reorder";
auto desc = dnnl::reorder(*src_mem, *dst_mem);
MS_LOG(DEBUG) << "end to invoke constructor of dnnl::reorder";
MS_LOG(DEBUG) << "begin to invoke primitive::execute";
desc.execute(stream_, *src_mem, *dst_mem);
MS_LOG(DEBUG) << "begin to invoke primitive::execute";
}
} // namespace kernel
} // namespace mindspore

View File

@ -22,6 +22,7 @@
#include <algorithm>
#include <memory>
#include <vector>
#include <utility>
#include "dnnl.hpp"
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
@ -32,6 +33,102 @@
namespace mindspore {
namespace kernel {
template <class T, class... Args>
auto CreateDesc(Args &&... args) {
MS_LOG(DEBUG) << "begin to invoke constructor of " << demangle(typeid(T).name());
auto desc = T(std::forward<Args>(args)...);
MS_LOG(DEBUG) << "end to invoke constructor of " << demangle(typeid(T).name());
return desc;
}
template <class T, class... Args>
auto CreatePrimitive(Args &&... args) {
MS_LOG(DEBUG) << "begin to invoke constructor of " << demangle(typeid(T).name());
auto prim = std::make_shared<T>(std::forward<Args>(args)...);
MS_LOG(DEBUG) << "end to invoke constructor of " << demangle(typeid(T).name());
return prim;
}
template <class T>
auto GetWorkspaceDesc(const T &prim_desc) {
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::workspace_desc()";
auto desc = prim_desc.workspace_desc();
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::workspace_desc()";
return desc;
}
template <class T>
auto GetMeanDesc(const T &prim_desc) {
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::mean_desc()";
auto desc = prim_desc.mean_desc();
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::mean_desc()";
return desc;
}
template <class T>
auto GetVarianceDesc(const T &prim_desc) {
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::variance_desc()";
auto desc = prim_desc.variance_desc();
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::variance_desc()";
return desc;
}
template <class T>
auto GetWeightsLayerDesc(const T &prim_desc) {
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::weights_layer_desc()";
auto desc = prim_desc.weights_layer_desc();
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::weights_layer_desc()";
return desc;
}
template <class T>
auto GetWeightsIterDesc(const T &prim_desc) {
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::weights_iter_desc()";
auto desc = prim_desc.weights_iter_desc();
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::weights_iter_desc()";
return desc;
}
template <class T>
auto GetBiasDesc(const T &prim_desc) {
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::bias_desc()";
auto desc = prim_desc.bias_desc();
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::bias_desc()";
return desc;
}
template <class T>
auto GetDiffWeightsLayerDesc(const T &prim_desc) {
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::diff_weights_layer_desc()";
auto desc = prim_desc.diff_weights_layer_desc();
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::diff_weights_layer_desc()";
return desc;
}
template <class T>
auto GetDiffWeightsIterDesc(const T &prim_desc) {
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::diff_weights_iter_desc()";
auto desc = prim_desc.diff_weights_iter_desc();
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::diff_weights_iter_desc()";
return desc;
}
template <class T>
auto GetDiffBiasDesc(const T &prim_desc) {
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::diff_bias_desc()";
auto desc = prim_desc.diff_bias_desc();
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::diff_bias_desc()";
return desc;
}
template <class T>
auto GetMemDesc(const T &prim_desc) {
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::get_desc()";
auto desc = prim_desc.get_desc();
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::get_desc()";
return desc;
}
#ifdef USE_MS_THREADPOOL_FOR_DNNL
class mkl_threadpool : public dnnl::threadpool_interop::threadpool_iface {
private:
@ -62,7 +159,9 @@ class MKLCPUKernel : public CPUKernel {
MKLCPUKernel() : engine_(dnnl::engine::kind::cpu, 0) {
auto thread_pool = GetActorMgrInnerThreadPool();
mkl_threadpool_ = std::make_shared<mkl_threadpool>(thread_pool);
MS_LOG(DEBUG) << "begin to invoke dnnl::threadpool_interop::make_stream";
stream_ = dnnl::threadpool_interop::make_stream(engine_, mkl_threadpool_.get());
MS_LOG(DEBUG) << "end to invoke dnnl::threadpool_interop::make_stream";
}
#else
MKLCPUKernel() : engine_(dnnl::engine::kind::cpu, 0), stream_(engine_) {}
@ -81,10 +180,17 @@ class MKLCPUKernel : public CPUKernel {
dnnl::memory::desc GetDefaultMemDesc(const std::vector<size_t> &shape) const;
void ExecutePrimitive();
inline dnnl::memory::desc formatted_md(const dnnl::memory::dims &dimensions, dnnl::memory::format_tag layout) {
return dnnl::memory::desc{{dimensions}, dnnl::memory::data_type::f32, layout};
MS_LOG(DEBUG) << "begin to invoke constructor of dnnl::memory::desc";
auto desc = dnnl::memory::desc{{dimensions}, dnnl::memory::data_type::f32, layout};
MS_LOG(DEBUG) << "end to invoke constructor of dnnl::memory::desc";
return desc;
}
void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem);
size_t GetSize(const dnnl::memory::desc &desc) const;
void SetDataHandle(dnnl::memory mem, void *ptr);
void *GetDataHandle(const dnnl::memory &mem) const;
std::unordered_map<int, dnnl::memory> arguments_;
std::shared_ptr<dnnl::primitive> primitive_{nullptr};
dnnl::engine engine_;

View File

@ -65,17 +65,17 @@ void AvgPoolingGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]};
// pooling_avg forward description
dnnl::pooling_forward::desc desc =
dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_avg, src_desc, dst_desc,
strides_dims, kernels_dims, padding_l, padding_r);
auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, engine_);
auto desc =
CreateDesc<dnnl::pooling_forward::desc>(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_avg, src_desc,
dst_desc, strides_dims, kernels_dims, padding_l, padding_r);
auto prim_desc = CreateDesc<dnnl::pooling_forward::primitive_desc>(desc, engine_);
// pooling_avg backward description
dnnl::pooling_backward::desc backward_desc = dnnl::pooling_backward::desc(
dnnl::algorithm::pooling_avg, src_desc, dst_desc, strides_dims, kernels_dims, padding_l, padding_r);
auto backward_prim_desc = dnnl::pooling_backward::primitive_desc(backward_desc, engine_, prim_desc);
auto backward_desc = CreateDesc<dnnl::pooling_backward::desc>(dnnl::algorithm::pooling_avg, src_desc, dst_desc,
strides_dims, kernels_dims, padding_l, padding_r);
auto backward_prim_desc = CreateDesc<dnnl::pooling_backward::primitive_desc>(backward_desc, engine_, prim_desc);
primitive_ = std::make_shared<dnnl::pooling_backward>(backward_prim_desc);
primitive_ = CreatePrimitive<dnnl::pooling_backward>(backward_prim_desc);
AddArgument(DNNL_ARG_SRC, src_desc);
AddArgument(DNNL_ARG_DST, dst_desc);
AddArgument(DNNL_ARG_DIFF_SRC, src_desc);

View File

@ -76,19 +76,22 @@ void PoolingCPUKernel::InitKernel(const CNodePtr &kernel_node) {
(void)padding_l.emplace_back(int_padding_l[i]);
(void)padding_r.emplace_back(int_padding_r[i]);
}
dnnl::pooling_forward::desc desc =
dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_max, src_desc, dst_desc,
strides_dims, kernels_dims, padding_l, padding_r);
auto desc =
CreateDesc<dnnl::pooling_forward::desc>(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_max, src_desc,
dst_desc, strides_dims, kernels_dims, padding_l, padding_r);
if (kernel_name_ == prim::kPrimAvgPool->name()) {
desc = dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_avg, src_desc,
dst_desc, strides_dims, kernels_dims, padding_l, padding_r);
desc =
CreateDesc<dnnl::pooling_forward::desc>(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_avg, src_desc,
dst_desc, strides_dims, kernels_dims, padding_l, padding_r);
}
auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, engine_);
workspace_size_ = prim_desc.workspace_desc().get_size();
primitive_ = std::make_shared<dnnl::pooling_forward>(prim_desc);
auto prim_desc = CreateDesc<dnnl::pooling_forward::primitive_desc>(desc, engine_);
auto wksp_desc = GetWorkspaceDesc(prim_desc);
workspace_size_ = GetSize(wksp_desc);
primitive_ = CreatePrimitive<dnnl::pooling_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, src_desc);
AddArgument(DNNL_ARG_DST, dst_desc);
AddArgument(DNNL_ARG_WORKSPACE, prim_desc.workspace_desc());
AddArgument(DNNL_ARG_WORKSPACE, wksp_desc);
}
bool PoolingCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,

View File

@ -45,9 +45,9 @@ void SoftmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) {
axis += SizeToInt(src_shape.size());
}
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, src_desc, axis);
auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::softmax_forward>(prim_desc);
auto desc = CreateDesc<dnnl::softmax_forward::desc>(dnnl::prop_kind::forward_training, src_desc, axis);
auto prim_desc = CreateDesc<dnnl::softmax_forward::primitive_desc>(desc, engine_);
primitive_ = CreatePrimitive<dnnl::softmax_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, src_desc);
AddArgument(DNNL_ARG_DST, src_desc);
}

View File

@ -52,11 +52,11 @@ void SoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_n
if (batch_size_ == 0 || class_num_ == 0) {
MS_LOG(EXCEPTION) << "Invalid batch size or class num input!";
}
dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc);
auto mem_desc = CreateDesc<dnnl::memory::desc>(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc);
dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1);
auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::softmax_forward>(prim_desc);
auto desc = CreateDesc<dnnl::softmax_forward::desc>(dnnl::prop_kind::forward_training, mem_desc, 1);
auto prim_desc = CreateDesc<dnnl::softmax_forward::primitive_desc>(desc, engine_);
primitive_ = CreatePrimitive<dnnl::softmax_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, mem_desc);
AddArgument(DNNL_ARG_DST, mem_desc);

View File

@ -58,11 +58,11 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &ke
MS_LOG(EXCEPTION) << "Invalid batch size or class num input!";
}
is_grad_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, IS_GRAD);
dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc);
auto mem_desc = CreateDesc<dnnl::memory::desc>(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc);
dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1);
auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::softmax_forward>(prim_desc);
auto desc = CreateDesc<dnnl::softmax_forward::desc>(dnnl::prop_kind::forward_training, mem_desc, 1);
auto prim_desc = CreateDesc<dnnl::softmax_forward::primitive_desc>(desc, engine_);
primitive_ = CreatePrimitive<dnnl::softmax_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, mem_desc);
AddArgument(DNNL_ARG_DST, mem_desc);

View File

@ -0,0 +1,78 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import re
from collections import defaultdict
import pytest
import mindspore.context as context
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_onednn_dfx_log():
"""
Feature: onednn dfx
Description: Use 'Python -O -m pytest -s xxx.py' to enable __debug__
Expectation: begin_cnt == end_cnt
"""
os.environ['GLOG_v'] = '0'
log_name = './onednn_dfx.log'
# need cd to current dir
cmd = "pytest -s test_maxpool_op.py::test_max_pool3d_1 --disable-warnings > {} 2>&1".format(log_name)
out = os.popen(cmd)
out.read()
keyword_begin = "begin to invoke"
keyword_end = "end to invoke"
begin_cnt = 0
end_cnt = 0
log_pattern = re.compile(r'\[[A-Z]+\] [A-Z]+\([\d]+\,[a-f\d]+,')
# {tid0: [log_line0, log_line1], tid1: [log_line2, log_line3, log_line4], ...}
multi_dict = defaultdict(list)
with open(log_name, "r") as f:
for line in f.readlines():
if not log_pattern.match(line):
continue
# log example: '[DEBUG] KERNEL(48810, 7f42b77fd700,python):2022-01-10-16:19:13.521.086 xxxx.'
tid = line.split(' ')[1].split(',')[1]
line = line.replace('\n', '').replace('\r', '')
multi_dict[tid].append(line)
assert multi_dict.keys() # check empty
for tid in multi_dict.keys():
if not __debug__:
f = open('./onednn_dfx-{}.log'.format(tid), "w")
for line in multi_dict[tid]:
if not __debug__:
f.write(line + '\n')
if keyword_begin in line:
begin_cnt += 1
last_begin_line = line
elif keyword_end in line:
end_cnt += 1
if begin_cnt != end_cnt:
print("\n=======The line below has no pairing end line======")
print(last_begin_line)
print("===================================================")
assert begin_cnt == end_cnt
if not __debug__:
f.close()