forked from mindspore-Ecosystem/mindspore
!8954 Add SequenceMask operator.
From: @liangzhibo Reviewed-by: @chenfei52,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
d9b4b5c750
|
@ -22,7 +22,7 @@ A collection of operators to build neural networks or to compute functions.
|
|||
from .image_ops import (CropAndResize)
|
||||
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
||||
Diag, DiagPart, DType, ExpandDims, Eye,
|
||||
Fill, Ones, Zeros, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
|
||||
Fill, Ones, Zeros, SequenceMask, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
|
||||
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
|
||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Meshgrid,
|
||||
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
|
||||
|
@ -182,6 +182,7 @@ __all__ = [
|
|||
'Fill',
|
||||
'Ones',
|
||||
'Zeros',
|
||||
'SequenceMask',
|
||||
'OnesLike',
|
||||
'ZerosLike',
|
||||
'Select',
|
||||
|
|
|
@ -1149,7 +1149,7 @@ class Ones(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Fill"""
|
||||
"""Initialize Ones"""
|
||||
|
||||
def __infer__(self, dims, dtype):
|
||||
if isinstance(dims['value'], int):
|
||||
|
@ -1203,7 +1203,7 @@ class Zeros(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Fill"""
|
||||
"""Initialize Zeros"""
|
||||
|
||||
def __infer__(self, dims, dtype):
|
||||
if isinstance(dims['value'], int):
|
||||
|
@ -1227,6 +1227,65 @@ class Zeros(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class SequenceMask(PrimitiveWithInfer):
|
||||
r"""
|
||||
Generates sequence mask according to input lengths.
|
||||
|
||||
Creates a mask tensor which retains the first N elements in tensor by setting the values
|
||||
to be True or one. The rest values in mask are set to False or zero.
|
||||
|
||||
Args:
|
||||
max_length (int): Nonnegative integer, size of the last dimension in mask. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **lengths** (Union[tuple[int], list[int]]) - Defines the first N elements that are retained.
|
||||
Only constant value is allowed.
|
||||
- **dtype** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed.
|
||||
|
||||
Outputs:
|
||||
Tensor.
|
||||
If max_length is set, the shape of the output is (lengths.shape, max_length).
|
||||
If max_length is not set and the biggest value in lengths is x. Then, the shape of
|
||||
the output is (lengths.shape, x).
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> sequence_mask = P.SequenceMask()
|
||||
>>> mask = sequence_mask([2, 2, 4], mindspore.int32)
|
||||
>>> print(mask)
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[1, 1, 1, 1]]
|
||||
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize SequenceMask"""
|
||||
|
||||
def __infer__(self, lengths, dtype, max_length=None):
|
||||
validator.check_value_type("shape", lengths['value'], [tuple, list], self.name)
|
||||
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
|
||||
mstype.uint8, mstype.uint32, mstype.uint64,
|
||||
mstype.float16, mstype.float32, mstype.float64]
|
||||
validator.check_subclass("dtype", dtype['value'], valid_types, self.name)
|
||||
nptype = mstype.dtype_to_nptype(dtype['value'])
|
||||
if max_length is None:
|
||||
max_length = np.max(lengths['value'])
|
||||
else:
|
||||
validator.check_non_negative_int(max_length['value'])
|
||||
max_length = max_length['value']
|
||||
row_vector = np.arange(0, max_length)
|
||||
col_matrix = np.expand_dims(lengths['value'], -1)
|
||||
result = (row_vector < col_matrix).astype(nptype)
|
||||
out = {
|
||||
'value': Tensor(result),
|
||||
'shape': result.shape,
|
||||
'dtype': dtype['value']
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
class OnesLike(PrimitiveWithInfer):
|
||||
"""
|
||||
Creates a new tensor. The values of all elements are 1.
|
||||
|
|
|
@ -42,6 +42,28 @@ def test_expand_dims():
|
|||
assert output.asnumpy().shape == (1, 2, 2)
|
||||
|
||||
|
||||
def test_sequence_mask():
|
||||
list_ = [2, 2, 4]
|
||||
sequence_mask = P.SequenceMask()
|
||||
mask1 = sequence_mask(list_, mstype.int32)
|
||||
mask2 = sequence_mask(list_, mstype.int32, 5)
|
||||
assert mask1.shape == (3, 4)
|
||||
assert mask1.dtype == mstype.int32
|
||||
assert mask2.shape == (3, 5)
|
||||
assert mask2.dtype == mstype.int32
|
||||
|
||||
|
||||
def test_sequence_mask_1():
|
||||
list_ = [[2, 2, 4], [3, 4, 4]]
|
||||
sequence_mask = P.SequenceMask()
|
||||
mask1 = sequence_mask(list_, mstype.bool_)
|
||||
mask2 = sequence_mask(list_, mstype.bool_, 5)
|
||||
assert mask1.shape == (2, 3, 4)
|
||||
assert mask1.dtype == mstype.bool_
|
||||
assert mask2.shape == (2, 3, 5)
|
||||
assert mask2.dtype == mstype.bool_
|
||||
|
||||
|
||||
def test_cast():
|
||||
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
||||
input_x = Tensor(input_np)
|
||||
|
|
Loading…
Reference in New Issue