From 09c5624cc4e85adb14330f523f211a5f78a44915 Mon Sep 17 00:00:00 2001 From: TFBunny Date: Wed, 24 Feb 2021 15:06:40 -0500 Subject: [PATCH] add validation check for sequence_mask --- mindspore/ops/composite/array_ops.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mindspore/ops/composite/array_ops.py b/mindspore/ops/composite/array_ops.py index ebe5aa80135..277fb3186f1 100644 --- a/mindspore/ops/composite/array_ops.py +++ b/mindspore/ops/composite/array_ops.py @@ -108,8 +108,13 @@ def repeat_elements(x, rep, axis=0): @constexpr def _check_sequence_mask_input_len(input_shape): if not input_shape: - raise ValueError(f"sequence_mask input lengths_shape should be > 0. " - f"current lengths_shape is {input_shape}.") + raise ValueError(f"Sequence_mask lengths_shape should be > 0. " + f"Current lengths_shape is {input_shape}.") + # broadcast only supports 7d shape + shape_size = len(input_shape) + if shape_size >= 7: + raise ValueError(f"Sequence_mask lengths_shape's size only support a value less than 7. " + f"Current lengths_shape is {shape_size}d.") def sequence_mask(lengths, maxlen=None):