forked from mindspore-Ecosystem/mindspore
!36606 [assistant][ops][aicpu][五期][I4XJI8] Add MaskedScatter operator
Merge pull request !36606 from 桂宁馨/MaskedScatter
This commit is contained in:
commit
2d62478890
|
@ -0,0 +1,176 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/masked_scatter_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
#include "mindspore/core/ops/masked_scatter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kMaskedScatterInputsNum = 3;
|
||||
constexpr size_t kMaskedScatterOutputsNum = 1;
|
||||
} // namespace
|
||||
|
||||
bool MaskedScatterCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "MaskedScatter does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
int MaskedScatterCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
int ret = KRET_OK;
|
||||
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs)) != 0) {
|
||||
return ret;
|
||||
}
|
||||
std::vector<int64_t> x_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
std::vector<int64_t> mask_shape = inputs.at(kIndex1)->GetShapeVector();
|
||||
std::vector<int64_t> updates_shape = inputs.at(kIndex2)->GetShapeVector();
|
||||
std::vector<int64_t> output_shape = outputs.at(kIndex0)->GetShapeVector();
|
||||
std::transform(x_shape.begin(), x_shape.end(), std::back_inserter(x_shape_), LongToSize);
|
||||
std::transform(mask_shape.begin(), mask_shape.end(), std::back_inserter(mask_shape_), LongToSize);
|
||||
std::transform(updates_shape.begin(), updates_shape.end(), std::back_inserter(updates_shape_), LongToSize);
|
||||
std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(output_shape_), LongToSize);
|
||||
x_numElements_ = std::accumulate(x_shape_.begin(), x_shape_.end(), 1, std::multiplies<size_t>());
|
||||
updates_numElements_ = std::accumulate(updates_shape_.begin(), updates_shape_.end(), 1, std::multiplies<size_t>());
|
||||
need_broadcast_ = (x_shape_ == mask_shape_) ? false : true;
|
||||
size_t mask_dims = mask_shape.size();
|
||||
std::vector<int64_t> x_shape_reverse = x_shape_;
|
||||
std::vector<int64_t> mask_shape_reverse = mask_shape_;
|
||||
std::reverse(x_shape_reverse.begin(), x_shape_reverse.end());
|
||||
std::reverse(mask_shape_reverse.begin(), mask_shape_reverse.end());
|
||||
for (size_t i = 0; i < mask_dims; i++) {
|
||||
if (mask_shape_reverse[i] != x_shape_reverse[i] && mask_shape_reverse[i] != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the shape of 'mask': " << ShapeVectorToStr(mask_shape)
|
||||
<< " can not be broadcast to the shape of 'x': " << ShapeVectorToStr(x_shape) << ".";
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MaskedScatterCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMaskedScatterInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMaskedScatterOutputsNum, kernel_name_);
|
||||
|
||||
auto x = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto mask = reinterpret_cast<bool *>(inputs[1]->addr);
|
||||
auto updates = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
auto y = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
uint64_t j = 0;
|
||||
if (!need_broadcast_) {
|
||||
for (uint64_t i = 0; i < x_numElements_; i++) {
|
||||
if (mask[i]) {
|
||||
if (j >= updates_numElements_) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
|
||||
<< "', number of elements of updates < number of ones in mask.";
|
||||
}
|
||||
y[i] = updates[j], j += 1;
|
||||
} else {
|
||||
y[i] = x[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
BroadcastIterator iter(x_shape_, mask_shape_, output_shape_);
|
||||
iter.SetPos(0);
|
||||
for (uint64_t i = 0; i < x_numElements_; i++, iter.GenNextPos()) {
|
||||
if (mask[iter.GetInputPosB()]) {
|
||||
if (j >= updates_numElements_) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
|
||||
<< "', number of elements of updates < number of ones in mask.";
|
||||
}
|
||||
y[iter.GetInputPosA()] = updates[j], j += 1;
|
||||
} else {
|
||||
y[i] = x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, MaskedScatterCpuKernelMod::MaskedScatterFunc>> MaskedScatterCpuKernelMod::func_list_ =
|
||||
{{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&MaskedScatterCpuKernelMod::LaunchKernel<float16>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&MaskedScatterCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&MaskedScatterCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&MaskedScatterCpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&MaskedScatterCpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&MaskedScatterCpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&MaskedScatterCpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&MaskedScatterCpuKernelMod::LaunchKernel<int64_t>}};
|
||||
|
||||
std::vector<KernelAttr> MaskedScatterCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, MaskedScatterFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MaskedScatter, MaskedScatterCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SCATTER_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SCATTER_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class MaskedScatterCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
MaskedScatterCpuKernelMod() = default;
|
||||
~MaskedScatterCpuKernelMod() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
using MaskedScatterFunc = std::function<bool(MaskedScatterCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, MaskedScatterFunc>> func_list_;
|
||||
MaskedScatterFunc kernel_func_;
|
||||
std::vector<int64_t> x_shape_;
|
||||
std::vector<int64_t> mask_shape_;
|
||||
std::vector<int64_t> updates_shape_;
|
||||
std::vector<int64_t> output_shape_;
|
||||
uint64_t x_numElements_ = 1;
|
||||
uint64_t updates_numElements_ = 1;
|
||||
bool need_broadcast_{false};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SCATTER_CPU_KERNEL_H_
|
|
@ -666,6 +666,7 @@ GVAR_DEF(PrimitivePtr, kPrimResizeLinear1D, std::make_shared<Primitive>("ResizeL
|
|||
GVAR_DEF(PrimitivePtr, kPrimResizeLinear1DGrad, std::make_shared<Primitive>("ResizeLinear1DGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSort, std::make_shared<Primitive>("Sort"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMaskedFill, std::make_shared<Primitive>("MaskedFill"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMaskedScatter, std::make_shared<Primitive>("MaskedScatter"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMaskedSelect, std::make_shared<Primitive>("MaskedSelect"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMaskedSelectGrad, std::make_shared<Primitive>("MaskedSelectGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimDiag, std::make_shared<Primitive>(kDiag));
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include "ops/masked_scatter.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr MaskedScatterInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
const int64_t input_num = 3;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
|
||||
auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape());
|
||||
auto x_shape = x_shape_map[kShape];
|
||||
auto mask_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape());
|
||||
auto mask_shape = mask_shape_map[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("dim of input_x", x_shape.size(), kGreaterEqual, mask_shape.size(), op_name);
|
||||
|
||||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
}
|
||||
|
||||
TypePtr MaskedScatterInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("mask", input_args[1]->BuildType(), {kBool}, op_name);
|
||||
std::set<TypePtr> valid_types;
|
||||
valid_types = {kFloat16, kFloat32, kFloat64, kUInt8, kInt8, kInt16, kInt32, kInt64};
|
||||
auto x_type = input_args[kInputIndex0]->BuildType();
|
||||
auto updates_type = input_args[kInputIndex2]->BuildType();
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", x_type);
|
||||
(void)types.emplace("updates", updates_type);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("updates", updates_type, valid_types, op_name);
|
||||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(MaskedScatter, BaseOperator);
|
||||
AbstractBasePtr MaskedScatterInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kMaskedScaterInputsNum = 3;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kMaskedScaterInputsNum, primitive->name());
|
||||
auto infer_type = MaskedScatterInferType(primitive, input_args);
|
||||
auto infer_shape = MaskedScatterInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MaskedScatter, prim::kPrimMaskedScatter, MaskedScatterInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_MASKED_SCATTER_H_
|
||||
#define MINDSPORE_CORE_OPS_MASKED_SCATTER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMaskedScatter = "MaskedScatter";
|
||||
/// \brief Returns a Tensor as same shape and type with the input tensor according to the boolean mask and updates.
|
||||
/// Refer to Python API @ref mindspore.ops.MaskedScatter for more details.
|
||||
class MIND_API MaskedScatter : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(MaskedScatter);
|
||||
/// \brief Constructor.
|
||||
MaskedScatter() : BaseOperator(kNameMaskedScatter) { InitIOName({"x", "mask", "updates"}, {"y"}); }
|
||||
};
|
||||
AbstractBasePtr MaskedScatterInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimMaskedScatterPtr = std::shared_ptr<MaskedScatter>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_MASKED_SCATTER_H_
|
|
@ -124,6 +124,44 @@ def get_bprop_masked_select(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.MaskedScatter)
|
||||
def get_bprop_masked_scatter(self):
|
||||
"""Generate bprop for MaskedScatter"""
|
||||
sort_ = P.Sort(descending=True)
|
||||
masked_scatter = P.MaskedScatter()
|
||||
masked_fill = P.MaskedFill()
|
||||
masked_select = P.MaskedSelect()
|
||||
size = P.Size()
|
||||
zeros = P.Zeros()
|
||||
concat = P.Concat(axis=0)
|
||||
reshape = P.Reshape()
|
||||
shape = P.Shape()
|
||||
|
||||
def bprop(x, mask, updates, out, dout):
|
||||
dx = masked_fill(F.cast(dout, mstype.float32), mask, 0.0)
|
||||
mask_selected = masked_select(F.cast(dout, mstype.float32), mask)
|
||||
mask_broad = mask
|
||||
if shape(mask) != shape(x):
|
||||
broad_cast = P.BroadcastTo(shape(x))
|
||||
mask_broad = broad_cast(mask)
|
||||
mask_broad_vec = mask_broad.reshape(-1)
|
||||
mask_sorted = F.cast(sort_(F.cast(mask_broad_vec, mstype.float32))[0], F.dtype(mask))
|
||||
diff_num = size(updates) - size(mask_broad)
|
||||
if diff_num > 0:
|
||||
zeros_pad = zeros(diff_num, F.dtype(mask))
|
||||
mask_sorted = concat((mask_sorted, zeros_pad))
|
||||
zeros_tensor = zeros(size(updates), mstype.float32)
|
||||
dupdates = masked_scatter(zeros_tensor, mask_sorted, mask_selected)
|
||||
if shape(updates) != ():
|
||||
dupdates = reshape(dupdates, shape(updates))
|
||||
else:
|
||||
zeros_tensor = zeros(shape(updates), mstype.float32)
|
||||
dupdates = masked_scatter(zeros_tensor, mask, mask_selected)
|
||||
return F.cast(dx, F.dtype(x)), zeros_like(mask), F.cast(dupdates, F.dtype(updates))
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(Mvlgamma)
|
||||
def get_bprop_mvlgamma(self):
|
||||
"""Grad definition for Mvlgamma"""
|
||||
|
|
|
@ -42,6 +42,7 @@ from .print_tensor import _print_aicpu
|
|||
from .topk import _top_k_aicpu
|
||||
from .log1p import _log1p_aicpu
|
||||
from .asin import _asin_aicpu
|
||||
from .masked_scatter import _masked_scatter_aicpu
|
||||
from .is_finite import _is_finite_aicpu
|
||||
from .is_inf import _is_inf_aicpu
|
||||
from .is_nan import _is_nan_aicpu
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""MaskedScatter op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
masked_scatter_op_info = AiCPURegOp("MaskedScatter") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "mask", "required") \
|
||||
.input(2, "updates", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.U8_Default, DataType.BOOL_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.BOOL_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.BOOL_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.BOOL_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.BOOL_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.BOOL_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(masked_scatter_op_info)
|
||||
def _masked_scatter_aicpu():
|
||||
"""MaskedScatter AiCPU register"""
|
||||
return
|
|
@ -102,6 +102,7 @@ from .array_func import (
|
|||
matrix_diag_part,
|
||||
matrix_set_diag,
|
||||
diag,
|
||||
masked_scatter,
|
||||
masked_select,
|
||||
meshgrid,
|
||||
affine_grid,
|
||||
|
|
|
@ -25,6 +25,7 @@ from mindspore.ops.operations.array_ops import (
|
|||
UniqueConsecutive,
|
||||
SearchSorted,
|
||||
NonZero,
|
||||
MaskedScatter,
|
||||
MatrixDiagV3,
|
||||
MatrixDiagPartV3,
|
||||
MatrixSetDiagV3,
|
||||
|
@ -4132,6 +4133,43 @@ def tuple_to_array(input_x):
|
|||
return tuple_to_array_(input_x)
|
||||
|
||||
|
||||
def masked_scatter(x, mask, updates):
|
||||
"""
|
||||
Updates the value in the input with the updates value according to the mask.
|
||||
The shapes of `mask` and `x` must be the same or broadcastable.
|
||||
|
||||
Args:
|
||||
x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
mask (Tensor[bool]): A bool tensor with a shape broadcastable to x.
|
||||
updates (Tensor): A tensor with the same data type as x. The
|
||||
number of elements must be greater than or equal to the number of True's in `mask`.
|
||||
|
||||
Outputs:
|
||||
y (Tensor), with the same type and shape as x.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x`, `mask` or `updates` is not a Tensor.
|
||||
TypeError: If data type of `x` is not be supported.
|
||||
TypeError: If dtype of `mask` is not bool.
|
||||
TypeError: If the dim of `x` less than the dim of `mask`.
|
||||
ValueError: If `mask` can not be broadcastable to `x`.
|
||||
ValueError: If the number of elements in `updates` is less than the number required for the updates.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x= Tensor(np.array([1., 2., 3., 4.]), mindspore.float32)
|
||||
>>> mask = Tensor(np.array([True, True, False, True]), mindspore.bool_)
|
||||
>>> updates = Tensor(np.array([5., 6., 7.]), mindspore.float32)
|
||||
>>> output = ops.MaskedScatter()(input_X, mask, updates)
|
||||
>>> print(output)
|
||||
[5. 6. 3. 7.]
|
||||
"""
|
||||
masked_scatter_ = MaskedScatter()
|
||||
return masked_scatter_(x, mask, updates)
|
||||
|
||||
|
||||
def masked_select(x, mask):
|
||||
"""
|
||||
Returns a new 1-D Tensor which indexes the `x` tensor according to the boolean `mask`.
|
||||
|
@ -5067,6 +5105,7 @@ __all__ = [
|
|||
'gather_nd',
|
||||
'one_hot',
|
||||
'masked_fill',
|
||||
'masked_scatter',
|
||||
'masked_select',
|
||||
'narrow',
|
||||
'scatter_add',
|
||||
|
|
|
@ -35,7 +35,7 @@ from .array_ops import (ArgMaxWithValue, ArgMinWithValue, Argmax, Argmin, BatchT
|
|||
Eye, Fill, Gather, GatherD, GatherNd, GatherV2, Identity, Im2Col, InvertPermutation,
|
||||
LowerBound, Lstsq, MaskedFill, MaskedSelect, Meshgrid, Mvlgamma, Ones, OnesLike,
|
||||
Pack, Padding, ParallelConcat, PopulationCount, Range, Rank, Reshape, ResizeNearestNeighbor,
|
||||
ReverseSequence, ReverseV2, Rint, ScalarToTensor, ScatterAdd,
|
||||
ReverseSequence, ReverseV2, Rint, ScalarToTensor, ScatterAdd, MaskedScatter,
|
||||
ScatterDiv, ScatterMax, ScatterMin, ScatterMul, ScatterNd, ScatterNdAdd, ScatterNdDiv,
|
||||
ScatterNdMax, ScatterNdMin, ScatterNdSub, ScatterNdUpdate, ScatterNonAliasingAdd, ScatterSub,
|
||||
ScatterUpdate, SearchSorted, Select, Shape, Size, Slice, Sort, SpaceToBatch, SpaceToBatchND,
|
||||
|
@ -148,6 +148,7 @@ __all__ = [
|
|||
'BatchMatMul',
|
||||
'Mul',
|
||||
'MaskedFill',
|
||||
'MaskedScatter',
|
||||
'MaskedSelect',
|
||||
'Meshgrid',
|
||||
'MultiMarginLoss',
|
||||
|
|
|
@ -6018,6 +6018,32 @@ class MaskedFill(Primitive):
|
|||
self.init_prim_io_names(inputs=['input', 'mask', 'value'], outputs=['output'])
|
||||
|
||||
|
||||
class MaskedScatter(Primitive):
|
||||
"""
|
||||
Updates the value in the input with the updates value according to the mask.
|
||||
The shapes of `mask` and `x` must be the same or broadcastable.
|
||||
|
||||
Refer to :func:`mindspore.ops.masked_scatter' for more details.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x= Tensor(np.array([1., 2., 3., 4.]), mindspore.float32)
|
||||
>>> mask = Tensor(np.array([True, True, False, True]), mindspore.bool_)
|
||||
>>> updates = Tensor(np.array([5., 6., 7.]), mindspore.float32)
|
||||
>>> output = ops.MaskedScatter()(input_X, mask, updates)
|
||||
>>> print(output)
|
||||
[5. 6. 3. 7.]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize MaskedScatter"""
|
||||
self.init_prim_io_names(inputs=['x', 'mask', 'updates'], outputs=['y'])
|
||||
self.add_prim_attr("cust_aicpu", "MaskedScatter")
|
||||
|
||||
|
||||
class MaskedSelect(PrimitiveWithCheck):
|
||||
"""
|
||||
Returns a new 1-D Tensor which indexes the `x` tensor according to the boolean `mask`.
|
||||
|
|
|
@ -89,6 +89,7 @@ from mindspore.ops.operations.other_ops import SampleDistortedBoundingBoxV2
|
|||
from mindspore.ops.operations.array_ops import Triu
|
||||
from mindspore.ops.operations.array_ops import ResizeNearestNeighborV2
|
||||
from mindspore.ops.operations._grad_ops import ResizeNearestNeighborV2Grad
|
||||
from mindspore.ops.operations.array_ops import MaskedScatter
|
||||
from mindspore.ops.operations.array_ops import MatrixDiagV3
|
||||
from mindspore.ops.operations.array_ops import MatrixDiagPartV3
|
||||
from mindspore.ops.operations.array_ops import MatrixSetDiagV3
|
||||
|
@ -3943,6 +3944,13 @@ test_case_array_ops = [
|
|||
Tensor(4.0, mstype.float32)],
|
||||
'desc_bprop': [Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mstype.float32)],
|
||||
}),
|
||||
('MaskedScatter', {
|
||||
'block': MaskedScatter(),
|
||||
'desc_inputs': [Tensor(np.array([[1.0, 2.0, 3.0]]), mstype.float32),
|
||||
Tensor(np.array([[True, True, False]]), mstype.bool_),
|
||||
Tensor(np.array([[4.0, 5.0]]), mstype.float32)],
|
||||
'desc_bprop': [Tensor(np.array([[4.0, 5.0, 3.0]]), mstype.float32)],
|
||||
}),
|
||||
('MaskedFill', {
|
||||
'block': P.MaskedFill(),
|
||||
'desc_inputs': [Tensor(np.array([[1.0, 2.0, 3.0]]), mstype.float32),
|
||||
|
|
Loading…
Reference in New Issue