diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc index c8e1c0403b5..4fefc2db98d 100644 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc @@ -22,6 +22,10 @@ namespace mindspore { namespace kernel { void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { +#ifdef PLATFORM_86 + _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); + _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); +#endif MS_EXCEPTION_IF_NULL(kernel_node); using tag = dnnl::memory::format_tag; using dim = dnnl::memory::dims; diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h index f864009d5f2..d42ff803f07 100644 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h @@ -16,6 +16,12 @@ #ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ +#if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || defined(_M_X64) +#define PLATFORM_86 +#endif +#ifdef PLATFORM_86 +#include +#endif #include #include #include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" diff --git a/mindspore/nn/layer/lstm.py b/mindspore/nn/layer/lstm.py index e3fba5ff3fb..71c29208501 100755 --- a/mindspore/nn/layer/lstm.py +++ b/mindspore/nn/layer/lstm.py @@ -23,6 +23,7 @@ from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.tensor import Tensor from mindspore.nn.cell import Cell from mindspore.ops import operations as P +from ..._checkparam import Rel __all__ = ['LSTM', 'LSTMCell'] @@ -123,6 +124,8 @@ class LSTM(Cell): self.num_layers = num_layers self.has_bias = has_bias self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) + self.hidden_size = validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.cls_name) + self.num_layers = validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.cls_name) self.dropout = float(dropout) self.bidirectional = bidirectional if self.batch_first: