From 576a73cd2a7037dcd361c98d728d1dbe3e4b59e3 Mon Sep 17 00:00:00 2001 From: baihuawei Date: Thu, 18 Jun 2020 21:30:44 +0800 Subject: [PATCH] fix stack overflow and memset use --- .../kernel/cpu/mkldnn/lstm_cpu_kernel.cc | 16 ++++--- .../kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc | 44 +++++++++++++------ 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc index 4fefc2db98d..0a343785f75 100644 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc @@ -81,11 +81,11 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); - dnnl::lstm_forward::desc desc = - dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, - formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, - dst_desc, dst_h_desc, dst_c_desc); - prim_desc_ = dnnl::lstm_forward::primitive_desc(desc, eng); + auto desc = std::make_shared(dnnl::prop_kind::forward_training, direction, src_desc, + src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any), + formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, + dst_h_desc, dst_c_desc); + prim_desc_ = dnnl::lstm_forward::primitive_desc(*desc, eng); primitive_ = std::make_shared(prim_desc_); AddArgument(DNNL_ARG_SRC_LAYER, src_desc); AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); @@ -117,7 +117,11 @@ bool LstmCPUKernel::Launch(const std::vector &inputs, if (has_bias_) { bias_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_); } else { - std::memset(bias_memory.get_data_handle(), 0, prim_desc_.bias_desc().get_size()); + auto ret = + memset_s(bias_memory.get_data_handle(), prim_desc_.bias_desc().get_size(), 0, prim_desc_.bias_desc().get_size()); + if (ret != 0) { + MS_LOG(EXCEPTION) << "bias memset error"; + } } // set handle SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc index 5d138089d6a..d7e7701d85d 100644 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc @@ -79,17 +79,17 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); - dnnl::lstm_forward::desc forward_desc = - dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, - formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, - dst_desc, dst_h_desc, dst_c_desc); - auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(forward_desc, eng); - dnnl::lstm_backward::desc backward_desc = dnnl::lstm_backward::desc( + auto forward_desc = std::make_shared( + dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, + formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, + dst_c_desc); + auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(*forward_desc, eng); + auto backward_desc = std::make_shared( dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc); - prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(backward_desc, eng, prim_forward_desc); + prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(*backward_desc, eng, prim_forward_desc); primitive_ = std::make_shared(prim_backward_desc_); AddArgument(DNNL_ARG_SRC_LAYER, src_desc); @@ -132,7 +132,10 @@ bool LSTMGradCPUKernel::Launch(const std::vector &inputs, if (has_bias_) { bias_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_); } else { - std::memset(bias_memory.get_data_handle(), 0, prim_backward_desc_.bias_desc().get_size()); + if (memset_s(bias_memory.get_data_handle(), prim_backward_desc_.bias_desc().get_size(), 0, + prim_backward_desc_.bias_desc().get_size())) { + MS_LOG(EXCEPTION) << "bias memset error"; + } } // construct bw memory auto diff_weights_memory = dnnl::memory(prim_backward_desc_.diff_weights_layer_desc(), eng); @@ -142,14 +145,29 @@ bool LSTMGradCPUKernel::Launch(const std::vector &inputs, auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); user_diff_weights_memory.set_data_handle(outputs[3]->addr); user_diff_weights_h_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_); - std::memset(user_diff_weights_memory.get_data_handle(), 0, user_diff_weights_memory.get_desc().get_size()); - std::memset(user_diff_weights_h_memory.get_data_handle(), 0, user_diff_weights_h_memory.get_desc().get_size()); + if (memset_s(user_diff_weights_memory.get_data_handle(), user_diff_weights_memory.get_desc().get_size(), 0, + user_diff_weights_memory.get_desc().get_size())) { + MS_LOG(EXCEPTION) << "user weights grad memset error"; + } + if (memset_s(user_diff_weights_h_memory.get_data_handle(), user_diff_weights_h_memory.get_desc().get_size(), 0, + user_diff_weights_h_memory.get_desc().get_size())) { + MS_LOG(EXCEPTION) << "user weights iter grad memset error"; + } if (has_bias_) { diff_bias_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_ + weight_h_size_); } - std::memset(diff_bias_memory.get_data_handle(), 0, prim_backward_desc_.diff_bias_desc().get_size()); - std::memset(diff_weights_memory.get_data_handle(), 0, diff_weights_memory.get_desc().get_size()); - std::memset(diff_weights_h_memory.get_data_handle(), 0, diff_weights_h_memory.get_desc().get_size()); + if (memset_s(diff_bias_memory.get_data_handle(), prim_backward_desc_.diff_bias_desc().get_size(), 0, + prim_backward_desc_.diff_bias_desc().get_size())) { + MS_LOG(EXCEPTION) << "bias grad memset error"; + } + if (memset_s(diff_weights_memory.get_data_handle(), diff_weights_memory.get_desc().get_size(), 0, + diff_weights_memory.get_desc().get_size())) { + MS_LOG(EXCEPTION) << "weights grad memset error"; + } + if (memset_s(diff_weights_h_memory.get_data_handle(), diff_weights_h_memory.get_desc().get_size(), 0, + diff_weights_h_memory.get_desc().get_size())) { + MS_LOG(EXCEPTION) << "weights iter grad memset error"; + } SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);