From e895365c3119e30416b27b482a5df486ed9b757d Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Fri, 16 Apr 2021 19:00:14 +0800 Subject: [PATCH] Fix check string attr of DynamicRnn. --- mindspore/ops/operations/nn_ops.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 52f0992504e..a1b1475c48b 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -7261,8 +7261,11 @@ class DynamicRNN(PrimitiveWithInfer): self.use_peephole = validator.check_value_type("use_peephole", use_peephole, [bool], self.name) self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name) self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name) + validator.check_value_type("cell_type", cell_type, [str], self.name) self.cell_type = validator.check_string(cell_type, ['LSTM'], "cell_type", self.name) + validator.check_value_type("direction", direction, [str], self.name) self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) + validator.check_value_type("activation", activation, [str], self.name) self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):