forked from mindspore-Ecosystem/mindspore
onednn dfx
This commit is contained in:
parent
9bb2c9b9b0
commit
9a1f9d69f6
|
@ -63,9 +63,9 @@ void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
dnnl::memory::desc src0_mem_desc = GetDefaultMemDesc(src0_shape);
|
||||
dnnl::memory::desc src1_mem_desc = GetDefaultMemDesc(src1_shape);
|
||||
dnnl::memory::desc dst_mem_desc = GetDefaultMemDesc(dst_shape);
|
||||
dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_add, src0_mem_desc, src1_mem_desc, dst_mem_desc);
|
||||
auto prim_desc = dnnl::binary::primitive_desc(desc, engine_);
|
||||
primitive_ = std::make_shared<dnnl::binary>(prim_desc);
|
||||
auto desc = CreateDesc<dnnl::binary::desc>(dnnl::algorithm::binary_add, src0_mem_desc, src1_mem_desc, dst_mem_desc);
|
||||
auto prim_desc = CreateDesc<dnnl::binary::primitive_desc>(desc, engine_);
|
||||
primitive_ = CreatePrimitive<dnnl::binary>(prim_desc);
|
||||
AddArgument(DNNL_ARG_SRC_0, src0_mem_desc);
|
||||
AddArgument(DNNL_ARG_SRC_1, src1_mem_desc);
|
||||
AddArgument(DNNL_ARG_DST, dst_mem_desc);
|
||||
|
|
|
@ -45,9 +45,9 @@ void AssignAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
dnnl::memory::desc src0_desc = GetDefaultMemDesc(src0_shape);
|
||||
dnnl::memory::desc src1_desc = GetDefaultMemDesc(src1_shape);
|
||||
dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_add, src0_desc, src1_desc, src0_desc);
|
||||
auto prim_desc = dnnl::binary::primitive_desc(desc, engine_);
|
||||
primitive_ = std::make_shared<dnnl::binary>(prim_desc);
|
||||
auto desc = CreateDesc<dnnl::binary::desc>(dnnl::algorithm::binary_add, src0_desc, src1_desc, src0_desc);
|
||||
auto prim_desc = CreateDesc<dnnl::binary::primitive_desc>(desc, engine_);
|
||||
primitive_ = CreatePrimitive<dnnl::binary>(prim_desc);
|
||||
AddArgument(DNNL_ARG_SRC_0, src0_desc);
|
||||
AddArgument(DNNL_ARG_SRC_1, src1_desc);
|
||||
AddArgument(DNNL_ARG_DST, src0_desc);
|
||||
|
|
|
@ -59,15 +59,17 @@ void BatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
prop_kind = dnnl::prop_kind::forward_training;
|
||||
normalization_flags = dnnl::normalization_flags::use_scale_shift;
|
||||
}
|
||||
dnnl::batch_normalization_forward::desc desc =
|
||||
dnnl::batch_normalization_forward::desc(prop_kind, x_desc, epsilon, normalization_flags);
|
||||
auto prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, engine_);
|
||||
primitive_ = std::make_shared<dnnl::batch_normalization_forward>(prim_desc);
|
||||
auto desc = CreateDesc<dnnl::batch_normalization_forward::desc>(prop_kind, x_desc, epsilon, normalization_flags);
|
||||
auto prim_desc = CreateDesc<dnnl::batch_normalization_forward::primitive_desc>(desc, engine_);
|
||||
auto wksp_desc = GetWorkspaceDesc(prim_desc);
|
||||
auto mean = GetMeanDesc(prim_desc);
|
||||
auto variance = GetVarianceDesc(prim_desc);
|
||||
primitive_ = CreatePrimitive<dnnl::batch_normalization_forward>(prim_desc);
|
||||
AddArgument(DNNL_ARG_SRC, x_desc);
|
||||
AddArgument(DNNL_ARG_MEAN, prim_desc.mean_desc());
|
||||
AddArgument(DNNL_ARG_VARIANCE, prim_desc.variance_desc());
|
||||
AddArgument(DNNL_ARG_MEAN, mean);
|
||||
AddArgument(DNNL_ARG_VARIANCE, variance);
|
||||
AddArgument(DNNL_ARG_SCALE_SHIFT, scale_bias_desc);
|
||||
AddArgument(DNNL_ARG_WORKSPACE, prim_desc.workspace_desc());
|
||||
AddArgument(DNNL_ARG_WORKSPACE, wksp_desc);
|
||||
AddArgument(DNNL_ARG_DST, x_desc);
|
||||
}
|
||||
|
||||
|
|
|
@ -64,21 +64,23 @@ void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
|
||||
// fused Batch Normalization forward description
|
||||
dnnl::batch_normalization_forward::desc desc =
|
||||
dnnl::batch_normalization_forward::desc(prop_kind, x_desc, epsilon, normalization_flags);
|
||||
auto forward_prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, engine_);
|
||||
auto desc = CreateDesc<dnnl::batch_normalization_forward::desc>(prop_kind, x_desc, epsilon, normalization_flags);
|
||||
auto forward_prim_desc = CreateDesc<dnnl::batch_normalization_forward::primitive_desc>(desc, engine_);
|
||||
|
||||
// fused Batch Normalization backward description
|
||||
dnnl::batch_normalization_backward::desc backward_desc =
|
||||
dnnl::batch_normalization_backward::desc(dnnl::prop_kind::backward, x_desc, x_desc, epsilon, normalization_flags);
|
||||
auto backward_desc = CreateDesc<dnnl::batch_normalization_backward::desc>(dnnl::prop_kind::backward, x_desc, x_desc,
|
||||
epsilon, normalization_flags);
|
||||
auto backward_prim_desc =
|
||||
dnnl::batch_normalization_backward::primitive_desc(backward_desc, engine_, forward_prim_desc);
|
||||
primitive_ = std::make_shared<dnnl::batch_normalization_backward>(backward_prim_desc);
|
||||
CreateDesc<dnnl::batch_normalization_backward::primitive_desc>(backward_desc, engine_, forward_prim_desc);
|
||||
auto wksp_desc = GetWorkspaceDesc(forward_prim_desc);
|
||||
auto mean = GetMeanDesc(forward_prim_desc);
|
||||
auto variance = GetVarianceDesc(forward_prim_desc);
|
||||
primitive_ = CreatePrimitive<dnnl::batch_normalization_backward>(backward_prim_desc);
|
||||
AddArgument(DNNL_ARG_SRC, x_desc);
|
||||
AddArgument(DNNL_ARG_MEAN, forward_prim_desc.mean_desc());
|
||||
AddArgument(DNNL_ARG_VARIANCE, forward_prim_desc.variance_desc());
|
||||
AddArgument(DNNL_ARG_MEAN, mean);
|
||||
AddArgument(DNNL_ARG_VARIANCE, variance);
|
||||
AddArgument(DNNL_ARG_SCALE_SHIFT, scale_bias_desc);
|
||||
AddArgument(DNNL_ARG_WORKSPACE, forward_prim_desc.workspace_desc());
|
||||
AddArgument(DNNL_ARG_WORKSPACE, wksp_desc);
|
||||
AddArgument(DNNL_ARG_DST, x_desc);
|
||||
AddArgument(DNNL_ARG_DIFF_DST, x_desc);
|
||||
AddArgument(DNNL_ARG_DIFF_SRC, x_desc);
|
||||
|
|
|
@ -81,18 +81,16 @@ void Conv2dGradFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]};
|
||||
dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]};
|
||||
dnnl::convolution_forward::desc forward_desc =
|
||||
dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc,
|
||||
weights_desc, dst_desc, strides, dilates, padding_l, padding_r);
|
||||
|
||||
auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, engine_);
|
||||
|
||||
dnnl::convolution_backward_weights::desc backward_desc = dnnl::convolution_backward_weights::desc(
|
||||
auto forward_desc = CreateDesc<dnnl::convolution_forward::desc>(
|
||||
dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides,
|
||||
dilates, padding_l, padding_r);
|
||||
auto forward_prim_desc = CreateDesc<dnnl::convolution_forward::primitive_desc>(forward_desc, engine_);
|
||||
auto backward_desc = CreateDesc<dnnl::convolution_backward_weights::desc>(
|
||||
dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r);
|
||||
|
||||
auto backward_prim_desc =
|
||||
dnnl::convolution_backward_weights::primitive_desc(backward_desc, engine_, forward_prim_desc);
|
||||
primitive_ = std::make_shared<dnnl::convolution_backward_weights>(backward_prim_desc);
|
||||
CreateDesc<dnnl::convolution_backward_weights::primitive_desc>(backward_desc, engine_, forward_prim_desc);
|
||||
primitive_ = CreatePrimitive<dnnl::convolution_backward_weights>(backward_prim_desc);
|
||||
|
||||
AddArgument(DNNL_ARG_SRC, src_desc);
|
||||
AddArgument(DNNL_ARG_DIFF_DST, dst_desc);
|
||||
|
|
|
@ -94,17 +94,15 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]};
|
||||
dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]};
|
||||
dnnl::convolution_forward::desc forward_desc =
|
||||
dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc,
|
||||
weights_desc, dst_desc, strides, dilates, padding_l, padding_r);
|
||||
|
||||
auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, engine_);
|
||||
|
||||
dnnl::convolution_backward_data::desc backward_desc = dnnl::convolution_backward_data::desc(
|
||||
auto forward_desc = CreateDesc<dnnl::convolution_forward::desc>(
|
||||
dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides,
|
||||
dilates, padding_l, padding_r);
|
||||
auto forward_prim_desc = CreateDesc<dnnl::convolution_forward::primitive_desc>(forward_desc, engine_);
|
||||
auto backward_desc = CreateDesc<dnnl::convolution_backward_data::desc>(
|
||||
dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r);
|
||||
|
||||
auto backward_prim_desc = dnnl::convolution_backward_data::primitive_desc(backward_desc, engine_, forward_prim_desc);
|
||||
primitive_ = std::make_shared<dnnl::convolution_backward_data>(backward_prim_desc);
|
||||
auto backward_prim_desc =
|
||||
CreateDesc<dnnl::convolution_backward_data::primitive_desc>(backward_desc, engine_, forward_prim_desc);
|
||||
primitive_ = CreatePrimitive<dnnl::convolution_backward_data>(backward_prim_desc);
|
||||
|
||||
AddArgument(DNNL_ARG_DIFF_SRC, src_desc);
|
||||
AddArgument(DNNL_ARG_DIFF_DST, dst_desc);
|
||||
|
|
|
@ -102,12 +102,12 @@ void ConvCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
(void)padding_l.emplace_back(int_padding_l[i]);
|
||||
(void)padding_r.emplace_back(int_padding_r[i]);
|
||||
}
|
||||
dnnl::convolution_forward::desc desc =
|
||||
dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc,
|
||||
weights_desc, dst_desc, strides, dilates, padding_l, padding_r);
|
||||
auto desc = CreateDesc<dnnl::convolution_forward::desc>(dnnl::prop_kind::forward_training,
|
||||
dnnl::algorithm::convolution_auto, src_desc, weights_desc,
|
||||
dst_desc, strides, dilates, padding_l, padding_r);
|
||||
|
||||
auto prim_desc = dnnl::convolution_forward::primitive_desc(desc, engine_);
|
||||
primitive_ = std::make_shared<dnnl::convolution_forward>(prim_desc);
|
||||
auto prim_desc = CreateDesc<dnnl::convolution_forward::primitive_desc>(desc, engine_);
|
||||
primitive_ = CreatePrimitive<dnnl::convolution_forward>(prim_desc);
|
||||
AddArgument(DNNL_ARG_SRC, src_desc);
|
||||
AddArgument(DNNL_ARG_WEIGHTS, weights_desc);
|
||||
AddArgument(DNNL_ARG_DST, dst_desc);
|
||||
|
|
|
@ -50,8 +50,9 @@ dnnl::eltwise_forward::desc EltWiseCPUKernel::GetForwardEltwiseDesc(const dnnl::
|
|||
if (desc_pair == eltWiseOpDescMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "EltWiseCPUKernel does not support " << kernel_name_;
|
||||
}
|
||||
return dnnl::eltwise_forward::desc(dnnl_forward_, desc_pair->second.algorithm, src_desc, desc_pair->second.alpha,
|
||||
desc_pair->second.beta);
|
||||
auto desc = CreateDesc<dnnl::eltwise_forward::desc>(dnnl_forward_, desc_pair->second.algorithm, src_desc,
|
||||
desc_pair->second.alpha, desc_pair->second.beta);
|
||||
return desc;
|
||||
}
|
||||
|
||||
void EltWiseCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
|
@ -64,8 +65,8 @@ void EltWiseCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
|
||||
|
||||
auto desc = GetForwardEltwiseDesc(src_desc);
|
||||
auto prim_desc = dnnl::eltwise_forward::primitive_desc(desc, engine_);
|
||||
primitive_ = std::make_shared<dnnl::eltwise_forward>(prim_desc);
|
||||
auto prim_desc = CreateDesc<dnnl::eltwise_forward::primitive_desc>(desc, engine_);
|
||||
primitive_ = CreatePrimitive<dnnl::eltwise_forward>(prim_desc);
|
||||
AddArgument(DNNL_ARG_SRC, src_desc);
|
||||
AddArgument(DNNL_ARG_DST, src_desc);
|
||||
}
|
||||
|
|
|
@ -38,10 +38,9 @@ void LogSoftmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
axis += SizeToInt(src_shape.size());
|
||||
}
|
||||
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
|
||||
dnnl::logsoftmax_forward::desc desc =
|
||||
dnnl::logsoftmax_forward::desc(dnnl::prop_kind::forward_training, src_desc, axis);
|
||||
auto prim_desc = dnnl::logsoftmax_forward::primitive_desc(desc, engine_);
|
||||
primitive_ = std::make_shared<dnnl::logsoftmax_forward>(prim_desc);
|
||||
auto desc = CreateDesc<dnnl::logsoftmax_forward::desc>(dnnl::prop_kind::forward_training, src_desc, axis);
|
||||
auto prim_desc = CreateDesc<dnnl::logsoftmax_forward::primitive_desc>(desc, engine_);
|
||||
primitive_ = CreatePrimitive<dnnl::logsoftmax_forward>(prim_desc);
|
||||
AddArgument(DNNL_ARG_SRC, src_desc);
|
||||
AddArgument(DNNL_ARG_DST, src_desc);
|
||||
}
|
||||
|
|
|
@ -38,13 +38,12 @@ void LogSoftmaxGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
axis += SizeToInt(src_shape.size());
|
||||
}
|
||||
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
|
||||
dnnl::logsoftmax_forward::desc desc =
|
||||
dnnl::logsoftmax_forward::desc(dnnl::prop_kind::forward_training, src_desc, axis);
|
||||
auto prim_desc = dnnl::logsoftmax_forward::primitive_desc(desc, engine_);
|
||||
auto desc = CreateDesc<dnnl::logsoftmax_forward::desc>(dnnl::prop_kind::forward_training, src_desc, axis);
|
||||
auto prim_desc = CreateDesc<dnnl::logsoftmax_forward::primitive_desc>(desc, engine_);
|
||||
// backward description
|
||||
dnnl::logsoftmax_backward::desc backward_desc = dnnl::logsoftmax_backward::desc(src_desc, src_desc, axis);
|
||||
auto backward_prim_desc = dnnl::logsoftmax_backward::primitive_desc(backward_desc, engine_, prim_desc);
|
||||
primitive_ = std::make_shared<dnnl::logsoftmax_backward>(backward_prim_desc);
|
||||
auto backward_desc = CreateDesc<dnnl::logsoftmax_backward::desc>(src_desc, src_desc, axis);
|
||||
auto backward_prim_desc = CreateDesc<dnnl::logsoftmax_backward::primitive_desc>(backward_desc, engine_, prim_desc);
|
||||
primitive_ = CreatePrimitive<dnnl::logsoftmax_backward>(backward_prim_desc);
|
||||
AddArgument(DNNL_ARG_DST, src_desc);
|
||||
AddArgument(DNNL_ARG_DIFF_SRC, src_desc);
|
||||
AddArgument(DNNL_ARG_DIFF_DST, src_desc);
|
||||
|
|
|
@ -25,6 +25,8 @@ constexpr size_t kLstmInputsNum = 4;
|
|||
constexpr size_t kLstmOutputsNum = 5;
|
||||
constexpr int kMaxLSTMLayer = 100;
|
||||
constexpr int kOutputWorkSpaceIndex = 3;
|
||||
constexpr int kInputCIndex = 2;
|
||||
constexpr int kInputWeightIndex = 3;
|
||||
constexpr int kGateNum = 4;
|
||||
|
||||
using tag = dnnl::memory::format_tag;
|
||||
|
@ -82,29 +84,46 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
if (!is_training) {
|
||||
prop_kind = dnnl::prop_kind::forward_inference;
|
||||
}
|
||||
auto desc = std::make_shared<dnnl::lstm_forward::desc>(
|
||||
prop_kind, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any),
|
||||
formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc);
|
||||
prim_desc_ = dnnl::lstm_forward::primitive_desc(*desc, eng);
|
||||
primitive_ = std::make_shared<dnnl::lstm_forward>(prim_desc_);
|
||||
auto weights_desc = formatted_md(weights_dims_, tag::any);
|
||||
auto weights_h_desc = formatted_md(weights_h_dims_, tag::any);
|
||||
auto desc =
|
||||
CreatePrimitive<dnnl::lstm_forward::desc>(prop_kind, direction, src_desc, src_h_desc, src_c_desc, weights_desc,
|
||||
weights_h_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc);
|
||||
prim_desc_ = CreateDesc<dnnl::lstm_forward::primitive_desc>(*desc, eng);
|
||||
primitive_ = CreatePrimitive<dnnl::lstm_forward>(prim_desc_);
|
||||
if (is_training) {
|
||||
reserve_size_ = static_cast<size_t>(prim_desc_.workspace_desc().get_size());
|
||||
AddArgument(DNNL_ARG_WORKSPACE, prim_desc_.workspace_desc());
|
||||
auto wksp_desc = GetWorkspaceDesc(prim_desc_);
|
||||
reserve_size_ = GetSize(wksp_desc);
|
||||
AddArgument(DNNL_ARG_WORKSPACE, wksp_desc);
|
||||
} else {
|
||||
reserve_size_ = 1;
|
||||
}
|
||||
auto weights_layer = GetWeightsLayerDesc(prim_desc_);
|
||||
auto weights_iter = GetWeightsIterDesc(prim_desc_);
|
||||
bias_desc_ = GetBiasDesc(prim_desc_);
|
||||
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
|
||||
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
|
||||
AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc);
|
||||
AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_desc_.weights_layer_desc());
|
||||
AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_desc_.weights_iter_desc());
|
||||
AddArgument(DNNL_ARG_WEIGHTS_LAYER, weights_layer);
|
||||
AddArgument(DNNL_ARG_WEIGHTS_ITER, weights_iter);
|
||||
AddArgument(DNNL_ARG_BIAS, bias_desc);
|
||||
AddArgument(DNNL_ARG_DST_LAYER, dst_desc);
|
||||
AddArgument(DNNL_ARG_DST_ITER, dst_h_desc);
|
||||
AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc);
|
||||
|
||||
using dt = dnnl::memory::data_type;
|
||||
using tag = dnnl::memory::format_tag;
|
||||
auto weights_dims_desc = CreateDesc<dnnl::memory::desc>(weights_dims_, dt::f32, tag::ldgoi);
|
||||
auto weights_h_dims_desc = CreateDesc<dnnl::memory::desc>(weights_h_dims_, dt::f32, tag::ldgoi);
|
||||
user_weights_memory_ = CreateDesc<dnnl::memory>(weights_dims_desc, eng);
|
||||
user_weights_h_memory_ = CreateDesc<dnnl::memory>(weights_h_dims_desc, eng);
|
||||
weights_memory_ = CreateDesc<dnnl::memory>(weights_layer, eng);
|
||||
weights_h_memory_ = CreateDesc<dnnl::memory>(weights_iter, eng);
|
||||
bias_memory_ = CreateDesc<dnnl::memory>(bias_desc_, eng);
|
||||
}
|
||||
|
||||
void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
||||
constexpr int kBidirectional = 2;
|
||||
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
std::vector<size_t> src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
|
||||
|
@ -117,7 +136,7 @@ void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
|||
seq_len_ = SizeToInt(src_shape[0]);
|
||||
num_directions_ = 1;
|
||||
if (bidirectional_) {
|
||||
num_directions_ = 2;
|
||||
num_directions_ = kBidirectional;
|
||||
}
|
||||
const int gate_size = kGateNum * hidden_size_;
|
||||
if (num_layers_ <= 0) {
|
||||
|
@ -142,33 +161,26 @@ void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
|||
|
||||
bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
using dt = dnnl::memory::data_type;
|
||||
using tag = dnnl::memory::format_tag;
|
||||
auto eng = engine_;
|
||||
auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng);
|
||||
auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
|
||||
auto weights_memory = dnnl::memory(prim_desc_.weights_layer_desc(), eng);
|
||||
auto weights_h_memory = dnnl::memory(prim_desc_.weights_iter_desc(), eng);
|
||||
user_weights_memory.set_data_handle(inputs[3]->addr);
|
||||
user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_);
|
||||
Reorder(&user_weights_memory, &weights_memory);
|
||||
Reorder(&user_weights_h_memory, &weights_h_memory);
|
||||
auto bias_memory = dnnl::memory(prim_desc_.bias_desc(), eng);
|
||||
SetDataHandle(user_weights_memory_, inputs[kInputWeightIndex]->addr);
|
||||
SetDataHandle(user_weights_h_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + weight_size_);
|
||||
Reorder(&user_weights_memory_, &weights_memory_);
|
||||
Reorder(&user_weights_h_memory_, &weights_h_memory_);
|
||||
if (has_bias_) {
|
||||
bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
|
||||
SetDataHandle(bias_memory_,
|
||||
reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + weight_size_ + weight_h_size_);
|
||||
} else {
|
||||
if (memset_s(bias_memory.get_data_handle(), prim_desc_.bias_desc().get_size(), 0,
|
||||
prim_desc_.bias_desc().get_size())) {
|
||||
auto size = GetSize(bias_desc_);
|
||||
if (memset_s(GetDataHandle(bias_memory_), size, 0, size)) {
|
||||
MS_LOG(EXCEPTION) << "Bias memset error";
|
||||
}
|
||||
}
|
||||
// set handle
|
||||
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[kInputCIndex]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, GetDataHandle(weights_memory_));
|
||||
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, GetDataHandle(weights_h_memory_));
|
||||
SetArgumentHandle(DNNL_ARG_BIAS, GetDataHandle(bias_memory_));
|
||||
SetArgumentHandle(DNNL_ARG_DST_LAYER, outputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DST_ITER, outputs[1]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DST_ITER_C, outputs[2]->addr);
|
||||
|
|
|
@ -55,6 +55,12 @@ class LstmCPUKernel : public MKLCPUKernel {
|
|||
dnnl::memory::dims weights_h_dims_;
|
||||
dnnl::memory::dims bias_dims_;
|
||||
dnnl::lstm_forward::primitive_desc prim_desc_;
|
||||
dnnl::memory::desc bias_desc_;
|
||||
dnnl::memory user_weights_memory_;
|
||||
dnnl::memory user_weights_h_memory_;
|
||||
dnnl::memory weights_memory_;
|
||||
dnnl::memory weights_h_memory_;
|
||||
dnnl::memory bias_memory_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(LSTM,
|
||||
|
|
|
@ -27,6 +27,19 @@ constexpr size_t kLstmGradInputsNum = 11;
|
|||
constexpr size_t kLstmGradOutputsNum = 4;
|
||||
constexpr int kMaxLSTMLayer = 100;
|
||||
constexpr int kInputWorkSpaceIndex = 10;
|
||||
constexpr int kInputWeightIndex = 3;
|
||||
constexpr int kOutputWeightIndex = 3;
|
||||
|
||||
constexpr int kSrcLayerIdx = 0;
|
||||
constexpr int kSrcIterIdx = 1;
|
||||
constexpr int kSrcIterCIdx = 2;
|
||||
constexpr int kDstLayerIdx = 4;
|
||||
constexpr int kDstIterIdx = 5;
|
||||
constexpr int kDstIterCIdx = 6;
|
||||
constexpr int kDiffDstLayerIdx = 7;
|
||||
constexpr int kDiffDstIterIdx = 8;
|
||||
constexpr int kDiffDstIterCIdx = 9;
|
||||
constexpr int kWorkspaceIdx = 10;
|
||||
|
||||
using tag = dnnl::memory::format_tag;
|
||||
using dim = dnnl::memory::dims;
|
||||
|
@ -63,21 +76,45 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
|
||||
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
|
||||
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
|
||||
auto forward_desc = std::make_shared<dnnl::lstm_forward::desc>(
|
||||
dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
|
||||
formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc,
|
||||
dst_c_desc);
|
||||
auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(*forward_desc, eng);
|
||||
auto backward_desc = std::make_shared<dnnl::lstm_backward::desc>(
|
||||
dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any),
|
||||
formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc,
|
||||
src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc,
|
||||
dst_h_desc, dst_c_desc);
|
||||
prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(*backward_desc, eng, prim_forward_desc);
|
||||
primitive_ = std::make_shared<dnnl::lstm_backward>(prim_backward_desc_);
|
||||
reserve_size_ = static_cast<size_t>(prim_forward_desc.workspace_desc().get_size());
|
||||
AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc());
|
||||
auto weights_desc = formatted_md(weights_dims_, tag::any);
|
||||
auto weights_h_desc = formatted_md(weights_h_dims_, tag::any);
|
||||
|
||||
auto forward_desc = CreatePrimitive<dnnl::lstm_forward::desc>(dnnl::prop_kind::forward_training, direction, src_desc,
|
||||
src_h_desc, src_c_desc, weights_desc, weights_h_desc,
|
||||
bias_desc, dst_desc, dst_h_desc, dst_c_desc);
|
||||
auto prim_forward_desc = CreateDesc<dnnl::lstm_forward::primitive_desc>(*forward_desc, eng);
|
||||
auto backward_desc = CreatePrimitive<dnnl::lstm_backward::desc>(
|
||||
dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, weights_desc, weights_h_desc, bias_desc,
|
||||
dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc, src_c_desc, weights_desc, weights_h_desc, bias_desc,
|
||||
dst_desc, dst_h_desc, dst_c_desc);
|
||||
prim_backward_desc_ = CreateDesc<dnnl::lstm_backward::primitive_desc>(*backward_desc, eng, prim_forward_desc);
|
||||
primitive_ = CreatePrimitive<dnnl::lstm_backward>(prim_backward_desc_);
|
||||
auto wksp_desc = GetWorkspaceDesc(prim_forward_desc);
|
||||
reserve_size_ = GetSize(wksp_desc);
|
||||
AddArgument(DNNL_ARG_WORKSPACE, wksp_desc);
|
||||
AddArgumentOp(src_desc, src_h_desc, src_c_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc);
|
||||
|
||||
// construct fw memory
|
||||
weights_layer_desc_ = GetWeightsLayerDesc(prim_backward_desc_);
|
||||
weights_iter_desc_ = GetWeightsIterDesc(prim_backward_desc_);
|
||||
bias_desc_ = GetBiasDesc(prim_backward_desc_);
|
||||
auto weights_mem_desc = CreateDesc<dnnl::memory::desc>(weights_dims_, dt::f32, tag::ldgoi);
|
||||
auto weights_h_mem_desc = CreateDesc<dnnl::memory::desc>(weights_h_dims_, dt::f32, tag::ldgoi);
|
||||
user_weights_memory_ = CreateDesc<dnnl::memory>(weights_mem_desc, eng);
|
||||
user_weights_h_memory_ = CreateDesc<dnnl::memory>(weights_h_mem_desc, eng);
|
||||
weights_memory_ = CreateDesc<dnnl::memory>(weights_layer_desc_, eng);
|
||||
weights_h_memory_ = CreateDesc<dnnl::memory>(weights_iter_desc_, eng);
|
||||
bias_memory_ = CreateDesc<dnnl::memory>(bias_desc_, eng);
|
||||
|
||||
// construct bw memory
|
||||
diff_weights_layer_desc_ = GetDiffWeightsLayerDesc(prim_backward_desc_);
|
||||
diff_weights_iter_desc_ = GetDiffWeightsIterDesc(prim_backward_desc_);
|
||||
diff_bias_desc_ = GetDiffBiasDesc(prim_backward_desc_);
|
||||
diff_weights_memory_ = CreateDesc<dnnl::memory>(diff_weights_layer_desc_, eng);
|
||||
diff_weights_h_memory_ = CreateDesc<dnnl::memory>(diff_weights_iter_desc_, eng);
|
||||
diff_bias_memory_ = CreateDesc<dnnl::memory>(diff_bias_desc_, eng);
|
||||
user_diff_weights_memory_ = CreateDesc<dnnl::memory>(weights_mem_desc, eng);
|
||||
user_diff_weights_h_memory_ = CreateDesc<dnnl::memory>(weights_h_mem_desc, eng);
|
||||
}
|
||||
|
||||
void LSTMGradCPUKernel::AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_h_desc,
|
||||
|
@ -87,8 +124,8 @@ void LSTMGradCPUKernel::AddArgumentOp(const dnnl::memory::desc &src_desc, const
|
|||
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
|
||||
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
|
||||
AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc);
|
||||
AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_backward_desc_.weights_layer_desc());
|
||||
AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_backward_desc_.weights_iter_desc());
|
||||
AddArgument(DNNL_ARG_WEIGHTS_LAYER, weights_layer_desc_);
|
||||
AddArgument(DNNL_ARG_WEIGHTS_ITER, weights_iter_desc_);
|
||||
AddArgument(DNNL_ARG_BIAS, bias_desc);
|
||||
AddArgument(DNNL_ARG_DST_LAYER, dst_desc);
|
||||
AddArgument(DNNL_ARG_DST_ITER, dst_h_desc);
|
||||
|
@ -96,8 +133,8 @@ void LSTMGradCPUKernel::AddArgumentOp(const dnnl::memory::desc &src_desc, const
|
|||
AddArgument(DNNL_ARG_DIFF_SRC_LAYER, src_desc);
|
||||
AddArgument(DNNL_ARG_DIFF_SRC_ITER, src_h_desc);
|
||||
AddArgument(DNNL_ARG_DIFF_SRC_ITER_C, src_c_desc);
|
||||
AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, prim_backward_desc_.diff_weights_layer_desc());
|
||||
AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, prim_backward_desc_.diff_weights_iter_desc());
|
||||
AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_layer_desc_);
|
||||
AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_iter_desc_);
|
||||
AddArgument(DNNL_ARG_DIFF_BIAS, bias_desc);
|
||||
AddArgument(DNNL_ARG_DIFF_DST_LAYER, dst_desc);
|
||||
AddArgument(DNNL_ARG_DIFF_DST_ITER, dst_h_desc);
|
||||
|
@ -141,34 +178,33 @@ void LSTMGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
|||
}
|
||||
|
||||
void LSTMGradCPUKernel::SetArgumentHandleOp(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs,
|
||||
const dnnl::memory &weights_memory, const dnnl::memory &weights_h_memory,
|
||||
const dnnl::memory &bias_memory, const dnnl::memory &diff_weights_memory,
|
||||
const dnnl::memory &diff_weights_h_memory,
|
||||
const dnnl::memory &diff_bias_memory) {
|
||||
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle());
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr);
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[kSrcLayerIdx]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[kSrcIterIdx]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[kSrcIterCIdx]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, GetDataHandle(weights_memory_));
|
||||
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, GetDataHandle(weights_h_memory_));
|
||||
SetArgumentHandle(DNNL_ARG_BIAS, GetDataHandle(bias_memory_));
|
||||
SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[kDstLayerIdx]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[kDstIterIdx]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[kDstIterCIdx]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[kWorkspaceIdx]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[kSrcLayerIdx]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[kSrcIterIdx]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[kSrcIterCIdx]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, GetDataHandle(diff_weights_memory_));
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, GetDataHandle(diff_weights_h_memory_));
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_BIAS, GetDataHandle(diff_bias_memory_));
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[kDiffDstLayerIdx]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[kDiffDstIterIdx]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[kDiffDstIterCIdx]->addr);
|
||||
}
|
||||
|
||||
void LSTMGradCPUKernel::ResetMemory(const dnnl::memory &mem, const string name) const {
|
||||
if (memset_s(mem.get_data_handle(), mem.get_desc().get_size(), 0, mem.get_desc().get_size())) {
|
||||
auto dst_ptr = GetDataHandle(mem);
|
||||
auto mem_desc = GetMemDesc(mem);
|
||||
auto size = GetSize(mem_desc);
|
||||
if (memset_s(dst_ptr, size, 0, size)) {
|
||||
MS_LOG(EXCEPTION) << name << " memset error";
|
||||
}
|
||||
}
|
||||
|
@ -177,49 +213,41 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, co
|
|||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kLstmGradInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kLstmGradOutputsNum, kernel_name_);
|
||||
auto eng = engine_;
|
||||
// construct fw memory
|
||||
auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng);
|
||||
auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
|
||||
auto weights_memory = dnnl::memory(prim_backward_desc_.weights_layer_desc(), eng);
|
||||
auto weights_h_memory = dnnl::memory(prim_backward_desc_.weights_iter_desc(), eng);
|
||||
auto bias_memory = dnnl::memory(prim_backward_desc_.bias_desc(), eng);
|
||||
user_weights_memory.set_data_handle(inputs[3]->addr);
|
||||
user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_);
|
||||
Reorder(&user_weights_memory, &weights_memory);
|
||||
Reorder(&user_weights_h_memory, &weights_h_memory);
|
||||
SetDataHandle(user_weights_memory_, inputs[kInputWeightIndex]->addr);
|
||||
SetDataHandle(user_weights_h_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + weight_size_);
|
||||
Reorder(&user_weights_memory_, &weights_memory_);
|
||||
Reorder(&user_weights_h_memory_, &weights_h_memory_);
|
||||
if (has_bias_) {
|
||||
bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
|
||||
SetDataHandle(bias_memory_,
|
||||
reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + weight_size_ + weight_h_size_);
|
||||
} else {
|
||||
if (memset_s(bias_memory.get_data_handle(), prim_backward_desc_.bias_desc().get_size(), 0,
|
||||
prim_backward_desc_.bias_desc().get_size())) {
|
||||
auto dst_ptr = GetDataHandle(bias_memory_);
|
||||
auto size = GetSize(bias_desc_);
|
||||
if (memset_s(dst_ptr, size, 0, size)) {
|
||||
MS_LOG(EXCEPTION) << "Bias memset error";
|
||||
}
|
||||
}
|
||||
// construct bw memory
|
||||
auto diff_weights_memory = dnnl::memory(prim_backward_desc_.diff_weights_layer_desc(), eng);
|
||||
auto diff_weights_h_memory = dnnl::memory(prim_backward_desc_.diff_weights_iter_desc(), eng);
|
||||
auto diff_bias_memory = dnnl::memory(prim_backward_desc_.diff_bias_desc(), eng);
|
||||
auto user_diff_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng);
|
||||
auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
|
||||
user_diff_weights_memory.set_data_handle(outputs[3]->addr);
|
||||
user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_);
|
||||
ResetMemory(user_diff_weights_memory, "user weights grad");
|
||||
ResetMemory(user_diff_weights_h_memory, "user weights iter grad");
|
||||
ResetMemory(diff_weights_memory, "weights grad");
|
||||
ResetMemory(diff_weights_h_memory, "weights iter grad");
|
||||
|
||||
SetDataHandle(user_diff_weights_memory_, outputs[kOutputWeightIndex]->addr);
|
||||
SetDataHandle(user_diff_weights_h_memory_,
|
||||
reinterpret_cast<float *>(outputs[kOutputWeightIndex]->addr) + weight_size_);
|
||||
ResetMemory(user_diff_weights_memory_, "user weights grad");
|
||||
ResetMemory(user_diff_weights_h_memory_, "user weights iter grad");
|
||||
ResetMemory(diff_weights_memory_, "weights grad");
|
||||
ResetMemory(diff_weights_h_memory_, "weights iter grad");
|
||||
if (has_bias_) {
|
||||
diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_);
|
||||
SetDataHandle(diff_bias_memory_,
|
||||
reinterpret_cast<float *>(outputs[kOutputWeightIndex]->addr) + weight_size_ + weight_h_size_);
|
||||
}
|
||||
if (memset_s(diff_bias_memory.get_data_handle(), prim_backward_desc_.diff_bias_desc().get_size(), 0,
|
||||
prim_backward_desc_.diff_bias_desc().get_size())) {
|
||||
auto dst_ptr = GetDataHandle(diff_bias_memory_);
|
||||
auto size = GetSize(diff_bias_desc_);
|
||||
if (memset_s(dst_ptr, size, 0, size)) {
|
||||
MS_LOG(EXCEPTION) << "Bias grad memset error";
|
||||
}
|
||||
SetArgumentHandleOp(inputs, outputs, weights_memory, weights_h_memory, bias_memory, diff_weights_memory,
|
||||
diff_weights_h_memory, diff_bias_memory);
|
||||
SetArgumentHandleOp(inputs, outputs);
|
||||
ExecutePrimitive();
|
||||
Reorder(&diff_weights_memory, &user_diff_weights_memory);
|
||||
Reorder(&diff_weights_h_memory, &user_diff_weights_h_memory);
|
||||
Reorder(&diff_weights_memory_, &user_diff_weights_memory_);
|
||||
Reorder(&diff_weights_h_memory_, &user_diff_weights_h_memory_);
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -42,10 +42,7 @@ class LSTMGradCPUKernel : public MKLCPUKernel {
|
|||
const dnnl::memory::desc &dst_desc, const dnnl::memory::desc &dst_h_desc,
|
||||
const dnnl::memory::desc &dst_c_desc);
|
||||
void SetArgumentHandleOp(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs, const dnnl::memory &weights_memory,
|
||||
const dnnl::memory &weights_h_memory, const dnnl::memory &bias_memory,
|
||||
const dnnl::memory &diff_weights_memory, const dnnl::memory &diff_weights_h_memory,
|
||||
const dnnl::memory &diff_bias_memory);
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
void ResetMemory(const dnnl::memory &mem, const string name) const;
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
|
||||
|
@ -65,6 +62,23 @@ class LSTMGradCPUKernel : public MKLCPUKernel {
|
|||
dnnl::memory::dims weights_h_dims_;
|
||||
dnnl::memory::dims bias_dims_;
|
||||
dnnl::lstm_backward::primitive_desc prim_backward_desc_;
|
||||
|
||||
dnnl::memory::desc weights_layer_desc_;
|
||||
dnnl::memory::desc weights_iter_desc_;
|
||||
dnnl::memory::desc bias_desc_;
|
||||
dnnl::memory::desc diff_weights_layer_desc_;
|
||||
dnnl::memory::desc diff_weights_iter_desc_;
|
||||
dnnl::memory::desc diff_bias_desc_;
|
||||
dnnl::memory user_weights_memory_;
|
||||
dnnl::memory user_weights_h_memory_;
|
||||
dnnl::memory weights_memory_;
|
||||
dnnl::memory weights_h_memory_;
|
||||
dnnl::memory bias_memory_;
|
||||
dnnl::memory diff_weights_memory_;
|
||||
dnnl::memory diff_weights_h_memory_;
|
||||
dnnl::memory diff_bias_memory_;
|
||||
dnnl::memory user_diff_weights_memory_;
|
||||
dnnl::memory user_diff_weights_h_memory_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(LSTMGrad,
|
||||
|
|
|
@ -76,12 +76,12 @@ void MatMulCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
o_strides = {dim_n, 1};
|
||||
}
|
||||
|
||||
dnnl::memory::desc src_md(src_dims, dnnl::memory::data_type::f32, a_strides);
|
||||
dnnl::memory::desc weights_md(weights_dims, dnnl::memory::data_type::f32, b_strides);
|
||||
dnnl::memory::desc dst_md(dst_dims, dnnl::memory::data_type::f32, o_strides);
|
||||
dnnl::matmul::desc matmul_desc(src_md, weights_md, dst_md);
|
||||
dnnl::matmul::primitive_desc prim_desc(matmul_desc, engine_);
|
||||
primitive_ = std::make_shared<dnnl::matmul>(prim_desc);
|
||||
auto src_md = CreateDesc<dnnl::memory::desc>(src_dims, dnnl::memory::data_type::f32, a_strides);
|
||||
auto weights_md = CreateDesc<dnnl::memory::desc>(weights_dims, dnnl::memory::data_type::f32, b_strides);
|
||||
auto dst_md = CreateDesc<dnnl::memory::desc>(dst_dims, dnnl::memory::data_type::f32, o_strides);
|
||||
auto matmul_desc = CreateDesc<dnnl::matmul::desc>(src_md, weights_md, dst_md);
|
||||
auto prim_desc = CreateDesc<dnnl::matmul::primitive_desc>(matmul_desc, engine_);
|
||||
primitive_ = CreatePrimitive<dnnl::matmul>(prim_desc);
|
||||
|
||||
AddArgument(DNNL_ARG_SRC, src_md);
|
||||
AddArgument(DNNL_ARG_WEIGHTS, weights_md);
|
||||
|
|
|
@ -134,7 +134,7 @@ dnnl::memory::desc MKLCPUKernel::GetDefaultMemDesc(const std::vector<size_t> &sh
|
|||
(void)dims.insert(dims.end(), shape.begin(), shape.end());
|
||||
}
|
||||
dnnl::memory::format_tag mem_tag = GetDefaultFormatTag(dims);
|
||||
dnnl::memory::desc mem_desc(dims, dnnl::memory::data_type::f32, mem_tag);
|
||||
auto mem_desc = CreateDesc<dnnl::memory::desc>(dims, dnnl::memory::data_type::f32, mem_tag);
|
||||
return mem_desc;
|
||||
}
|
||||
|
||||
|
@ -149,7 +149,9 @@ void MKLCPUKernel::AddArgument(int arg_key, const dnnl::memory::desc &mem_desc,
|
|||
void MKLCPUKernel::SetArgumentHandle(int arg_key, void *ptr) {
|
||||
auto arg_iter = arguments_.find(arg_key);
|
||||
if (arg_iter != arguments_.end()) {
|
||||
MS_LOG(DEBUG) << "begin to invoke dnnl::memory::set_data_handle";
|
||||
arg_iter->second.set_data_handle(ptr);
|
||||
MS_LOG(DEBUG) << "end to invoke dnnl::memory::set_data_handle";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -169,7 +171,9 @@ void MKLCPUKernel::ExecutePrimitive() {
|
|||
}
|
||||
double start_time = GetTime();
|
||||
mkl_pool->set_num_threads(current_thread_nums);
|
||||
MS_LOG(DEBUG) << "begin to invoke primitive::execute";
|
||||
primitive_->execute(stream_, arguments_);
|
||||
MS_LOG(DEBUG) << "end to invoke primitive::execute";
|
||||
double cost_time = GetTime() - start_time;
|
||||
parallel_search_info_.tmp_sum_cost_time += cost_time;
|
||||
parallel_search_info_.search_count++;
|
||||
|
@ -184,16 +188,45 @@ void MKLCPUKernel::ExecutePrimitive() {
|
|||
} else {
|
||||
int best_thread_nums = static_cast<int>(std::pow(2.0f, parallel_search_info_.best_pow));
|
||||
mkl_pool->set_num_threads(best_thread_nums);
|
||||
MS_LOG(DEBUG) << "begin to invoke primitive::execute";
|
||||
primitive_->execute(stream_, arguments_);
|
||||
MS_LOG(DEBUG) << "end to invoke primitive::execute";
|
||||
}
|
||||
#else
|
||||
MS_LOG(DEBUG) << "begin to invoke primitive::execute";
|
||||
primitive_->execute(stream_, arguments_);
|
||||
MS_LOG(DEBUG) << "end to invoke primitive::execute";
|
||||
#endif
|
||||
(void)stream_.wait();
|
||||
}
|
||||
|
||||
void MKLCPUKernel::SetDataHandle(dnnl::memory mem, void *ptr) {
|
||||
MS_LOG(DEBUG) << "begin to invoke dnnl::memory::set_data_handle";
|
||||
mem.set_data_handle(ptr);
|
||||
MS_LOG(DEBUG) << "end to invoke dnnl::memory::set_data_handle";
|
||||
}
|
||||
|
||||
void *MKLCPUKernel::GetDataHandle(const dnnl::memory &mem) const {
|
||||
MS_LOG(DEBUG) << "begin to invoke dnnl::memory::get_data_handle";
|
||||
auto ptr = mem.get_data_handle();
|
||||
MS_LOG(DEBUG) << "end to invoke dnnl::memory::get_data_handle";
|
||||
return ptr;
|
||||
}
|
||||
|
||||
size_t MKLCPUKernel::GetSize(const dnnl::memory::desc &desc) const {
|
||||
MS_LOG(DEBUG) << "begin to invoke dnnl::memory::desc::get_size()";
|
||||
auto size = desc.get_size();
|
||||
MS_LOG(DEBUG) << "end to invoke dnnl::memory::desc::get_size()";
|
||||
return size;
|
||||
}
|
||||
|
||||
void MKLCPUKernel::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) {
|
||||
dnnl::reorder(*src_mem, *dst_mem).execute(stream_, *src_mem, *dst_mem);
|
||||
MS_LOG(DEBUG) << "begin to invoke constructor of dnnl::reorder";
|
||||
auto desc = dnnl::reorder(*src_mem, *dst_mem);
|
||||
MS_LOG(DEBUG) << "end to invoke constructor of dnnl::reorder";
|
||||
MS_LOG(DEBUG) << "begin to invoke primitive::execute";
|
||||
desc.execute(stream_, *src_mem, *dst_mem);
|
||||
MS_LOG(DEBUG) << "begin to invoke primitive::execute";
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "dnnl.hpp"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
@ -32,6 +33,102 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <class T, class... Args>
|
||||
auto CreateDesc(Args &&... args) {
|
||||
MS_LOG(DEBUG) << "begin to invoke constructor of " << demangle(typeid(T).name());
|
||||
auto desc = T(std::forward<Args>(args)...);
|
||||
MS_LOG(DEBUG) << "end to invoke constructor of " << demangle(typeid(T).name());
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <class T, class... Args>
|
||||
auto CreatePrimitive(Args &&... args) {
|
||||
MS_LOG(DEBUG) << "begin to invoke constructor of " << demangle(typeid(T).name());
|
||||
auto prim = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
MS_LOG(DEBUG) << "end to invoke constructor of " << demangle(typeid(T).name());
|
||||
return prim;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
auto GetWorkspaceDesc(const T &prim_desc) {
|
||||
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::workspace_desc()";
|
||||
auto desc = prim_desc.workspace_desc();
|
||||
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::workspace_desc()";
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
auto GetMeanDesc(const T &prim_desc) {
|
||||
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::mean_desc()";
|
||||
auto desc = prim_desc.mean_desc();
|
||||
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::mean_desc()";
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
auto GetVarianceDesc(const T &prim_desc) {
|
||||
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::variance_desc()";
|
||||
auto desc = prim_desc.variance_desc();
|
||||
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::variance_desc()";
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
auto GetWeightsLayerDesc(const T &prim_desc) {
|
||||
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::weights_layer_desc()";
|
||||
auto desc = prim_desc.weights_layer_desc();
|
||||
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::weights_layer_desc()";
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
auto GetWeightsIterDesc(const T &prim_desc) {
|
||||
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::weights_iter_desc()";
|
||||
auto desc = prim_desc.weights_iter_desc();
|
||||
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::weights_iter_desc()";
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
auto GetBiasDesc(const T &prim_desc) {
|
||||
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::bias_desc()";
|
||||
auto desc = prim_desc.bias_desc();
|
||||
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::bias_desc()";
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
auto GetDiffWeightsLayerDesc(const T &prim_desc) {
|
||||
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::diff_weights_layer_desc()";
|
||||
auto desc = prim_desc.diff_weights_layer_desc();
|
||||
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::diff_weights_layer_desc()";
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
auto GetDiffWeightsIterDesc(const T &prim_desc) {
|
||||
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::diff_weights_iter_desc()";
|
||||
auto desc = prim_desc.diff_weights_iter_desc();
|
||||
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::diff_weights_iter_desc()";
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
auto GetDiffBiasDesc(const T &prim_desc) {
|
||||
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::diff_bias_desc()";
|
||||
auto desc = prim_desc.diff_bias_desc();
|
||||
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::diff_bias_desc()";
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
auto GetMemDesc(const T &prim_desc) {
|
||||
MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::get_desc()";
|
||||
auto desc = prim_desc.get_desc();
|
||||
MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::get_desc()";
|
||||
return desc;
|
||||
}
|
||||
|
||||
#ifdef USE_MS_THREADPOOL_FOR_DNNL
|
||||
class mkl_threadpool : public dnnl::threadpool_interop::threadpool_iface {
|
||||
private:
|
||||
|
@ -62,7 +159,9 @@ class MKLCPUKernel : public CPUKernel {
|
|||
MKLCPUKernel() : engine_(dnnl::engine::kind::cpu, 0) {
|
||||
auto thread_pool = GetActorMgrInnerThreadPool();
|
||||
mkl_threadpool_ = std::make_shared<mkl_threadpool>(thread_pool);
|
||||
MS_LOG(DEBUG) << "begin to invoke dnnl::threadpool_interop::make_stream";
|
||||
stream_ = dnnl::threadpool_interop::make_stream(engine_, mkl_threadpool_.get());
|
||||
MS_LOG(DEBUG) << "end to invoke dnnl::threadpool_interop::make_stream";
|
||||
}
|
||||
#else
|
||||
MKLCPUKernel() : engine_(dnnl::engine::kind::cpu, 0), stream_(engine_) {}
|
||||
|
@ -81,10 +180,17 @@ class MKLCPUKernel : public CPUKernel {
|
|||
dnnl::memory::desc GetDefaultMemDesc(const std::vector<size_t> &shape) const;
|
||||
void ExecutePrimitive();
|
||||
inline dnnl::memory::desc formatted_md(const dnnl::memory::dims &dimensions, dnnl::memory::format_tag layout) {
|
||||
return dnnl::memory::desc{{dimensions}, dnnl::memory::data_type::f32, layout};
|
||||
MS_LOG(DEBUG) << "begin to invoke constructor of dnnl::memory::desc";
|
||||
auto desc = dnnl::memory::desc{{dimensions}, dnnl::memory::data_type::f32, layout};
|
||||
MS_LOG(DEBUG) << "end to invoke constructor of dnnl::memory::desc";
|
||||
return desc;
|
||||
}
|
||||
void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem);
|
||||
|
||||
size_t GetSize(const dnnl::memory::desc &desc) const;
|
||||
void SetDataHandle(dnnl::memory mem, void *ptr);
|
||||
void *GetDataHandle(const dnnl::memory &mem) const;
|
||||
|
||||
std::unordered_map<int, dnnl::memory> arguments_;
|
||||
std::shared_ptr<dnnl::primitive> primitive_{nullptr};
|
||||
dnnl::engine engine_;
|
||||
|
|
|
@ -65,17 +65,17 @@ void AvgPoolingGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]};
|
||||
|
||||
// pooling_avg forward description
|
||||
dnnl::pooling_forward::desc desc =
|
||||
dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_avg, src_desc, dst_desc,
|
||||
strides_dims, kernels_dims, padding_l, padding_r);
|
||||
auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, engine_);
|
||||
auto desc =
|
||||
CreateDesc<dnnl::pooling_forward::desc>(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_avg, src_desc,
|
||||
dst_desc, strides_dims, kernels_dims, padding_l, padding_r);
|
||||
auto prim_desc = CreateDesc<dnnl::pooling_forward::primitive_desc>(desc, engine_);
|
||||
|
||||
// pooling_avg backward description
|
||||
dnnl::pooling_backward::desc backward_desc = dnnl::pooling_backward::desc(
|
||||
dnnl::algorithm::pooling_avg, src_desc, dst_desc, strides_dims, kernels_dims, padding_l, padding_r);
|
||||
auto backward_prim_desc = dnnl::pooling_backward::primitive_desc(backward_desc, engine_, prim_desc);
|
||||
auto backward_desc = CreateDesc<dnnl::pooling_backward::desc>(dnnl::algorithm::pooling_avg, src_desc, dst_desc,
|
||||
strides_dims, kernels_dims, padding_l, padding_r);
|
||||
auto backward_prim_desc = CreateDesc<dnnl::pooling_backward::primitive_desc>(backward_desc, engine_, prim_desc);
|
||||
|
||||
primitive_ = std::make_shared<dnnl::pooling_backward>(backward_prim_desc);
|
||||
primitive_ = CreatePrimitive<dnnl::pooling_backward>(backward_prim_desc);
|
||||
AddArgument(DNNL_ARG_SRC, src_desc);
|
||||
AddArgument(DNNL_ARG_DST, dst_desc);
|
||||
AddArgument(DNNL_ARG_DIFF_SRC, src_desc);
|
||||
|
|
|
@ -76,19 +76,22 @@ void PoolingCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
(void)padding_l.emplace_back(int_padding_l[i]);
|
||||
(void)padding_r.emplace_back(int_padding_r[i]);
|
||||
}
|
||||
dnnl::pooling_forward::desc desc =
|
||||
dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_max, src_desc, dst_desc,
|
||||
strides_dims, kernels_dims, padding_l, padding_r);
|
||||
auto desc =
|
||||
CreateDesc<dnnl::pooling_forward::desc>(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_max, src_desc,
|
||||
dst_desc, strides_dims, kernels_dims, padding_l, padding_r);
|
||||
if (kernel_name_ == prim::kPrimAvgPool->name()) {
|
||||
desc = dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_avg, src_desc,
|
||||
dst_desc, strides_dims, kernels_dims, padding_l, padding_r);
|
||||
desc =
|
||||
CreateDesc<dnnl::pooling_forward::desc>(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_avg, src_desc,
|
||||
dst_desc, strides_dims, kernels_dims, padding_l, padding_r);
|
||||
}
|
||||
auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, engine_);
|
||||
workspace_size_ = prim_desc.workspace_desc().get_size();
|
||||
primitive_ = std::make_shared<dnnl::pooling_forward>(prim_desc);
|
||||
|
||||
auto prim_desc = CreateDesc<dnnl::pooling_forward::primitive_desc>(desc, engine_);
|
||||
auto wksp_desc = GetWorkspaceDesc(prim_desc);
|
||||
workspace_size_ = GetSize(wksp_desc);
|
||||
primitive_ = CreatePrimitive<dnnl::pooling_forward>(prim_desc);
|
||||
AddArgument(DNNL_ARG_SRC, src_desc);
|
||||
AddArgument(DNNL_ARG_DST, dst_desc);
|
||||
AddArgument(DNNL_ARG_WORKSPACE, prim_desc.workspace_desc());
|
||||
AddArgument(DNNL_ARG_WORKSPACE, wksp_desc);
|
||||
}
|
||||
|
||||
bool PoolingCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
|
|
|
@ -45,9 +45,9 @@ void SoftmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
axis += SizeToInt(src_shape.size());
|
||||
}
|
||||
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
|
||||
dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, src_desc, axis);
|
||||
auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, engine_);
|
||||
primitive_ = std::make_shared<dnnl::softmax_forward>(prim_desc);
|
||||
auto desc = CreateDesc<dnnl::softmax_forward::desc>(dnnl::prop_kind::forward_training, src_desc, axis);
|
||||
auto prim_desc = CreateDesc<dnnl::softmax_forward::primitive_desc>(desc, engine_);
|
||||
primitive_ = CreatePrimitive<dnnl::softmax_forward>(prim_desc);
|
||||
AddArgument(DNNL_ARG_SRC, src_desc);
|
||||
AddArgument(DNNL_ARG_DST, src_desc);
|
||||
}
|
||||
|
|
|
@ -52,11 +52,11 @@ void SoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_n
|
|||
if (batch_size_ == 0 || class_num_ == 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid batch size or class num input!";
|
||||
}
|
||||
dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc);
|
||||
auto mem_desc = CreateDesc<dnnl::memory::desc>(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc);
|
||||
|
||||
dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1);
|
||||
auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, engine_);
|
||||
primitive_ = std::make_shared<dnnl::softmax_forward>(prim_desc);
|
||||
auto desc = CreateDesc<dnnl::softmax_forward::desc>(dnnl::prop_kind::forward_training, mem_desc, 1);
|
||||
auto prim_desc = CreateDesc<dnnl::softmax_forward::primitive_desc>(desc, engine_);
|
||||
primitive_ = CreatePrimitive<dnnl::softmax_forward>(prim_desc);
|
||||
|
||||
AddArgument(DNNL_ARG_SRC, mem_desc);
|
||||
AddArgument(DNNL_ARG_DST, mem_desc);
|
||||
|
|
|
@ -58,11 +58,11 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &ke
|
|||
MS_LOG(EXCEPTION) << "Invalid batch size or class num input!";
|
||||
}
|
||||
is_grad_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, IS_GRAD);
|
||||
dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc);
|
||||
auto mem_desc = CreateDesc<dnnl::memory::desc>(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc);
|
||||
|
||||
dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1);
|
||||
auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, engine_);
|
||||
primitive_ = std::make_shared<dnnl::softmax_forward>(prim_desc);
|
||||
auto desc = CreateDesc<dnnl::softmax_forward::desc>(dnnl::prop_kind::forward_training, mem_desc, 1);
|
||||
auto prim_desc = CreateDesc<dnnl::softmax_forward::primitive_desc>(desc, engine_);
|
||||
primitive_ = CreatePrimitive<dnnl::softmax_forward>(prim_desc);
|
||||
|
||||
AddArgument(DNNL_ARG_SRC, mem_desc);
|
||||
AddArgument(DNNL_ARG_DST, mem_desc);
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_onednn_dfx_log():
|
||||
"""
|
||||
Feature: onednn dfx
|
||||
Description: Use 'Python -O -m pytest -s xxx.py' to enable __debug__
|
||||
Expectation: begin_cnt == end_cnt
|
||||
"""
|
||||
os.environ['GLOG_v'] = '0'
|
||||
log_name = './onednn_dfx.log'
|
||||
# need cd to current dir
|
||||
cmd = "pytest -s test_maxpool_op.py::test_max_pool3d_1 --disable-warnings > {} 2>&1".format(log_name)
|
||||
out = os.popen(cmd)
|
||||
out.read()
|
||||
keyword_begin = "begin to invoke"
|
||||
keyword_end = "end to invoke"
|
||||
begin_cnt = 0
|
||||
end_cnt = 0
|
||||
log_pattern = re.compile(r'\[[A-Z]+\] [A-Z]+\([\d]+\,[a-f\d]+,')
|
||||
# {tid0: [log_line0, log_line1], tid1: [log_line2, log_line3, log_line4], ...}
|
||||
multi_dict = defaultdict(list)
|
||||
|
||||
with open(log_name, "r") as f:
|
||||
for line in f.readlines():
|
||||
if not log_pattern.match(line):
|
||||
continue
|
||||
# log example: '[DEBUG] KERNEL(48810, 7f42b77fd700,python):2022-01-10-16:19:13.521.086 xxxx.'
|
||||
tid = line.split(' ')[1].split(',')[1]
|
||||
line = line.replace('\n', '').replace('\r', '')
|
||||
multi_dict[tid].append(line)
|
||||
|
||||
assert multi_dict.keys() # check empty
|
||||
for tid in multi_dict.keys():
|
||||
if not __debug__:
|
||||
f = open('./onednn_dfx-{}.log'.format(tid), "w")
|
||||
for line in multi_dict[tid]:
|
||||
if not __debug__:
|
||||
f.write(line + '\n')
|
||||
|
||||
if keyword_begin in line:
|
||||
begin_cnt += 1
|
||||
last_begin_line = line
|
||||
elif keyword_end in line:
|
||||
end_cnt += 1
|
||||
|
||||
if begin_cnt != end_cnt:
|
||||
print("\n=======The line below has no pairing end line======")
|
||||
print(last_begin_line)
|
||||
print("===================================================")
|
||||
assert begin_cnt == end_cnt
|
||||
|
||||
if not __debug__:
|
||||
f.close()
|
Loading…
Reference in New Issue