forked from mindspore-Ecosystem/mindspore
fix stack overflow and memset use
This commit is contained in:
parent
7f7c006acf
commit
576a73cd2a
|
@ -81,11 +81,11 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
|
||||
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
|
||||
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
|
||||
dnnl::lstm_forward::desc desc =
|
||||
dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
|
||||
formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc,
|
||||
dst_desc, dst_h_desc, dst_c_desc);
|
||||
prim_desc_ = dnnl::lstm_forward::primitive_desc(desc, eng);
|
||||
auto desc = std::make_shared<dnnl::lstm_forward::desc>(dnnl::prop_kind::forward_training, direction, src_desc,
|
||||
src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any),
|
||||
formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc,
|
||||
dst_h_desc, dst_c_desc);
|
||||
prim_desc_ = dnnl::lstm_forward::primitive_desc(*desc, eng);
|
||||
primitive_ = std::make_shared<dnnl::lstm_forward>(prim_desc_);
|
||||
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
|
||||
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
|
||||
|
@ -117,7 +117,11 @@ bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
if (has_bias_) {
|
||||
bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
|
||||
} else {
|
||||
std::memset(bias_memory.get_data_handle(), 0, prim_desc_.bias_desc().get_size());
|
||||
auto ret =
|
||||
memset_s(bias_memory.get_data_handle(), prim_desc_.bias_desc().get_size(), 0, prim_desc_.bias_desc().get_size());
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "bias memset error";
|
||||
}
|
||||
}
|
||||
// set handle
|
||||
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
|
||||
|
|
|
@ -79,17 +79,17 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
|
||||
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
|
||||
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
|
||||
dnnl::lstm_forward::desc forward_desc =
|
||||
dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
|
||||
formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc,
|
||||
dst_desc, dst_h_desc, dst_c_desc);
|
||||
auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(forward_desc, eng);
|
||||
dnnl::lstm_backward::desc backward_desc = dnnl::lstm_backward::desc(
|
||||
auto forward_desc = std::make_shared<dnnl::lstm_forward::desc>(
|
||||
dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
|
||||
formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc,
|
||||
dst_c_desc);
|
||||
auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(*forward_desc, eng);
|
||||
auto backward_desc = std::make_shared<dnnl::lstm_backward::desc>(
|
||||
dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any),
|
||||
formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc,
|
||||
src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc,
|
||||
dst_h_desc, dst_c_desc);
|
||||
prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(backward_desc, eng, prim_forward_desc);
|
||||
prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(*backward_desc, eng, prim_forward_desc);
|
||||
primitive_ = std::make_shared<dnnl::lstm_backward>(prim_backward_desc_);
|
||||
|
||||
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
|
||||
|
@ -132,7 +132,10 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
if (has_bias_) {
|
||||
bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
|
||||
} else {
|
||||
std::memset(bias_memory.get_data_handle(), 0, prim_backward_desc_.bias_desc().get_size());
|
||||
if (memset_s(bias_memory.get_data_handle(), prim_backward_desc_.bias_desc().get_size(), 0,
|
||||
prim_backward_desc_.bias_desc().get_size())) {
|
||||
MS_LOG(EXCEPTION) << "bias memset error";
|
||||
}
|
||||
}
|
||||
// construct bw memory
|
||||
auto diff_weights_memory = dnnl::memory(prim_backward_desc_.diff_weights_layer_desc(), eng);
|
||||
|
@ -142,14 +145,29 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
|
||||
user_diff_weights_memory.set_data_handle(outputs[3]->addr);
|
||||
user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_);
|
||||
std::memset(user_diff_weights_memory.get_data_handle(), 0, user_diff_weights_memory.get_desc().get_size());
|
||||
std::memset(user_diff_weights_h_memory.get_data_handle(), 0, user_diff_weights_h_memory.get_desc().get_size());
|
||||
if (memset_s(user_diff_weights_memory.get_data_handle(), user_diff_weights_memory.get_desc().get_size(), 0,
|
||||
user_diff_weights_memory.get_desc().get_size())) {
|
||||
MS_LOG(EXCEPTION) << "user weights grad memset error";
|
||||
}
|
||||
if (memset_s(user_diff_weights_h_memory.get_data_handle(), user_diff_weights_h_memory.get_desc().get_size(), 0,
|
||||
user_diff_weights_h_memory.get_desc().get_size())) {
|
||||
MS_LOG(EXCEPTION) << "user weights iter grad memset error";
|
||||
}
|
||||
if (has_bias_) {
|
||||
diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_);
|
||||
}
|
||||
std::memset(diff_bias_memory.get_data_handle(), 0, prim_backward_desc_.diff_bias_desc().get_size());
|
||||
std::memset(diff_weights_memory.get_data_handle(), 0, diff_weights_memory.get_desc().get_size());
|
||||
std::memset(diff_weights_h_memory.get_data_handle(), 0, diff_weights_h_memory.get_desc().get_size());
|
||||
if (memset_s(diff_bias_memory.get_data_handle(), prim_backward_desc_.diff_bias_desc().get_size(), 0,
|
||||
prim_backward_desc_.diff_bias_desc().get_size())) {
|
||||
MS_LOG(EXCEPTION) << "bias grad memset error";
|
||||
}
|
||||
if (memset_s(diff_weights_memory.get_data_handle(), diff_weights_memory.get_desc().get_size(), 0,
|
||||
diff_weights_memory.get_desc().get_size())) {
|
||||
MS_LOG(EXCEPTION) << "weights grad memset error";
|
||||
}
|
||||
if (memset_s(diff_weights_h_memory.get_data_handle(), diff_weights_h_memory.get_desc().get_size(), 0,
|
||||
diff_weights_h_memory.get_desc().get_size())) {
|
||||
MS_LOG(EXCEPTION) << "weights iter grad memset error";
|
||||
}
|
||||
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);
|
||||
|
|
Loading…
Reference in New Issue