!36606 [assistant][ops][aicpu][五期][I4XJI8] Add MaskedScatter operator

Merge pull request !36606 from 桂宁馨/MaskedScatter
This commit is contained in:
i-robot 2022-11-17 11:09:03 +00:00 committed by Gitee
commit 2d62478890
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
13 changed files with 512 additions and 1 deletions

View File

@ -0,0 +1,176 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/masked_scatter_cpu_kernel.h"
#include <algorithm>
#include <utility>
#include <functional>
#include "mindspore/core/ops/masked_scatter.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kMaskedScatterInputsNum = 3;
constexpr size_t kMaskedScatterOutputsNum = 1;
} // namespace
bool MaskedScatterCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "MaskedScatter does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
return true;
}
int MaskedScatterCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
int ret = KRET_OK;
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs)) != 0) {
return ret;
}
std::vector<int64_t> x_shape = inputs.at(kIndex0)->GetShapeVector();
std::vector<int64_t> mask_shape = inputs.at(kIndex1)->GetShapeVector();
std::vector<int64_t> updates_shape = inputs.at(kIndex2)->GetShapeVector();
std::vector<int64_t> output_shape = outputs.at(kIndex0)->GetShapeVector();
std::transform(x_shape.begin(), x_shape.end(), std::back_inserter(x_shape_), LongToSize);
std::transform(mask_shape.begin(), mask_shape.end(), std::back_inserter(mask_shape_), LongToSize);
std::transform(updates_shape.begin(), updates_shape.end(), std::back_inserter(updates_shape_), LongToSize);
std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(output_shape_), LongToSize);
x_numElements_ = std::accumulate(x_shape_.begin(), x_shape_.end(), 1, std::multiplies<size_t>());
updates_numElements_ = std::accumulate(updates_shape_.begin(), updates_shape_.end(), 1, std::multiplies<size_t>());
need_broadcast_ = (x_shape_ == mask_shape_) ? false : true;
size_t mask_dims = mask_shape.size();
std::vector<int64_t> x_shape_reverse = x_shape_;
std::vector<int64_t> mask_shape_reverse = mask_shape_;
std::reverse(x_shape_reverse.begin(), x_shape_reverse.end());
std::reverse(mask_shape_reverse.begin(), mask_shape_reverse.end());
for (size_t i = 0; i < mask_dims; i++) {
if (mask_shape_reverse[i] != x_shape_reverse[i] && mask_shape_reverse[i] != 1) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the shape of 'mask': " << ShapeVectorToStr(mask_shape)
<< " can not be broadcast to the shape of 'x': " << ShapeVectorToStr(x_shape) << ".";
}
}
return ret;
}
template <typename T>
bool MaskedScatterCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMaskedScatterInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMaskedScatterOutputsNum, kernel_name_);
auto x = reinterpret_cast<T *>(inputs[0]->addr);
auto mask = reinterpret_cast<bool *>(inputs[1]->addr);
auto updates = reinterpret_cast<T *>(inputs[2]->addr);
auto y = reinterpret_cast<T *>(outputs[0]->addr);
uint64_t j = 0;
if (!need_broadcast_) {
for (uint64_t i = 0; i < x_numElements_; i++) {
if (mask[i]) {
if (j >= updates_numElements_) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
<< "', number of elements of updates < number of ones in mask.";
}
y[i] = updates[j], j += 1;
} else {
y[i] = x[i];
}
}
} else {
BroadcastIterator iter(x_shape_, mask_shape_, output_shape_);
iter.SetPos(0);
for (uint64_t i = 0; i < x_numElements_; i++, iter.GenNextPos()) {
if (mask[iter.GetInputPosB()]) {
if (j >= updates_numElements_) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
<< "', number of elements of updates < number of ones in mask.";
}
y[iter.GetInputPosA()] = updates[j], j += 1;
} else {
y[i] = x[i];
}
}
}
return true;
}
std::vector<std::pair<KernelAttr, MaskedScatterCpuKernelMod::MaskedScatterFunc>> MaskedScatterCpuKernelMod::func_list_ =
{{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&MaskedScatterCpuKernelMod::LaunchKernel<float16>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&MaskedScatterCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&MaskedScatterCpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
&MaskedScatterCpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
&MaskedScatterCpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
&MaskedScatterCpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&MaskedScatterCpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&MaskedScatterCpuKernelMod::LaunchKernel<int64_t>}};
std::vector<KernelAttr> MaskedScatterCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MaskedScatterFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MaskedScatter, MaskedScatterCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,64 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SCATTER_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SCATTER_CPU_KERNEL_H_
#include <vector>
#include <map>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class MaskedScatterCpuKernelMod : public NativeCpuKernelMod {
public:
MaskedScatterCpuKernelMod() = default;
~MaskedScatterCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using MaskedScatterFunc = std::function<bool(MaskedScatterCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, MaskedScatterFunc>> func_list_;
MaskedScatterFunc kernel_func_;
std::vector<int64_t> x_shape_;
std::vector<int64_t> mask_shape_;
std::vector<int64_t> updates_shape_;
std::vector<int64_t> output_shape_;
uint64_t x_numElements_ = 1;
uint64_t updates_numElements_ = 1;
bool need_broadcast_{false};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SCATTER_CPU_KERNEL_H_

View File

@ -666,6 +666,7 @@ GVAR_DEF(PrimitivePtr, kPrimResizeLinear1D, std::make_shared<Primitive>("ResizeL
GVAR_DEF(PrimitivePtr, kPrimResizeLinear1DGrad, std::make_shared<Primitive>("ResizeLinear1DGrad"));
GVAR_DEF(PrimitivePtr, kPrimSort, std::make_shared<Primitive>("Sort"));
GVAR_DEF(PrimitivePtr, kPrimMaskedFill, std::make_shared<Primitive>("MaskedFill"));
GVAR_DEF(PrimitivePtr, kPrimMaskedScatter, std::make_shared<Primitive>("MaskedScatter"));
GVAR_DEF(PrimitivePtr, kPrimMaskedSelect, std::make_shared<Primitive>("MaskedSelect"));
GVAR_DEF(PrimitivePtr, kPrimMaskedSelectGrad, std::make_shared<Primitive>("MaskedSelectGrad"));
GVAR_DEF(PrimitivePtr, kPrimDiag, std::make_shared<Primitive>(kDiag));

View File

@ -0,0 +1,73 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <map>
#include <string>
#include <set>
#include "ops/masked_scatter.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr MaskedScatterInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
const int64_t input_num = 3;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape());
auto x_shape = x_shape_map[kShape];
auto mask_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape());
auto mask_shape = mask_shape_map[kShape];
CheckAndConvertUtils::CheckInteger("dim of input_x", x_shape.size(), kGreaterEqual, mask_shape.size(), op_name);
return std::make_shared<abstract::Shape>(x_shape);
}
TypePtr MaskedScatterInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = prim->name();
(void)CheckAndConvertUtils::CheckTensorTypeValid("mask", input_args[1]->BuildType(), {kBool}, op_name);
std::set<TypePtr> valid_types;
valid_types = {kFloat16, kFloat32, kFloat64, kUInt8, kInt8, kInt16, kInt32, kInt64};
auto x_type = input_args[kInputIndex0]->BuildType();
auto updates_type = input_args[kInputIndex2]->BuildType();
std::map<std::string, TypePtr> types;
(void)types.emplace("x", x_type);
(void)types.emplace("updates", updates_type);
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("updates", updates_type, valid_types, op_name);
return x_type;
}
} // namespace
MIND_API_OPERATOR_IMPL(MaskedScatter, BaseOperator);
AbstractBasePtr MaskedScatterInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kMaskedScaterInputsNum = 3;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kMaskedScaterInputsNum, primitive->name());
auto infer_type = MaskedScatterInferType(primitive, input_args);
auto infer_shape = MaskedScatterInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(MaskedScatter, prim::kPrimMaskedScatter, MaskedScatterInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MASKED_SCATTER_H_
#define MINDSPORE_CORE_OPS_MASKED_SCATTER_H_
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMaskedScatter = "MaskedScatter";
/// \brief Returns a Tensor as same shape and type with the input tensor according to the boolean mask and updates.
/// Refer to Python API @ref mindspore.ops.MaskedScatter for more details.
class MIND_API MaskedScatter : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MaskedScatter);
/// \brief Constructor.
MaskedScatter() : BaseOperator(kNameMaskedScatter) { InitIOName({"x", "mask", "updates"}, {"y"}); }
};
AbstractBasePtr MaskedScatterInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimMaskedScatterPtr = std::shared_ptr<MaskedScatter>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MASKED_SCATTER_H_

View File

@ -124,6 +124,44 @@ def get_bprop_masked_select(self):
return bprop
@bprop_getters.register(P.MaskedScatter)
def get_bprop_masked_scatter(self):
"""Generate bprop for MaskedScatter"""
sort_ = P.Sort(descending=True)
masked_scatter = P.MaskedScatter()
masked_fill = P.MaskedFill()
masked_select = P.MaskedSelect()
size = P.Size()
zeros = P.Zeros()
concat = P.Concat(axis=0)
reshape = P.Reshape()
shape = P.Shape()
def bprop(x, mask, updates, out, dout):
dx = masked_fill(F.cast(dout, mstype.float32), mask, 0.0)
mask_selected = masked_select(F.cast(dout, mstype.float32), mask)
mask_broad = mask
if shape(mask) != shape(x):
broad_cast = P.BroadcastTo(shape(x))
mask_broad = broad_cast(mask)
mask_broad_vec = mask_broad.reshape(-1)
mask_sorted = F.cast(sort_(F.cast(mask_broad_vec, mstype.float32))[0], F.dtype(mask))
diff_num = size(updates) - size(mask_broad)
if diff_num > 0:
zeros_pad = zeros(diff_num, F.dtype(mask))
mask_sorted = concat((mask_sorted, zeros_pad))
zeros_tensor = zeros(size(updates), mstype.float32)
dupdates = masked_scatter(zeros_tensor, mask_sorted, mask_selected)
if shape(updates) != ():
dupdates = reshape(dupdates, shape(updates))
else:
zeros_tensor = zeros(shape(updates), mstype.float32)
dupdates = masked_scatter(zeros_tensor, mask, mask_selected)
return F.cast(dx, F.dtype(x)), zeros_like(mask), F.cast(dupdates, F.dtype(updates))
return bprop
@bprop_getters.register(Mvlgamma)
def get_bprop_mvlgamma(self):
"""Grad definition for Mvlgamma"""

View File

@ -42,6 +42,7 @@ from .print_tensor import _print_aicpu
from .topk import _top_k_aicpu
from .log1p import _log1p_aicpu
from .asin import _asin_aicpu
from .masked_scatter import _masked_scatter_aicpu
from .is_finite import _is_finite_aicpu
from .is_inf import _is_inf_aicpu
from .is_nan import _is_nan_aicpu

View File

@ -0,0 +1,39 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskedScatter op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
masked_scatter_op_info = AiCPURegOp("MaskedScatter") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.input(1, "mask", "required") \
.input(2, "updates", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.U8_Default, DataType.BOOL_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I8_Default, DataType.BOOL_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.BOOL_Default, DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F16_Default, DataType.BOOL_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.BOOL_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.BOOL_Default, DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(masked_scatter_op_info)
def _masked_scatter_aicpu():
"""MaskedScatter AiCPU register"""
return

View File

@ -102,6 +102,7 @@ from .array_func import (
matrix_diag_part,
matrix_set_diag,
diag,
masked_scatter,
masked_select,
meshgrid,
affine_grid,

View File

@ -25,6 +25,7 @@ from mindspore.ops.operations.array_ops import (
UniqueConsecutive,
SearchSorted,
NonZero,
MaskedScatter,
MatrixDiagV3,
MatrixDiagPartV3,
MatrixSetDiagV3,
@ -4132,6 +4133,43 @@ def tuple_to_array(input_x):
return tuple_to_array_(input_x)
def masked_scatter(x, mask, updates):
"""
Updates the value in the input with the updates value according to the mask.
The shapes of `mask` and `x` must be the same or broadcastable.
Args:
x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
mask (Tensor[bool]): A bool tensor with a shape broadcastable to x.
updates (Tensor): A tensor with the same data type as x. The
number of elements must be greater than or equal to the number of True's in `mask`.
Outputs:
y (Tensor), with the same type and shape as x.
Raises:
TypeError: If `x`, `mask` or `updates` is not a Tensor.
TypeError: If data type of `x` is not be supported.
TypeError: If dtype of `mask` is not bool.
TypeError: If the dim of `x` less than the dim of `mask`.
ValueError: If `mask` can not be broadcastable to `x`.
ValueError: If the number of elements in `updates` is less than the number required for the updates.
Supported Platforms:
``CPU``
Examples:
>>> x= Tensor(np.array([1., 2., 3., 4.]), mindspore.float32)
>>> mask = Tensor(np.array([True, True, False, True]), mindspore.bool_)
>>> updates = Tensor(np.array([5., 6., 7.]), mindspore.float32)
>>> output = ops.MaskedScatter()(input_X, mask, updates)
>>> print(output)
[5. 6. 3. 7.]
"""
masked_scatter_ = MaskedScatter()
return masked_scatter_(x, mask, updates)
def masked_select(x, mask):
"""
Returns a new 1-D Tensor which indexes the `x` tensor according to the boolean `mask`.
@ -5067,6 +5105,7 @@ __all__ = [
'gather_nd',
'one_hot',
'masked_fill',
'masked_scatter',
'masked_select',
'narrow',
'scatter_add',

View File

@ -35,7 +35,7 @@ from .array_ops import (ArgMaxWithValue, ArgMinWithValue, Argmax, Argmin, BatchT
Eye, Fill, Gather, GatherD, GatherNd, GatherV2, Identity, Im2Col, InvertPermutation,
LowerBound, Lstsq, MaskedFill, MaskedSelect, Meshgrid, Mvlgamma, Ones, OnesLike,
Pack, Padding, ParallelConcat, PopulationCount, Range, Rank, Reshape, ResizeNearestNeighbor,
ReverseSequence, ReverseV2, Rint, ScalarToTensor, ScatterAdd,
ReverseSequence, ReverseV2, Rint, ScalarToTensor, ScatterAdd, MaskedScatter,
ScatterDiv, ScatterMax, ScatterMin, ScatterMul, ScatterNd, ScatterNdAdd, ScatterNdDiv,
ScatterNdMax, ScatterNdMin, ScatterNdSub, ScatterNdUpdate, ScatterNonAliasingAdd, ScatterSub,
ScatterUpdate, SearchSorted, Select, Shape, Size, Slice, Sort, SpaceToBatch, SpaceToBatchND,
@ -148,6 +148,7 @@ __all__ = [
'BatchMatMul',
'Mul',
'MaskedFill',
'MaskedScatter',
'MaskedSelect',
'Meshgrid',
'MultiMarginLoss',

View File

@ -6018,6 +6018,32 @@ class MaskedFill(Primitive):
self.init_prim_io_names(inputs=['input', 'mask', 'value'], outputs=['output'])
class MaskedScatter(Primitive):
"""
Updates the value in the input with the updates value according to the mask.
The shapes of `mask` and `x` must be the same or broadcastable.
Refer to :func:`mindspore.ops.masked_scatter' for more details.
Supported Platforms:
``CPU``
Examples:
>>> x= Tensor(np.array([1., 2., 3., 4.]), mindspore.float32)
>>> mask = Tensor(np.array([True, True, False, True]), mindspore.bool_)
>>> updates = Tensor(np.array([5., 6., 7.]), mindspore.float32)
>>> output = ops.MaskedScatter()(input_X, mask, updates)
>>> print(output)
[5. 6. 3. 7.]
"""
@prim_attr_register
def __init__(self):
"""Initialize MaskedScatter"""
self.init_prim_io_names(inputs=['x', 'mask', 'updates'], outputs=['y'])
self.add_prim_attr("cust_aicpu", "MaskedScatter")
class MaskedSelect(PrimitiveWithCheck):
"""
Returns a new 1-D Tensor which indexes the `x` tensor according to the boolean `mask`.

View File

@ -89,6 +89,7 @@ from mindspore.ops.operations.other_ops import SampleDistortedBoundingBoxV2
from mindspore.ops.operations.array_ops import Triu
from mindspore.ops.operations.array_ops import ResizeNearestNeighborV2
from mindspore.ops.operations._grad_ops import ResizeNearestNeighborV2Grad
from mindspore.ops.operations.array_ops import MaskedScatter
from mindspore.ops.operations.array_ops import MatrixDiagV3
from mindspore.ops.operations.array_ops import MatrixDiagPartV3
from mindspore.ops.operations.array_ops import MatrixSetDiagV3
@ -3943,6 +3944,13 @@ test_case_array_ops = [
Tensor(4.0, mstype.float32)],
'desc_bprop': [Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mstype.float32)],
}),
('MaskedScatter', {
'block': MaskedScatter(),
'desc_inputs': [Tensor(np.array([[1.0, 2.0, 3.0]]), mstype.float32),
Tensor(np.array([[True, True, False]]), mstype.bool_),
Tensor(np.array([[4.0, 5.0]]), mstype.float32)],
'desc_bprop': [Tensor(np.array([[4.0, 5.0, 3.0]]), mstype.float32)],
}),
('MaskedFill', {
'block': P.MaskedFill(),
'desc_inputs': [Tensor(np.array([[1.0, 2.0, 3.0]]), mstype.float32),