!45687 Remove functional interface of maskedscatter

Merge pull request !45687 from panzhihui/maskedscatter_ops
This commit is contained in:
i-robot 2022-11-18 08:13:26 +00:00 committed by Gitee
commit be06b3cf58
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 20 additions and 47 deletions

View File

@ -46,6 +46,7 @@ from mindspore.ops.operations.array_ops import AffineGrid
from mindspore.ops.operations.array_ops import Im2Col
from mindspore.ops.operations.array_ops import Col2Im
from mindspore.ops.operations.array_ops import StridedSliceV2
from mindspore.ops.operations.array_ops import MaskedScatter
from mindspore.ops.operations._grad_ops import StridedSliceV2Grad
from mindspore.ops.operations.random_ops import LogNormalReverse
from mindspore.ops.operations import _inner_ops as inner
@ -124,11 +125,11 @@ def get_bprop_masked_select(self):
return bprop
@bprop_getters.register(P.MaskedScatter)
@bprop_getters.register(MaskedScatter)
def get_bprop_masked_scatter(self):
"""Generate bprop for MaskedScatter"""
sort_ = P.Sort(descending=True)
masked_scatter = P.MaskedScatter()
masked_scatter = MaskedScatter()
masked_fill = P.MaskedFill()
masked_select = P.MaskedSelect()
size = P.Size()

View File

@ -42,7 +42,6 @@ from .print_tensor import _print_aicpu
from .topk import _top_k_aicpu
from .log1p import _log1p_aicpu
from .asin import _asin_aicpu
from .masked_scatter import _masked_scatter_aicpu
from .is_finite import _is_finite_aicpu
from .is_inf import _is_inf_aicpu
from .is_nan import _is_nan_aicpu

View File

@ -102,7 +102,6 @@ from .array_func import (
matrix_diag_part,
matrix_set_diag,
diag,
masked_scatter,
masked_select,
meshgrid,
affine_grid,

View File

@ -25,7 +25,6 @@ from mindspore.ops.operations.array_ops import (
UniqueConsecutive,
SearchSorted,
NonZero,
MaskedScatter,
MatrixDiagV3,
MatrixDiagPartV3,
MatrixSetDiagV3,
@ -4133,43 +4132,6 @@ def tuple_to_array(input_x):
return tuple_to_array_(input_x)
def masked_scatter(x, mask, updates):
"""
Updates the value in the input with the updates value according to the mask.
The shapes of `mask` and `x` must be the same or broadcastable.
Args:
x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
mask (Tensor[bool]): A bool tensor with a shape broadcastable to x.
updates (Tensor): A tensor with the same data type as x. The
number of elements must be greater than or equal to the number of True's in `mask`.
Outputs:
y (Tensor), with the same type and shape as x.
Raises:
TypeError: If `x`, `mask` or `updates` is not a Tensor.
TypeError: If data type of `x` is not be supported.
TypeError: If dtype of `mask` is not bool.
TypeError: If the dim of `x` less than the dim of `mask`.
ValueError: If `mask` can not be broadcastable to `x`.
ValueError: If the number of elements in `updates` is less than the number required for the updates.
Supported Platforms:
``CPU``
Examples:
>>> x= Tensor(np.array([1., 2., 3., 4.]), mindspore.float32)
>>> mask = Tensor(np.array([True, True, False, True]), mindspore.bool_)
>>> updates = Tensor(np.array([5., 6., 7.]), mindspore.float32)
>>> output = ops.MaskedScatter()(input_X, mask, updates)
>>> print(output)
[5. 6. 3. 7.]
"""
masked_scatter_ = MaskedScatter()
return masked_scatter_(x, mask, updates)
def masked_select(x, mask):
"""
Returns a new 1-D Tensor which indexes the `x` tensor according to the boolean `mask`.
@ -5105,7 +5067,6 @@ __all__ = [
'gather_nd',
'one_hot',
'masked_fill',
'masked_scatter',
'masked_select',
'narrow',
'scatter_add',

View File

@ -35,7 +35,7 @@ from .array_ops import (ArgMaxWithValue, ArgMinWithValue, Argmax, Argmin, BatchT
Eye, Fill, Gather, GatherD, GatherNd, GatherV2, Identity, Im2Col, InvertPermutation,
LowerBound, Lstsq, MaskedFill, MaskedSelect, Meshgrid, Mvlgamma, Ones, OnesLike,
Pack, Padding, ParallelConcat, PopulationCount, Range, Rank, Reshape, ResizeNearestNeighbor,
ReverseSequence, ReverseV2, Rint, ScalarToTensor, ScatterAdd, MaskedScatter,
ReverseSequence, ReverseV2, Rint, ScalarToTensor, ScatterAdd,
ScatterDiv, ScatterMax, ScatterMin, ScatterMul, ScatterNd, ScatterNdAdd, ScatterNdDiv,
ScatterNdMax, ScatterNdMin, ScatterNdSub, ScatterNdUpdate, ScatterNonAliasingAdd, ScatterSub,
ScatterUpdate, SearchSorted, Select, Shape, Size, Slice, Sort, SpaceToBatch, SpaceToBatchND,
@ -151,7 +151,6 @@ __all__ = [
'BatchMatMul',
'Mul',
'MaskedFill',
'MaskedScatter',
'MaskedSelect',
'Meshgrid',
'MultiMarginLoss',

View File

@ -6030,7 +6030,22 @@ class MaskedScatter(Primitive):
Updates the value in the input with the updates value according to the mask.
The shapes of `mask` and `x` must be the same or broadcastable.
Refer to :func:`mindspore.ops.masked_scatter' for more details.
Inputs:
- **x** (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
- **mask** (Tensor[bool]): A bool tensor with a shape broadcastable to x.
- **updates** (Tensor): A tensor with the same data type as x. The
number of elements must be greater than or equal to the number of True's in `mask`.
Outputs:
Tensor, with the same type and shape as x.
Raises:
TypeError: If `x`, `mask` or `updates` is not a Tensor.
TypeError: If data type of `x` is not be supported.
TypeError: If dtype of `mask` is not bool.
TypeError: If the dim of `x` less than the dim of `mask`.
ValueError: If `mask` can not be broadcastable to `x`.
ValueError: If the number of elements in `updates` is less than the number required for the updates.
Supported Platforms:
``CPU``
@ -6048,7 +6063,6 @@ class MaskedScatter(Primitive):
def __init__(self):
"""Initialize MaskedScatter"""
self.init_prim_io_names(inputs=['x', 'mask', 'updates'], outputs=['y'])
self.add_prim_attr("cust_aicpu", "MaskedScatter")
class MaskedSelect(PrimitiveWithCheck):