forked from mindspore-Ecosystem/mindspore
fix lstm
This commit is contained in:
parent
bbc64b8873
commit
ea78e16e74
|
@ -24,17 +24,20 @@ namespace kernel {
|
|||
void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
|
||||
input_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "input_size");
|
||||
hidden_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "hidden_size");
|
||||
num_layers_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "num_layers");
|
||||
has_bias_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "has_bias");
|
||||
batch_size_ = SizeToInt(src_shape[1]);
|
||||
seq_len_ = SizeToInt(src_shape[0]);
|
||||
num_directions_ = 1;
|
||||
if (bidirectional_) {
|
||||
num_directions_ = 2;
|
||||
}
|
||||
int gate_size = 4 * hidden_size_;
|
||||
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) MS_LOG(EXCEPTION) << "error iteration shape!";
|
||||
const int gate_size = 4 * hidden_size_;
|
||||
for (int i = 0; i < num_layers_; ++i) {
|
||||
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
|
||||
weight_h_size_ += gate_size * hidden_size_;
|
||||
|
@ -52,11 +55,11 @@ bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
auto eng = MKLKernelEngine::Get().engine();
|
||||
dnnl::stream s(eng);
|
||||
auto formatted_md = [](dim dimensions, tag layout) { return dnnl::memory::desc{{dimensions}, dt::f32, layout}; };
|
||||
auto generic_md = [](dim dimensions) { return dnnl::memory::desc{{dimensions}, dt::f32, tag::any}; };
|
||||
dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional;
|
||||
if (bidirectional_) {
|
||||
direction = dnnl::rnn_direction::bidirectional_concat;
|
||||
}
|
||||
|
||||
dim src_dims = {seq_len_, batch_size_, input_size_};
|
||||
dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
|
||||
dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
|
||||
|
@ -69,35 +72,43 @@ bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc);
|
||||
dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc);
|
||||
dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc);
|
||||
dnnl::memory::desc weights_desc = formatted_md(weights_dims, tag::ldigo);
|
||||
dnnl::memory::desc weights_h_desc = formatted_md(weights_h_dims, tag::ldigo);
|
||||
dnnl::memory::desc bias_desc = formatted_md(bias_dims, tag::ldgo);
|
||||
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
|
||||
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
|
||||
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
|
||||
dnnl::lstm_forward::desc desc =
|
||||
dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
|
||||
weights_desc, weights_h_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc);
|
||||
dnnl::lstm_forward::desc desc = dnnl::lstm_forward::desc(
|
||||
dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims),
|
||||
generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc);
|
||||
auto prim_desc = dnnl::lstm_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
|
||||
|
||||
// construct fw memory
|
||||
auto workspace_memory = dnnl::memory(prim_desc.workspace_desc(), eng);
|
||||
auto src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng);
|
||||
write_to_dnnl_memory(inputs[0]->addr, src_memory);
|
||||
|
||||
auto src_h_memory = dnnl::memory(prim_desc.src_iter_desc(), eng);
|
||||
auto src_c_memory = dnnl::memory(prim_desc.src_iter_c_desc(), eng);
|
||||
write_to_dnnl_memory(inputs[1]->addr, src_h_memory);
|
||||
write_to_dnnl_memory(inputs[2]->addr, src_c_memory);
|
||||
|
||||
auto weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldigo), eng);
|
||||
auto weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldigo), eng);
|
||||
auto bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng);
|
||||
write_to_dnnl_memory(inputs[3]->addr, weights_memory);
|
||||
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_, weights_h_memory);
|
||||
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_, bias_memory);
|
||||
src_memory.set_data_handle(inputs[0]->addr);
|
||||
auto src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng);
|
||||
auto src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng);
|
||||
src_h_memory.set_data_handle(inputs[1]->addr);
|
||||
src_c_memory.set_data_handle(inputs[2]->addr);
|
||||
auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng);
|
||||
auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng);
|
||||
user_weights_memory.set_data_handle(inputs[3]->addr);
|
||||
user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_);
|
||||
auto weights_memory = dnnl::memory(prim_desc.weights_layer_desc(), eng);
|
||||
auto weights_h_memory = dnnl::memory(prim_desc.weights_iter_desc(), eng);
|
||||
dnnl::reorder(user_weights_memory, weights_memory).execute(s, user_weights_memory, weights_memory);
|
||||
dnnl::reorder(user_weights_h_memory, weights_h_memory).execute(s, user_weights_h_memory, weights_h_memory);
|
||||
|
||||
auto bias_memory = dnnl::memory(prim_desc.bias_desc(), eng);
|
||||
if (has_bias_) {
|
||||
auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng);
|
||||
user_bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
|
||||
dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory);
|
||||
} else {
|
||||
std::vector<float> net_bias(bias_memory.get_desc().get_size(), 0.0f);
|
||||
write_to_dnnl_memory(net_bias.data(), bias_memory);
|
||||
}
|
||||
auto dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng);
|
||||
auto dst_h_memory = dnnl::memory(prim_desc.dst_iter_desc(), eng);
|
||||
auto dst_c_memory = dnnl::memory(prim_desc.dst_iter_c_desc(), eng);
|
||||
auto dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng);
|
||||
auto dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng);
|
||||
dnnl::lstm_forward fw_layer(prim_desc);
|
||||
workspace_memory.set_data_handle(outputs[3]->addr);
|
||||
dst_memory.set_data_handle(outputs[0]->addr);
|
||||
|
@ -113,8 +124,8 @@ bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
{DNNL_ARG_DST_ITER, dst_h_memory},
|
||||
{DNNL_ARG_DST_ITER_C, dst_c_memory},
|
||||
{DNNL_ARG_WORKSPACE, workspace_memory}});
|
||||
s.wait();
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,6 +40,7 @@ class LstmCPUKernel : public MKLCPUKernel {
|
|||
int seq_len_;
|
||||
int num_directions_;
|
||||
bool bidirectional_;
|
||||
bool has_bias_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(LSTM,
|
||||
|
|
|
@ -28,17 +28,20 @@ namespace kernel {
|
|||
void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
|
||||
input_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "input_size");
|
||||
hidden_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "hidden_size");
|
||||
num_layers_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "num_layers");
|
||||
has_bias_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "has_bias");
|
||||
batch_size_ = SizeToInt(src_shape[1]);
|
||||
seq_len_ = SizeToInt(src_shape[0]);
|
||||
num_directions_ = 1;
|
||||
if (bidirectional_) {
|
||||
num_directions_ = 2;
|
||||
}
|
||||
int gate_size = 4 * hidden_size_;
|
||||
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) MS_LOG(EXCEPTION) << "error iteration shape!";
|
||||
const int gate_size = 4 * hidden_size_;
|
||||
for (int i = 0; i < num_layers_; ++i) {
|
||||
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
|
||||
weight_h_size_ += gate_size * hidden_size_;
|
||||
|
@ -70,79 +73,92 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_};
|
||||
dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
|
||||
dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
|
||||
|
||||
dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc);
|
||||
dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc);
|
||||
dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc);
|
||||
dnnl::memory::desc weights_desc = formatted_md(weights_dims, tag::ldigo);
|
||||
dnnl::memory::desc weights_h_desc = formatted_md(weights_h_dims, tag::ldigo);
|
||||
dnnl::memory::desc bias_desc = formatted_md(bias_dims, tag::ldgo);
|
||||
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
|
||||
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
|
||||
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
|
||||
|
||||
dnnl::lstm_forward::desc forward_desc =
|
||||
dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
|
||||
weights_desc, weights_h_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc);
|
||||
dnnl::lstm_forward::desc forward_desc = dnnl::lstm_forward::desc(
|
||||
dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims),
|
||||
generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc);
|
||||
auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(forward_desc, eng);
|
||||
|
||||
dnnl::lstm_backward::desc backward_desc = dnnl::lstm_backward::desc(
|
||||
dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims),
|
||||
generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc,
|
||||
src_c_desc, weights_desc, weights_h_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc);
|
||||
dnnl::lstm_backward::desc backward_desc =
|
||||
dnnl::lstm_backward::desc(dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc,
|
||||
generic_md(weights_dims), generic_md(weights_h_dims), generic_md(bias_dims), dst_desc,
|
||||
dst_h_desc, dst_c_desc, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims),
|
||||
generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc);
|
||||
auto prim_backward_desc = dnnl::lstm_backward::primitive_desc(backward_desc, eng, prim_forward_desc);
|
||||
|
||||
// construct fw memory
|
||||
auto src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng);
|
||||
write_to_dnnl_memory(inputs[0]->addr, src_memory);
|
||||
|
||||
auto src_h_memory = dnnl::memory(prim_forward_desc.src_iter_desc(), eng);
|
||||
auto src_c_memory = dnnl::memory(prim_forward_desc.src_iter_c_desc(), eng);
|
||||
write_to_dnnl_memory(inputs[1]->addr, src_h_memory);
|
||||
write_to_dnnl_memory(inputs[2]->addr, src_c_memory);
|
||||
|
||||
auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldigo), eng);
|
||||
auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldigo), eng);
|
||||
auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng);
|
||||
write_to_dnnl_memory(inputs[3]->addr, user_weights_memory);
|
||||
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_, user_weights_h_memory);
|
||||
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_, user_bias_memory);
|
||||
src_memory.set_data_handle(inputs[0]->addr);
|
||||
auto src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng);
|
||||
auto src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng);
|
||||
src_h_memory.set_data_handle(inputs[1]->addr);
|
||||
src_c_memory.set_data_handle(inputs[2]->addr);
|
||||
auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng);
|
||||
auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng);
|
||||
user_weights_memory.set_data_handle(inputs[3]->addr);
|
||||
user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_);
|
||||
auto weights_memory = dnnl::memory(prim_backward_desc.weights_layer_desc(), eng);
|
||||
auto weights_h_memory = dnnl::memory(prim_backward_desc.weights_iter_desc(), eng);
|
||||
auto bias_memory = dnnl::memory(prim_forward_desc.bias_desc(), eng);
|
||||
dnnl::reorder(user_weights_memory, weights_memory).execute(s, user_weights_memory, weights_memory);
|
||||
dnnl::reorder(user_weights_h_memory, weights_h_memory).execute(s, user_weights_h_memory, weights_h_memory);
|
||||
dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory);
|
||||
|
||||
// construct bias memory
|
||||
auto bias_memory = dnnl::memory(prim_backward_desc.bias_desc(), eng);
|
||||
if (has_bias_) {
|
||||
auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng);
|
||||
user_bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
|
||||
dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory);
|
||||
} else {
|
||||
std::vector<float> net_bias(bias_memory.get_desc().get_size(), 0.0f);
|
||||
write_to_dnnl_memory(net_bias.data(), bias_memory);
|
||||
}
|
||||
|
||||
auto dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng);
|
||||
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[4]->addr), dst_memory);
|
||||
auto dst_h_memory = dnnl::memory(prim_backward_desc.dst_iter_desc(), eng);
|
||||
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[5]->addr), dst_h_memory);
|
||||
auto dst_c_memory = dnnl::memory(prim_backward_desc.dst_iter_c_desc(), eng);
|
||||
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[6]->addr), dst_c_memory);
|
||||
dst_memory.set_data_handle(inputs[4]->addr);
|
||||
auto dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng);
|
||||
auto dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng);
|
||||
dst_h_memory.set_data_handle(inputs[5]->addr);
|
||||
dst_c_memory.set_data_handle(inputs[6]->addr);
|
||||
auto workspace_memory = dnnl::memory(prim_forward_desc.workspace_desc(), eng);
|
||||
write_to_dnnl_memory(inputs[10]->addr, workspace_memory);
|
||||
workspace_memory.set_data_handle(inputs[10]->addr);
|
||||
|
||||
// construct diff memory
|
||||
// construct bw memory
|
||||
std::vector<float> net_w(weights_memory.get_desc().get_size(), 0.0f);
|
||||
std::vector<float> net_wh(weights_h_memory.get_desc().get_size(), 0.0f);
|
||||
auto diff_src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng);
|
||||
auto diff_src_h_memory = dnnl::memory(prim_backward_desc.diff_src_iter_desc(), eng);
|
||||
auto diff_src_c_memory = dnnl::memory(prim_backward_desc.diff_src_iter_c_desc(), eng);
|
||||
|
||||
auto diff_src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng);
|
||||
auto diff_src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng);
|
||||
auto user_diff_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng);
|
||||
auto user_diff_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng);
|
||||
auto diff_weights_memory = dnnl::memory(prim_backward_desc.diff_weights_layer_desc(), eng);
|
||||
auto diff_weights_h_memory = dnnl::memory(prim_backward_desc.diff_weights_iter_desc(), eng);
|
||||
write_to_dnnl_memory(net_w.data(), diff_weights_memory);
|
||||
write_to_dnnl_memory(net_wh.data(), diff_weights_h_memory);
|
||||
auto user_diff_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng);
|
||||
auto diff_bias_memory = dnnl::memory(prim_backward_desc.diff_bias_desc(), eng);
|
||||
auto diff_dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng);
|
||||
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[7]->addr), diff_dst_memory);
|
||||
auto diff_dst_h_memory = dnnl::memory(prim_backward_desc.diff_dst_iter_desc(), eng);
|
||||
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[8]->addr), diff_dst_h_memory);
|
||||
auto diff_dst_c_memory = dnnl::memory(prim_backward_desc.diff_dst_iter_c_desc(), eng);
|
||||
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[9]->addr), diff_dst_c_memory);
|
||||
write_to_dnnl_memory(net_w.data(), diff_bias_memory);
|
||||
|
||||
auto diff_dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng);
|
||||
diff_dst_memory.set_data_handle(inputs[7]->addr);
|
||||
auto diff_dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng);
|
||||
diff_dst_h_memory.set_data_handle(inputs[8]->addr);
|
||||
auto diff_dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng);
|
||||
diff_dst_c_memory.set_data_handle(inputs[9]->addr);
|
||||
diff_src_memory.set_data_handle(outputs[0]->addr);
|
||||
diff_src_h_memory.set_data_handle(outputs[1]->addr);
|
||||
diff_src_c_memory.set_data_handle(outputs[2]->addr);
|
||||
diff_weights_memory.set_data_handle(outputs[3]->addr);
|
||||
diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_);
|
||||
diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_);
|
||||
user_diff_weights_memory.set_data_handle(outputs[3]->addr);
|
||||
user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_);
|
||||
write_to_dnnl_memory(net_w.data(), user_diff_weights_memory);
|
||||
write_to_dnnl_memory(net_wh.data(), user_diff_weights_h_memory);
|
||||
|
||||
// construct bw bias memory
|
||||
user_diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_);
|
||||
write_to_dnnl_memory(net_w.data(), user_diff_bias_memory);
|
||||
dnnl::lstm_backward bwd_layer(prim_backward_desc);
|
||||
bwd_layer.execute(s, {{DNNL_ARG_SRC_LAYER, src_memory},
|
||||
{DNNL_ARG_SRC_ITER, src_h_memory},
|
||||
|
@ -163,6 +179,16 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
{DNNL_ARG_DIFF_DST_ITER, diff_dst_h_memory},
|
||||
{DNNL_ARG_DIFF_DST_ITER_C, diff_dst_c_memory},
|
||||
{DNNL_ARG_WORKSPACE, workspace_memory}});
|
||||
dnnl::reorder(diff_weights_memory, user_diff_weights_memory)
|
||||
.execute(s, diff_weights_memory, user_diff_weights_memory);
|
||||
dnnl::reorder(diff_weights_h_memory, user_diff_weights_h_memory)
|
||||
.execute(s, diff_weights_h_memory, user_diff_weights_h_memory);
|
||||
if (has_bias_) {
|
||||
dnnl::reorder(diff_bias_memory, user_diff_bias_memory).execute(s, diff_bias_memory, user_diff_bias_memory);
|
||||
} else {
|
||||
write_to_dnnl_memory(net_w.data(), user_diff_bias_memory);
|
||||
}
|
||||
s.wait();
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -41,6 +41,7 @@ class LSTMGradCPUKernel : public MKLCPUKernel {
|
|||
int seq_len_;
|
||||
int num_directions_;
|
||||
bool bidirectional_;
|
||||
bool has_bias_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(LSTMGrad,
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"""lstm"""
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore import context
|
||||
|
@ -149,21 +149,28 @@ class LSTM(Cell):
|
|||
weight_size += increment_size * num_directions
|
||||
self.weight = Parameter(initializer(0.0, [weight_size, 1, 1]), name='weight')
|
||||
else:
|
||||
layer = []
|
||||
layer.append(nn.LSTMCell(input_size=self.input_size,
|
||||
hidden_size=self.hidden_size,
|
||||
layer_index=0,
|
||||
has_bias=self.has_bias,
|
||||
bidirectional=self.bidirectional,
|
||||
dropout=self.dropout))
|
||||
for i in range(num_layers - 1):
|
||||
layer.append(nn.LSTMCell(input_size=self.hidden_size * num_directions,
|
||||
hidden_size=self.hidden_size,
|
||||
layer_index=i + 1,
|
||||
has_bias=self.has_bias,
|
||||
bidirectional=self.bidirectional,
|
||||
dropout=self.dropout))
|
||||
self.lstms = layer
|
||||
input_size_list = []
|
||||
input_size_list.append(self.input_size)
|
||||
for i in range(self.num_layers - 1):
|
||||
input_size_list.append(self.hidden_size * num_directions)
|
||||
weights = []
|
||||
layers = []
|
||||
bias_size = 0 if not self.has_bias else num_directions * self.hidden_size * 4
|
||||
for i in range(num_layers):
|
||||
weight_size = (input_size_list[i] + self.hidden_size) * num_directions * self.hidden_size * 4
|
||||
w_np = np.ones([weight_size, 1, 1]).astype(np.float32) * 0.01
|
||||
if has_bias:
|
||||
bias_np = np.zeros([bias_size, 1, 1]).astype(np.float32)
|
||||
w_np = np.concatenate([w_np, bias_np], axis=0)
|
||||
weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name='weight' + str(i)))
|
||||
|
||||
layers.append(nn.LSTMCell(input_size=input_size_list[i],
|
||||
hidden_size=self.hidden_size,
|
||||
has_bias=self.has_bias,
|
||||
bidirectional=self.bidirectional,
|
||||
dropout=self.dropout))
|
||||
self.lstms = layers
|
||||
self.weight = ParameterTuple(tuple(weights))
|
||||
self.fill = P.Fill()
|
||||
self.shape = P.Shape()
|
||||
|
||||
|
@ -177,12 +184,12 @@ class LSTM(Cell):
|
|||
output = self.transpose2(output, (1, 0, 2))
|
||||
return (output, (h, c))
|
||||
h, c = hx
|
||||
output, hn, cn, _, _ = self.lstms[0](x, h[0], c[0])
|
||||
output, hn, cn, _, _ = self.lstms[0](x, h[0], c[0], self.weight[0])
|
||||
for i in range(1, self.num_layers):
|
||||
output, hn, cn, _, _ = self.lstms[i](output, h[i], c[i])
|
||||
output, hn, cn, _, _ = self.lstms[i](output, h[i], c[i], self.weight[i])
|
||||
if self.batch_first:
|
||||
output = self.transpose2(output, (1, 0, 2))
|
||||
return output, hn, cn, _, _
|
||||
return (output, (hn, cn))
|
||||
|
||||
|
||||
class LSTMCell(Cell):
|
||||
|
@ -271,11 +278,9 @@ class LSTMCell(Cell):
|
|||
>>> output, hn, cn, _, _ = net(input, h0, c0)
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self,
|
||||
input_size,
|
||||
hidden_size,
|
||||
layer_index=0,
|
||||
has_bias=True,
|
||||
batch_first=False,
|
||||
dropout=0,
|
||||
|
@ -283,8 +288,6 @@ class LSTMCell(Cell):
|
|||
super(LSTMCell, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = 1
|
||||
self.layer_index = layer_index
|
||||
self.has_bias = has_bias
|
||||
self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
|
||||
self.dropout = float(dropout)
|
||||
|
@ -295,16 +298,7 @@ class LSTMCell(Cell):
|
|||
if self.batch_first:
|
||||
self.transpose1 = P.Transpose()
|
||||
self.transpose2 = P.Transpose()
|
||||
w_np = np.ones([(self.input_size + self.hidden_size) * self.num_directions * self.hidden_size * 4, 1]).astype(
|
||||
np.float32) * 0.01
|
||||
if has_bias:
|
||||
b_np = np.ones([self.num_directions * self.hidden_size * 4, 1]).astype(
|
||||
np.float32) * 0.01
|
||||
else:
|
||||
b_np = np.zeros([self.num_directions * self.hidden_size * 4, 1]).astype(
|
||||
np.float32) * 0.01
|
||||
wb_np = np.concatenate((w_np, b_np), axis=0).reshape([-1, 1, 1])
|
||||
self.w = Parameter(initializer(Tensor(wb_np), wb_np.shape), name='w' + str(self.layer_index))
|
||||
|
||||
self.lstm = P.LSTM(input_size=self.input_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=1,
|
||||
|
@ -312,10 +306,10 @@ class LSTMCell(Cell):
|
|||
bidirectional=self.bidirectional,
|
||||
dropout=self.dropout)
|
||||
|
||||
def construct(self, x, h, c):
|
||||
def construct(self, x, h, c, w):
|
||||
if self.batch_first:
|
||||
x = self.transpose1(x, (1, 0, 2))
|
||||
output, hn, cn, _, _ = self.lstm(x, h, c, self.w)
|
||||
output, hn, cn, _, _ = self.lstm(x, h, c, w)
|
||||
if self.batch_first:
|
||||
output = self.transpose2(output, (1, 0, 2))
|
||||
return output, hn, cn, _, _
|
||||
|
|
|
@ -35,27 +35,22 @@ class LstmNet(nn.Cell):
|
|||
if bidirectional:
|
||||
num_directions = 2
|
||||
|
||||
self.lstm = P.LSTM(input_size, hidden_size, num_layers, has_bias, bidirectional, dropout)
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, has_bias, bidirectional, dropout)
|
||||
input_np = np.array([[[0.6755, -1.6607, 0.1367], [0.4276, -0.7850, -0.3758]],
|
||||
[[-0.6424, -0.6095, 0.6639], [0.7918, 0.4147, -0.5089]],
|
||||
[[-1.5612, 0.0120, -0.7289], [-0.6656, -0.6626, -0.5883]],
|
||||
[[-0.9667, -0.6296, -0.7310], [0.1026, -0.6821, -0.4387]],
|
||||
[[-0.4710, 0.6558, -0.3144], [-0.8449, -0.2184, -0.1806]]
|
||||
]).astype(np.float32)
|
||||
self.x = Parameter(initializer(Tensor(input_np), [seq_len, batch_size, input_size]), name='x')
|
||||
self.x = Tensor(input_np)
|
||||
|
||||
self.h = Parameter(initializer(
|
||||
Tensor(
|
||||
np.array([0.1, 0.1, 0.1, 0.1]).reshape((num_layers * num_directions, batch_size, hidden_size)).astype(
|
||||
np.float32)),
|
||||
[num_layers * num_directions, batch_size, hidden_size]), name='h')
|
||||
|
||||
self.c = Parameter(initializer(
|
||||
Tensor(
|
||||
np.array([0.2, 0.2, 0.2, 0.2]).reshape((num_layers * num_directions, batch_size, hidden_size)).astype(
|
||||
np.float32)),
|
||||
[num_layers * num_directions, batch_size, hidden_size]), name='c')
|
||||
self.h = Tensor(np.array([0., 0., 0., 0.]).reshape((num_directions, batch_size, hidden_size)).astype(
|
||||
np.float32))
|
||||
|
||||
self.c = Tensor(np.array([0., 0., 0., 0.]).reshape((num_directions, batch_size, hidden_size)).astype(
|
||||
np.float32))
|
||||
self.h = tuple((self.h,))
|
||||
self.c = tuple((self.c,))
|
||||
wih = np.array([[3.4021e-01, -4.6622e-01, 4.5117e-01],
|
||||
[-6.4257e-02, -2.4807e-01, 1.3550e-02], # i
|
||||
[-3.2140e-01, 5.5578e-01, 6.3589e-01],
|
||||
|
@ -63,7 +58,7 @@ class LstmNet(nn.Cell):
|
|||
[-6.9863e-01, 5.9773e-01, -3.9062e-01],
|
||||
[-3.0253e-01, -1.9464e-01, 7.0591e-01],
|
||||
[-4.0835e-01, 3.6751e-01, 4.7989e-01],
|
||||
[-5.6894e-01, -5.0359e-01, 4.7491e-01]]).astype(np.float32) # .reshape([1,-1])
|
||||
[-5.6894e-01, -5.0359e-01, 4.7491e-01]]).astype(np.float32).reshape([1, -1])
|
||||
whh = np.array([[-0.4820, -0.2350],
|
||||
[-0.1195, 0.0519],
|
||||
[0.2162, -0.1178],
|
||||
|
@ -71,16 +66,16 @@ class LstmNet(nn.Cell):
|
|||
[0.4511, -0.3961],
|
||||
[-0.5962, 0.0906],
|
||||
[0.1867, -0.1225],
|
||||
[0.1831, 0.0850]]).astype(np.float32) # .reshape([1,-1])
|
||||
wih = wih.transpose((1, 0))
|
||||
whh = whh.transpose((1, 0))
|
||||
[0.1831, 0.0850]]).astype(np.float32).reshape([1, -1])
|
||||
bih = np.zeros((1, 8)).astype(np.float32)
|
||||
w_np = np.concatenate((wih, whh, bih), axis=0).reshape([-1, 1, 1])
|
||||
w_np = np.concatenate((wih, whh, bih), axis=1).reshape([-1, 1, 1])
|
||||
self.w = Parameter(initializer(Tensor(w_np), w_np.shape), name='w')
|
||||
self.lstm.weight = ParameterTuple((self.w,))
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.lstm(self.x, self.h, self.c, self.w)
|
||||
return self.lstm(self.x, (self.h, self.c))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
|
@ -98,40 +93,41 @@ def test_lstm():
|
|||
if bidirectional:
|
||||
num_directions = 2
|
||||
net = LstmNet(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout)
|
||||
y, h, c, _, _ = net()
|
||||
y, (h, c) = net()
|
||||
print(y)
|
||||
print(c)
|
||||
print(h)
|
||||
expect_y = np.array([[[-0.16709016, 0.13125697],
|
||||
[-0.08438572, -0.01969833]],
|
||||
[[-0.2746155, 0.32764038],
|
||||
[-0.06504016, -0.07770399]],
|
||||
[[-0.00140004, 0.17706314],
|
||||
[0.03244496, -0.10135599]],
|
||||
[[0.08328028, 0.06437367],
|
||||
[-0.04133911, -0.11072896]],
|
||||
[[0.19004421, -0.02852732],
|
||||
[0.09138509, -0.00344161]]]
|
||||
)
|
||||
error = np.ones([num_layers, batch_size, hidden_size]) * 1.0e-4
|
||||
diff = y.asnumpy() - expect_y
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
#
|
||||
expect_h = np.array([[[0.19004421, -0.02852732],
|
||||
[0.09138509, -0.00344161]]])
|
||||
expect_y = [[[-0.17992045, 0.07819052],
|
||||
[-0.10745212, -0.06291768]],
|
||||
|
||||
error = np.ones((num_layers * num_directions, batch_size, hidden_size)) * 1.0e-4
|
||||
diff = h.asnumpy() - expect_h
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
#
|
||||
expect_c = np.array([[[0.34533143, -0.06313794],
|
||||
[0.169008, -0.00555446]]])
|
||||
error = np.ones((num_layers * num_directions, batch_size, hidden_size)) * 1.0e-4
|
||||
diff = c.asnumpy() - expect_c
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
[[-0.28830513, 0.30579978],
|
||||
[-0.07570618, -0.08868407]],
|
||||
|
||||
[[-0.00814095, 0.16889746],
|
||||
[0.02814853, -0.11208838]],
|
||||
|
||||
[[0.08157863, 0.06088024],
|
||||
[-0.04227093, -0.11514835]],
|
||||
|
||||
[[0.18908429, -0.02963362],
|
||||
[0.09106826, -0.00602506]]]
|
||||
expect_h = [[[0.18908429, -0.02963362],
|
||||
[0.09106826, -0.00602506]]]
|
||||
expect_c = [[[0.3434288, -0.06561527],
|
||||
[0.16838229, -0.00972614]]]
|
||||
|
||||
diff_y = y.asnumpy() - expect_y
|
||||
error_y = np.ones([seq_len, batch_size, hidden_size]) * 1.0e-4
|
||||
assert np.all(diff_y < error_y)
|
||||
assert np.all(-diff_y < error_y)
|
||||
diff_h = h.asnumpy() - expect_h
|
||||
error_h = np.ones([num_layers * num_directions, batch_size, hidden_size]) * 1.0e-4
|
||||
assert np.all(diff_h < error_h)
|
||||
assert np.all(-diff_h < error_h)
|
||||
diff_c = c.asnumpy() - expect_c
|
||||
error_c = np.ones([num_layers * num_directions, batch_size, hidden_size]) * 1.0e-4
|
||||
assert np.all(diff_c < error_c)
|
||||
assert np.all(-diff_c < error_c)
|
||||
|
||||
|
||||
class MultiLayerBiLstmNet(nn.Cell):
|
||||
|
@ -161,22 +157,15 @@ class MultiLayerBiLstmNet(nn.Cell):
|
|||
[1.2223, -1.3248, 0.1207, -0.8256, 0.1816, 0.7057, -0.3105, 0.5713, 0.2804,
|
||||
-1.0685]]]).astype(np.float32)
|
||||
|
||||
self.x = Parameter(initializer(Tensor(input_np), [seq_len, batch_size, input_size]), name='x')
|
||||
self.x = Tensor(input_np)
|
||||
|
||||
self.h0 = Parameter(initializer(
|
||||
Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32)),
|
||||
[num_directions, batch_size, hidden_size]), name='h0')
|
||||
self.c0 = Parameter(initializer(
|
||||
Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32)),
|
||||
[num_directions, batch_size, hidden_size]), name='c0')
|
||||
self.h1 = Parameter(initializer(
|
||||
Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32)),
|
||||
[num_directions, batch_size, hidden_size]), name='h1')
|
||||
self.c1 = Parameter(initializer(
|
||||
Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32)),
|
||||
[num_directions, batch_size, hidden_size]), name='c1')
|
||||
self.h = ParameterTuple((self.h0, self.h1))
|
||||
self.c = ParameterTuple((self.c0, self.c1))
|
||||
self.h0 = Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32))
|
||||
self.c0 = Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32))
|
||||
self.h1 = Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32))
|
||||
self.c1 = Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32))
|
||||
|
||||
self.h = tuple((self.h0, self.h1))
|
||||
self.c = tuple((self.c0, self.c1))
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
|
@ -202,7 +191,7 @@ def test_multi_layer_bilstm():
|
|||
|
||||
net = MultiLayerBiLstmNet(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional,
|
||||
dropout)
|
||||
y, h, c, _, _ = net()
|
||||
y, (h, c) = net()
|
||||
print(y)
|
||||
print(h)
|
||||
print(c)
|
||||
|
@ -231,66 +220,53 @@ class Net(nn.Cell):
|
|||
num_directions = 1
|
||||
if bidirectional:
|
||||
num_directions = 2
|
||||
input_np = np.array([[[-0.5907, 1.0557, 1.7283, 0.6706, -1.2550, -0.5298, -0.2290, -0.6735, 0.8555, 1.4836],
|
||||
[-1.7070, -0.5347, -0.9105, -0.2598, 0.0588, 1.5496, 1.0757, 0.3760, -1.2020, -0.2868]],
|
||||
|
||||
[[0.0151, 0.2126, 0.8090, -0.5292, -2.5590, 0.4279, -0.3081, -1.4706, -0.0498, 1.2301],
|
||||
[0.4165, -0.5391, -0.0996, 0.1928, -0.4909, -0.1255, 0.4444, -1.3687, 1.3096, 0.6553]],
|
||||
|
||||
[[-0.7802, -0.2083, -0.6388, 1.3757, 0.4293, 0.5363, 0.3202, -0.6687, -1.3864, -0.2953],
|
||||
[1.0799, -0.7204, 0.1130, -0.5857, -0.4855, -1.1068, 1.0126, 0.8716, 1.5460, -0.7392]],
|
||||
|
||||
[[2.2645, -0.6586, -0.2227, 1.4290, -0.5006, -1.6576, -0.1793, 0.5319, 0.1360, 0.2707],
|
||||
[-0.4071, 0.1575, 1.4199, -0.9156, 0.1855, 0.4947, 1.0460, -0.6365, 0.1191, -0.6374]],
|
||||
|
||||
[[0.2468, 1.0815, -0.4893, 0.0664, 0.6405, -2.2967, 0.7612, 0.8759, 0.5685, -1.0999],
|
||||
[-0.7272, -1.7750, -0.1164, -0.7159, 0.0061, -0.7839, -1.8329, 0.3434, -0.5634,
|
||||
0.5384]]]).astype(np.float32)
|
||||
|
||||
input_np = np.array([[[0.6755, -1.6607, 0.1367], [0.4276, -0.7850, -0.3758]],
|
||||
[[-0.6424, -0.6095, 0.6639], [0.7918, 0.4147, -0.5089]],
|
||||
[[-1.5612, 0.0120, -0.7289], [-0.6656, -0.6626, -0.5883]],
|
||||
[[-0.9667, -0.6296, -0.7310], [0.1026, -0.6821, -0.4387]],
|
||||
[[-0.4710, 0.6558, -0.3144], [-0.8449, -0.2184, -0.1806]]
|
||||
]).astype(np.float32)
|
||||
self.x = Parameter(initializer(Tensor(input_np), [seq_len, batch_size, input_size]), name='x')
|
||||
|
||||
self.h0 = Parameter(initializer(
|
||||
Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32)),
|
||||
[num_directions, batch_size, hidden_size]), name='h0')
|
||||
|
||||
self.c0 = Parameter(initializer(
|
||||
Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32)),
|
||||
[num_directions, batch_size, hidden_size]), name='c0')
|
||||
|
||||
wih_l0 = np.array([[0.2300, 0.6668, 0.4703, 0.0425, 0.0464, 0.6825, 0.2249, -0.4315, -0.2449, 0.2964],
|
||||
[-0.2811, -0.3444, 0.2557, -0.5137, -0.5518, 0.1652, -0.6720, 0.1066, 0.3586, 0.6299],
|
||||
[0.5728, -0.1784, 0.5661, 0.4012, 0.3856, -0.1899, 0.3102, 0.3717, -0.5651, 0.1952],
|
||||
[0.1026, -0.0527, 0.1198, -0.3080, 0.2292, 0.5757, -0.3567, -0.2731, -0.0586, -0.2849],
|
||||
[0.2194, -0.1622, 0.3219, -0.3008, -0.3713, -0.3034, -0.2385, 0.0412, -0.5205, 0.0280],
|
||||
[-0.5499, -0.0733, -0.5236, -0.6753, -0.7045, -0.1839, -0.1037, -0.5026, -0.4055, -0.3416],
|
||||
[0.1573, -0.1301, -0.2882, -0.3464, 0.6643, 0.1980, -0.6804, 0.5359, 0.5996, 0.0124],
|
||||
[-0.6436, 0.0587, -0.6520, -0.0471, 0.1667, 0.6042, 0.5752, -0.6296, -0.2976,
|
||||
-0.3757]]).astype(np.float32).reshape([1, -1])
|
||||
|
||||
whh_l0 = np.array([[0.3358, 0.2790],
|
||||
[-0.5355, 0.0989],
|
||||
[-0.1402, 0.5120],
|
||||
[0.1335, 0.1653],
|
||||
[0.3533, -0.3531],
|
||||
[0.4166, -0.4420],
|
||||
[-0.5454, -0.1720],
|
||||
[0.0041, -0.0799]]).astype(np.float32).reshape([1, -1])
|
||||
|
||||
bih_l0 = np.array([0.5518, 0.1083, 0.4829, 0.0607, -0.1770, -0.6944, 0.3059, 0.5354]).astype(
|
||||
np.float32).reshape([1, -1])
|
||||
bhh_l0 = np.array([0.5025, -0.1261, -0.5405, 0.3220, -0.3441, 0.6488, -0.0284, -0.2334]).astype(
|
||||
np.float32).reshape([1, -1])
|
||||
|
||||
w0_np = np.concatenate(
|
||||
(wih_l0, whh_l0, bih_l0 + bhh_l0),
|
||||
axis=1).reshape([-1, 1, 1])
|
||||
self.w0 = Parameter(initializer(Tensor(w0_np), w0_np.shape), name='w0')
|
||||
self.lstm = P.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
|
||||
has_bias=has_bias, bidirectional=bidirectional, dropout=dropout)
|
||||
self.hlist = []
|
||||
self.clist = []
|
||||
self.hlist.append(Parameter(initializer(
|
||||
Tensor(
|
||||
np.array([0.1, 0.1, 0.1, 0.1]).reshape((num_directions, batch_size, hidden_size)).astype(
|
||||
np.float32)),
|
||||
[num_directions, batch_size, hidden_size]), name='h'))
|
||||
self.clist.append(Parameter(initializer(
|
||||
Tensor(
|
||||
np.array([0.2, 0.2, 0.2, 0.2]).reshape((num_directions, batch_size, hidden_size)).astype(
|
||||
np.float32)),
|
||||
[num_directions, batch_size, hidden_size]), name='c'))
|
||||
self.h = ParameterTuple(tuple(self.hlist))
|
||||
self.c = ParameterTuple(tuple(self.clist))
|
||||
wih = np.array([[3.4021e-01, -4.6622e-01, 4.5117e-01],
|
||||
[-6.4257e-02, -2.4807e-01, 1.3550e-02], # i
|
||||
[-3.2140e-01, 5.5578e-01, 6.3589e-01],
|
||||
[1.6547e-01, -7.9030e-02, -2.0045e-01],
|
||||
[-6.9863e-01, 5.9773e-01, -3.9062e-01],
|
||||
[-3.0253e-01, -1.9464e-01, 7.0591e-01],
|
||||
[-4.0835e-01, 3.6751e-01, 4.7989e-01],
|
||||
[-5.6894e-01, -5.0359e-01, 4.7491e-01]]).astype(np.float32).reshape([1, -1])
|
||||
whh = np.array([[-0.4820, -0.2350],
|
||||
[-0.1195, 0.0519],
|
||||
[0.2162, -0.1178],
|
||||
[0.6237, 0.0711],
|
||||
[0.4511, -0.3961],
|
||||
[-0.5962, 0.0906],
|
||||
[0.1867, -0.1225],
|
||||
[0.1831, 0.0850]]).astype(np.float32).reshape([1, -1])
|
||||
bih = np.zeros((1, 8)).astype(np.float32)
|
||||
w_np = np.concatenate((wih, whh, bih), axis=1).reshape([-1, 1, 1])
|
||||
self.w = Parameter(initializer(Tensor(w_np), w_np.shape), name='weight0')
|
||||
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
|
||||
has_bias=has_bias, bidirectional=bidirectional, dropout=dropout)
|
||||
self.lstm.weight = ParameterTuple(tuple([self.w]))
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.lstm(self.x, self.h0, self.c0, self.w0)[0]
|
||||
return self.lstm(self.x, (self.h, self.c))[0]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -299,7 +275,7 @@ class Net(nn.Cell):
|
|||
def test_grad():
|
||||
seq_len = 5
|
||||
batch_size = 2
|
||||
input_size = 10
|
||||
input_size = 3
|
||||
hidden_size = 2
|
||||
num_layers = 1
|
||||
has_bias = True
|
||||
|
@ -329,7 +305,6 @@ def test_grad():
|
|||
print(dcx)
|
||||
print(dw)
|
||||
|
||||
# test_multi_layer_bilstm()
|
||||
# test_lstm()
|
||||
# tf_lstm_test()
|
||||
# test_grad()
|
||||
test_multi_layer_bilstm()
|
||||
test_lstm()
|
||||
test_grad()
|
||||
|
|
Loading…
Reference in New Issue