!45687 Remove functional interface of maskedscatter
Merge pull request !45687 from panzhihui/maskedscatter_ops
This commit is contained in:
commit
be06b3cf58
|
@ -46,6 +46,7 @@ from mindspore.ops.operations.array_ops import AffineGrid
|
|||
from mindspore.ops.operations.array_ops import Im2Col
|
||||
from mindspore.ops.operations.array_ops import Col2Im
|
||||
from mindspore.ops.operations.array_ops import StridedSliceV2
|
||||
from mindspore.ops.operations.array_ops import MaskedScatter
|
||||
from mindspore.ops.operations._grad_ops import StridedSliceV2Grad
|
||||
from mindspore.ops.operations.random_ops import LogNormalReverse
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
@ -124,11 +125,11 @@ def get_bprop_masked_select(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.MaskedScatter)
|
||||
@bprop_getters.register(MaskedScatter)
|
||||
def get_bprop_masked_scatter(self):
|
||||
"""Generate bprop for MaskedScatter"""
|
||||
sort_ = P.Sort(descending=True)
|
||||
masked_scatter = P.MaskedScatter()
|
||||
masked_scatter = MaskedScatter()
|
||||
masked_fill = P.MaskedFill()
|
||||
masked_select = P.MaskedSelect()
|
||||
size = P.Size()
|
||||
|
|
|
@ -42,7 +42,6 @@ from .print_tensor import _print_aicpu
|
|||
from .topk import _top_k_aicpu
|
||||
from .log1p import _log1p_aicpu
|
||||
from .asin import _asin_aicpu
|
||||
from .masked_scatter import _masked_scatter_aicpu
|
||||
from .is_finite import _is_finite_aicpu
|
||||
from .is_inf import _is_inf_aicpu
|
||||
from .is_nan import _is_nan_aicpu
|
||||
|
|
|
@ -102,7 +102,6 @@ from .array_func import (
|
|||
matrix_diag_part,
|
||||
matrix_set_diag,
|
||||
diag,
|
||||
masked_scatter,
|
||||
masked_select,
|
||||
meshgrid,
|
||||
affine_grid,
|
||||
|
|
|
@ -25,7 +25,6 @@ from mindspore.ops.operations.array_ops import (
|
|||
UniqueConsecutive,
|
||||
SearchSorted,
|
||||
NonZero,
|
||||
MaskedScatter,
|
||||
MatrixDiagV3,
|
||||
MatrixDiagPartV3,
|
||||
MatrixSetDiagV3,
|
||||
|
@ -4133,43 +4132,6 @@ def tuple_to_array(input_x):
|
|||
return tuple_to_array_(input_x)
|
||||
|
||||
|
||||
def masked_scatter(x, mask, updates):
|
||||
"""
|
||||
Updates the value in the input with the updates value according to the mask.
|
||||
The shapes of `mask` and `x` must be the same or broadcastable.
|
||||
|
||||
Args:
|
||||
x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
mask (Tensor[bool]): A bool tensor with a shape broadcastable to x.
|
||||
updates (Tensor): A tensor with the same data type as x. The
|
||||
number of elements must be greater than or equal to the number of True's in `mask`.
|
||||
|
||||
Outputs:
|
||||
y (Tensor), with the same type and shape as x.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x`, `mask` or `updates` is not a Tensor.
|
||||
TypeError: If data type of `x` is not be supported.
|
||||
TypeError: If dtype of `mask` is not bool.
|
||||
TypeError: If the dim of `x` less than the dim of `mask`.
|
||||
ValueError: If `mask` can not be broadcastable to `x`.
|
||||
ValueError: If the number of elements in `updates` is less than the number required for the updates.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x= Tensor(np.array([1., 2., 3., 4.]), mindspore.float32)
|
||||
>>> mask = Tensor(np.array([True, True, False, True]), mindspore.bool_)
|
||||
>>> updates = Tensor(np.array([5., 6., 7.]), mindspore.float32)
|
||||
>>> output = ops.MaskedScatter()(input_X, mask, updates)
|
||||
>>> print(output)
|
||||
[5. 6. 3. 7.]
|
||||
"""
|
||||
masked_scatter_ = MaskedScatter()
|
||||
return masked_scatter_(x, mask, updates)
|
||||
|
||||
|
||||
def masked_select(x, mask):
|
||||
"""
|
||||
Returns a new 1-D Tensor which indexes the `x` tensor according to the boolean `mask`.
|
||||
|
@ -5105,7 +5067,6 @@ __all__ = [
|
|||
'gather_nd',
|
||||
'one_hot',
|
||||
'masked_fill',
|
||||
'masked_scatter',
|
||||
'masked_select',
|
||||
'narrow',
|
||||
'scatter_add',
|
||||
|
|
|
@ -35,7 +35,7 @@ from .array_ops import (ArgMaxWithValue, ArgMinWithValue, Argmax, Argmin, BatchT
|
|||
Eye, Fill, Gather, GatherD, GatherNd, GatherV2, Identity, Im2Col, InvertPermutation,
|
||||
LowerBound, Lstsq, MaskedFill, MaskedSelect, Meshgrid, Mvlgamma, Ones, OnesLike,
|
||||
Pack, Padding, ParallelConcat, PopulationCount, Range, Rank, Reshape, ResizeNearestNeighbor,
|
||||
ReverseSequence, ReverseV2, Rint, ScalarToTensor, ScatterAdd, MaskedScatter,
|
||||
ReverseSequence, ReverseV2, Rint, ScalarToTensor, ScatterAdd,
|
||||
ScatterDiv, ScatterMax, ScatterMin, ScatterMul, ScatterNd, ScatterNdAdd, ScatterNdDiv,
|
||||
ScatterNdMax, ScatterNdMin, ScatterNdSub, ScatterNdUpdate, ScatterNonAliasingAdd, ScatterSub,
|
||||
ScatterUpdate, SearchSorted, Select, Shape, Size, Slice, Sort, SpaceToBatch, SpaceToBatchND,
|
||||
|
@ -151,7 +151,6 @@ __all__ = [
|
|||
'BatchMatMul',
|
||||
'Mul',
|
||||
'MaskedFill',
|
||||
'MaskedScatter',
|
||||
'MaskedSelect',
|
||||
'Meshgrid',
|
||||
'MultiMarginLoss',
|
||||
|
|
|
@ -6030,7 +6030,22 @@ class MaskedScatter(Primitive):
|
|||
Updates the value in the input with the updates value according to the mask.
|
||||
The shapes of `mask` and `x` must be the same or broadcastable.
|
||||
|
||||
Refer to :func:`mindspore.ops.masked_scatter' for more details.
|
||||
Inputs:
|
||||
- **x** (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
- **mask** (Tensor[bool]): A bool tensor with a shape broadcastable to x.
|
||||
- **updates** (Tensor): A tensor with the same data type as x. The
|
||||
number of elements must be greater than or equal to the number of True's in `mask`.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as x.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x`, `mask` or `updates` is not a Tensor.
|
||||
TypeError: If data type of `x` is not be supported.
|
||||
TypeError: If dtype of `mask` is not bool.
|
||||
TypeError: If the dim of `x` less than the dim of `mask`.
|
||||
ValueError: If `mask` can not be broadcastable to `x`.
|
||||
ValueError: If the number of elements in `updates` is less than the number required for the updates.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
@ -6048,7 +6063,6 @@ class MaskedScatter(Primitive):
|
|||
def __init__(self):
|
||||
"""Initialize MaskedScatter"""
|
||||
self.init_prim_io_names(inputs=['x', 'mask', 'updates'], outputs=['y'])
|
||||
self.add_prim_attr("cust_aicpu", "MaskedScatter")
|
||||
|
||||
|
||||
class MaskedSelect(PrimitiveWithCheck):
|
||||
|
|
Loading…
Reference in New Issue