forked from mindspore-Ecosystem/mindspore
fix lstm bug when hidden_size is zero
This commit is contained in:
parent
30b19165a2
commit
7f7c006acf
|
@ -22,6 +22,10 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
|
#ifdef PLATFORM_86
|
||||||
|
_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
|
||||||
|
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
|
||||||
|
#endif
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
using tag = dnnl::memory::format_tag;
|
using tag = dnnl::memory::format_tag;
|
||||||
using dim = dnnl::memory::dims;
|
using dim = dnnl::memory::dims;
|
||||||
|
|
|
@ -16,6 +16,12 @@
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_
|
#define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_
|
||||||
|
#if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || defined(_M_X64)
|
||||||
|
#define PLATFORM_86
|
||||||
|
#endif
|
||||||
|
#ifdef PLATFORM_86
|
||||||
|
#include <pmmintrin.h>
|
||||||
|
#endif
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h"
|
#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h"
|
||||||
|
|
|
@ -23,6 +23,7 @@ from mindspore.common.parameter import Parameter, ParameterTuple
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.nn.cell import Cell
|
from mindspore.nn.cell import Cell
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from ..._checkparam import Rel
|
||||||
|
|
||||||
__all__ = ['LSTM', 'LSTMCell']
|
__all__ = ['LSTM', 'LSTMCell']
|
||||||
|
|
||||||
|
@ -123,6 +124,8 @@ class LSTM(Cell):
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.has_bias = has_bias
|
self.has_bias = has_bias
|
||||||
self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
|
self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
|
||||||
|
self.hidden_size = validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.cls_name)
|
||||||
|
self.num_layers = validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.cls_name)
|
||||||
self.dropout = float(dropout)
|
self.dropout = float(dropout)
|
||||||
self.bidirectional = bidirectional
|
self.bidirectional = bidirectional
|
||||||
if self.batch_first:
|
if self.batch_first:
|
||||||
|
|
Loading…
Reference in New Issue