forked from mindspore-Ecosystem/mindspore
!10312 SequenceMask move API to composite
From: @peilin-wang Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @liangchenghui
This commit is contained in:
commit
2b64b7f295
|
@ -28,7 +28,7 @@ from .multitype_ops.ones_like_impl import ones_like
|
|||
from .multitype_ops.zeros_like_impl import zeros_like
|
||||
from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial
|
||||
from .math_ops import count_nonzero, tensor_dot
|
||||
from .array_ops import repeat_elements
|
||||
from .array_ops import repeat_elements, sequence_mask
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -53,4 +53,5 @@ __all__ = [
|
|||
'clip_by_global_norm',
|
||||
'count_nonzero',
|
||||
'tensor_dot',
|
||||
'repeat_elements']
|
||||
'repeat_elements',
|
||||
'sequence_mask']
|
||||
|
|
|
@ -20,6 +20,7 @@ from mindspore._checkparam import Rel
|
|||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import functional as F
|
||||
from .. import operations as P
|
||||
from ..operations import _inner_ops as inner
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -103,3 +104,35 @@ def repeat_elements(x, rep, axis=0):
|
|||
x_rep = reshape_op(x_expand, x_reshape)
|
||||
|
||||
return x_rep
|
||||
|
||||
def sequence_mask(lengths, maxlen):
|
||||
"""
|
||||
Returns a mask tensor representing the first N positions of each cell.
|
||||
|
||||
If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type dtype and shape
|
||||
[d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])
|
||||
|
||||
Args:
|
||||
length (Tensor): Tensor to calculate the mask for. All values in this tensor must be
|
||||
less than `maxlen`. Must be type int32 or int64.
|
||||
|
||||
maxlen (int): size of the last dimension of returned tensor. Must be positive and same
|
||||
type as elements in `lengths`.
|
||||
|
||||
Outputs:
|
||||
One mask tensor of shape lengths.shape + (maxlen,).
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[1, 3], [2, 0]])
|
||||
>>> sequence_mask = P.SequenceMask()
|
||||
>>> output = sequence_mask(x, 3)
|
||||
>>> print(output)
|
||||
[[[True, False, False],
|
||||
[True, True, True]],
|
||||
[[True, True, False],
|
||||
[False, False, False]]]
|
||||
"""
|
||||
return inner.SequenceMask()(lengths, maxlen)
|
||||
|
|
|
@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax,
|
||||
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
|
||||
Unique, GatherD, Identity, SequenceMask)
|
||||
Unique, GatherD, Identity)
|
||||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
|
||||
_MirrorOperator, ReduceOp, _VirtualDataset,
|
||||
_VirtualDiv, _GetTensorSlice,
|
||||
|
@ -400,7 +400,6 @@ __all__ = [
|
|||
"Pull",
|
||||
"ReLUV2",
|
||||
"SparseToDense",
|
||||
"SequenceMask",
|
||||
]
|
||||
|
||||
__all__.sort()
|
||||
|
|
|
@ -679,3 +679,47 @@ class ErrorOnDynamicShapeInput(PrimitiveWithInfer):
|
|||
|
||||
def infer_value(self, input_tensor):
|
||||
return input_tensor
|
||||
|
||||
|
||||
class SequenceMask(PrimitiveWithCheck):
|
||||
"""
|
||||
Returns a mask tensor representing the first N positions of each cell.
|
||||
|
||||
If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type dtype and shape
|
||||
[d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])
|
||||
|
||||
Inputs:
|
||||
- **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor must be
|
||||
less than `maxlen`. Must be type int32 or int64.
|
||||
|
||||
- **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same
|
||||
type as elements in `lengths`.
|
||||
|
||||
Outputs:
|
||||
One mask tensor of shape lengths.shape + (maxlen,).
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[1, 3], [2, 0]])
|
||||
>>> sequence_mask = P.SequenceMask()
|
||||
>>> output = sequence_mask(x, 3)
|
||||
>>> print(output)
|
||||
[[[True, False, False],
|
||||
[True, True, True]],
|
||||
[[True, True, False],
|
||||
[False, False, False]]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=["lengths", "maxlen"], outputs=["mask"])
|
||||
|
||||
def check_shape(self, lengths_shape, maxlen_shape):
|
||||
validator.check("lengths_shape", len(lengths_shape), "", 0, Rel.GT, self.name)
|
||||
validator.check("maxlen_shape", len(maxlen_shape), "", 0, Rel.EQ, self.name)
|
||||
|
||||
def check_dtype(self, lengths_dtype, maxlen_dtype):
|
||||
validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor, self.name)
|
||||
validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name)
|
||||
|
|
|
@ -4720,47 +4720,3 @@ class Identity(PrimitiveWithInfer):
|
|||
'dtype': x['dtype'],
|
||||
'value': None}
|
||||
return out
|
||||
|
||||
|
||||
class SequenceMask(PrimitiveWithCheck):
|
||||
"""
|
||||
Returns a mask tensor representing the first N positions of each cell.
|
||||
|
||||
If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type dtype and shape
|
||||
[d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])
|
||||
|
||||
Inputs:
|
||||
- **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor must be
|
||||
less than `maxlen`. Must be type int32 or int64.
|
||||
|
||||
- **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same
|
||||
tyupe as elements in `lengths`.
|
||||
|
||||
Outputs:
|
||||
One mask tensor of shape lengths.shape + (maxlen,).
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[1, 3], [2, 0]])
|
||||
>>> sequence_mask = P.SequenceMask()
|
||||
>>> output = sequence_mask(x, 3)
|
||||
>>> print(output)
|
||||
[[[True, False, False],
|
||||
[True, True, True]],
|
||||
[[True, True, False],
|
||||
[False, False, False]]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=["lengths", "maxlen"], outputs=["mask"])
|
||||
|
||||
def check_shape(self, lengths_shape, maxlen_shape):
|
||||
validator.check("lengths_shape", len(lengths_shape), "", 0, Rel.GT, self.name)
|
||||
validator.check("maxlen_shape", len(maxlen_shape), "", 0, Rel.EQ, self.name)
|
||||
|
||||
def check_dtype(self, lengths_dtype, maxlen_dtype):
|
||||
validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor, self.name)
|
||||
validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name)
|
||||
|
|
|
@ -2,14 +2,13 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
|
||||
def sequence_mask(x, maxlen):
|
||||
sequence_mask_op = P.SequenceMask()
|
||||
return sequence_mask_op(Tensor(x.astype(np.int32)), maxlen)
|
||||
return C.sequence_mask(Tensor(x.astype(np.int32)), maxlen)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -87,11 +86,10 @@ def test_sequence_mask_dynamic():
|
|||
super(SequenceMaskDynamicNet, self).__init__()
|
||||
self.maxlen = maxlen
|
||||
self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
|
||||
self.sequence_mask = P.SequenceMask()
|
||||
|
||||
def construct(self, x):
|
||||
converted_to_dynamic_shape = self.convert_to_dynamic_shape(x)
|
||||
return self.sequence_mask(converted_to_dynamic_shape, self.maxlen)
|
||||
return C.sequence_mask(converted_to_dynamic_shape, self.maxlen)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
|
Loading…
Reference in New Issue