forked from mindspore-Ecosystem/mindspore
!2307 fix stack overflow and memset use risk
Merge pull request !2307 from baihuawei/cpulstm
This commit is contained in:
commit
a39f5452c5
|
@ -81,11 +81,11 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
|
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
|
||||||
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
|
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
|
||||||
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
|
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
|
||||||
dnnl::lstm_forward::desc desc =
|
auto desc = std::make_shared<dnnl::lstm_forward::desc>(dnnl::prop_kind::forward_training, direction, src_desc,
|
||||||
dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
|
src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any),
|
||||||
formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc,
|
formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc,
|
||||||
dst_desc, dst_h_desc, dst_c_desc);
|
dst_h_desc, dst_c_desc);
|
||||||
prim_desc_ = dnnl::lstm_forward::primitive_desc(desc, eng);
|
prim_desc_ = dnnl::lstm_forward::primitive_desc(*desc, eng);
|
||||||
primitive_ = std::make_shared<dnnl::lstm_forward>(prim_desc_);
|
primitive_ = std::make_shared<dnnl::lstm_forward>(prim_desc_);
|
||||||
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
|
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
|
||||||
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
|
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
|
||||||
|
@ -117,7 +117,11 @@ bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
if (has_bias_) {
|
if (has_bias_) {
|
||||||
bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
|
bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
|
||||||
} else {
|
} else {
|
||||||
std::memset(bias_memory.get_data_handle(), 0, prim_desc_.bias_desc().get_size());
|
auto ret =
|
||||||
|
memset_s(bias_memory.get_data_handle(), prim_desc_.bias_desc().get_size(), 0, prim_desc_.bias_desc().get_size());
|
||||||
|
if (ret != 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "bias memset error";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// set handle
|
// set handle
|
||||||
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
|
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
|
||||||
|
|
|
@ -79,17 +79,17 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
|
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
|
||||||
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
|
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
|
||||||
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
|
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
|
||||||
dnnl::lstm_forward::desc forward_desc =
|
auto forward_desc = std::make_shared<dnnl::lstm_forward::desc>(
|
||||||
dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
|
dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
|
||||||
formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc,
|
formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc,
|
||||||
dst_desc, dst_h_desc, dst_c_desc);
|
dst_c_desc);
|
||||||
auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(forward_desc, eng);
|
auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(*forward_desc, eng);
|
||||||
dnnl::lstm_backward::desc backward_desc = dnnl::lstm_backward::desc(
|
auto backward_desc = std::make_shared<dnnl::lstm_backward::desc>(
|
||||||
dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any),
|
dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any),
|
||||||
formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc,
|
formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc,
|
||||||
src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc,
|
src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc,
|
||||||
dst_h_desc, dst_c_desc);
|
dst_h_desc, dst_c_desc);
|
||||||
prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(backward_desc, eng, prim_forward_desc);
|
prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(*backward_desc, eng, prim_forward_desc);
|
||||||
primitive_ = std::make_shared<dnnl::lstm_backward>(prim_backward_desc_);
|
primitive_ = std::make_shared<dnnl::lstm_backward>(prim_backward_desc_);
|
||||||
|
|
||||||
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
|
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
|
||||||
|
@ -132,7 +132,10 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
if (has_bias_) {
|
if (has_bias_) {
|
||||||
bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
|
bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
|
||||||
} else {
|
} else {
|
||||||
std::memset(bias_memory.get_data_handle(), 0, prim_backward_desc_.bias_desc().get_size());
|
if (memset_s(bias_memory.get_data_handle(), prim_backward_desc_.bias_desc().get_size(), 0,
|
||||||
|
prim_backward_desc_.bias_desc().get_size())) {
|
||||||
|
MS_LOG(EXCEPTION) << "bias memset error";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// construct bw memory
|
// construct bw memory
|
||||||
auto diff_weights_memory = dnnl::memory(prim_backward_desc_.diff_weights_layer_desc(), eng);
|
auto diff_weights_memory = dnnl::memory(prim_backward_desc_.diff_weights_layer_desc(), eng);
|
||||||
|
@ -142,14 +145,29 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
|
auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
|
||||||
user_diff_weights_memory.set_data_handle(outputs[3]->addr);
|
user_diff_weights_memory.set_data_handle(outputs[3]->addr);
|
||||||
user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_);
|
user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_);
|
||||||
std::memset(user_diff_weights_memory.get_data_handle(), 0, user_diff_weights_memory.get_desc().get_size());
|
if (memset_s(user_diff_weights_memory.get_data_handle(), user_diff_weights_memory.get_desc().get_size(), 0,
|
||||||
std::memset(user_diff_weights_h_memory.get_data_handle(), 0, user_diff_weights_h_memory.get_desc().get_size());
|
user_diff_weights_memory.get_desc().get_size())) {
|
||||||
|
MS_LOG(EXCEPTION) << "user weights grad memset error";
|
||||||
|
}
|
||||||
|
if (memset_s(user_diff_weights_h_memory.get_data_handle(), user_diff_weights_h_memory.get_desc().get_size(), 0,
|
||||||
|
user_diff_weights_h_memory.get_desc().get_size())) {
|
||||||
|
MS_LOG(EXCEPTION) << "user weights iter grad memset error";
|
||||||
|
}
|
||||||
if (has_bias_) {
|
if (has_bias_) {
|
||||||
diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_);
|
diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_);
|
||||||
}
|
}
|
||||||
std::memset(diff_bias_memory.get_data_handle(), 0, prim_backward_desc_.diff_bias_desc().get_size());
|
if (memset_s(diff_bias_memory.get_data_handle(), prim_backward_desc_.diff_bias_desc().get_size(), 0,
|
||||||
std::memset(diff_weights_memory.get_data_handle(), 0, diff_weights_memory.get_desc().get_size());
|
prim_backward_desc_.diff_bias_desc().get_size())) {
|
||||||
std::memset(diff_weights_h_memory.get_data_handle(), 0, diff_weights_h_memory.get_desc().get_size());
|
MS_LOG(EXCEPTION) << "bias grad memset error";
|
||||||
|
}
|
||||||
|
if (memset_s(diff_weights_memory.get_data_handle(), diff_weights_memory.get_desc().get_size(), 0,
|
||||||
|
diff_weights_memory.get_desc().get_size())) {
|
||||||
|
MS_LOG(EXCEPTION) << "weights grad memset error";
|
||||||
|
}
|
||||||
|
if (memset_s(diff_weights_h_memory.get_data_handle(), diff_weights_h_memory.get_desc().get_size(), 0,
|
||||||
|
diff_weights_h_memory.get_desc().get_size())) {
|
||||||
|
MS_LOG(EXCEPTION) << "weights iter grad memset error";
|
||||||
|
}
|
||||||
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
|
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
|
||||||
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr);
|
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr);
|
||||||
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);
|
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);
|
||||||
|
|
Loading…
Reference in New Issue