forked from mindspore-Ecosystem/mindspore
add validation check for sequence_mask
This commit is contained in:
parent
c8657a6723
commit
09c5624cc4
|
@ -108,8 +108,13 @@ def repeat_elements(x, rep, axis=0):
|
|||
@constexpr
|
||||
def _check_sequence_mask_input_len(input_shape):
|
||||
if not input_shape:
|
||||
raise ValueError(f"sequence_mask input lengths_shape should be > 0. "
|
||||
f"current lengths_shape is {input_shape}.")
|
||||
raise ValueError(f"Sequence_mask lengths_shape should be > 0. "
|
||||
f"Current lengths_shape is {input_shape}.")
|
||||
# broadcast only supports 7d shape
|
||||
shape_size = len(input_shape)
|
||||
if shape_size >= 7:
|
||||
raise ValueError(f"Sequence_mask lengths_shape's size only support a value less than 7. "
|
||||
f"Current lengths_shape is {shape_size}d.")
|
||||
|
||||
|
||||
def sequence_mask(lengths, maxlen=None):
|
||||
|
|
Loading…
Reference in New Issue