add validation check for sequence_mask

This commit is contained in:
TFBunny 2021-02-24 15:06:40 -05:00
parent c8657a6723
commit 09c5624cc4
1 changed files with 7 additions and 2 deletions

View File

@ -108,8 +108,13 @@ def repeat_elements(x, rep, axis=0):
@constexpr
def _check_sequence_mask_input_len(input_shape):
if not input_shape:
raise ValueError(f"sequence_mask input lengths_shape should be > 0. "
f"current lengths_shape is {input_shape}.")
raise ValueError(f"Sequence_mask lengths_shape should be > 0. "
f"Current lengths_shape is {input_shape}.")
# broadcast only supports 7d shape
shape_size = len(input_shape)
if shape_size >= 7:
raise ValueError(f"Sequence_mask lengths_shape's size only support a value less than 7. "
f"Current lengths_shape is {shape_size}d.")
def sequence_mask(lengths, maxlen=None):