Add MaskedSelect CPU Operation

This commit is contained in:
wuxuejian 2021-05-31 15:30:24 +08:00
parent 2c0dd077f5
commit 5ad5e427d0
20 changed files with 532 additions and 3 deletions

View File

@ -58,7 +58,9 @@ constexpr auto kPadAndShift = "PadAndShift";
constexpr auto kCustRunApi = "RunCpuKernel";
constexpr auto kDropout2D = "Dropout2D";
constexpr auto kDropout3D = "Dropout3D";
const std::set<std::string> kCustAiCpuKernelOps{kIdentity};
constexpr auto kMaskedSelect = "MaskedSelect";
constexpr auto kMaskedSelectGrad = "MaskedSelectGrad";
const std::set<std::string> kCustAiCpuKernelOps{kIdentity, kMaskedSelect, kMaskedSelectGrad};
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter,
kPadAndShift, kDropout3D, kDropout2D};

View File

@ -236,5 +236,36 @@ void TransposeIterator::GenNextPos() {
}
}
}
std::vector<size_t> CPUKernelUtils::GetBroadcastShape(const std::vector<size_t> &x, const std::vector<size_t> &y) {
size_t x_len = x.size();
size_t y_len = y.size();
size_t length = x_len < y_len ? x_len : y_len;
std::vector<size_t> broadcast_shape;
std::vector<size_t> broadcast_shape_back;
for (int i = -length; i < 0; ++i) {
if (x[x_len + i] == 1) {
broadcast_shape_back.push_back(y[y_len + i]);
} else if (y[y_len + i] == 1) {
broadcast_shape_back.push_back(x[x_len + i]);
} else if (x[x_len + i] == y[y_len + i]) {
broadcast_shape_back.push_back(x[x_len + i]);
}
}
if (length == x_len) {
for (size_t i = 0; i < y_len - length; ++i) {
broadcast_shape.push_back(y[i]);
}
} else {
for (size_t i = 0; i < x_len - length; ++i) {
broadcast_shape.push_back(x[i]);
}
}
for (size_t i = 0; i < length; ++i) {
broadcast_shape.push_back(broadcast_shape_back[i]);
}
return broadcast_shape;
}
} // namespace kernel
} // namespace mindspore

View File

@ -153,6 +153,7 @@ class CPUKernelUtils {
static void GetElementNumEveryDim(const std::vector<size_t> &shape, std::vector<size_t> *element_num);
static void ParallelFor(const CTask &task, size_t count);
static std::vector<size_t> FlatShapeByAxis(const std::vector<size_t> &shape, int axis);
static std::vector<size_t> GetBroadcastShape(const std::vector<size_t> &x, const std::vector<size_t> &y);
};
class BroadcastIterator {

View File

@ -0,0 +1,83 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/masked_select_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
template <typename T>
void MaskedSelectCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but MaskedSelectCPUKernel needs 2 input.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but MaskedSelectCPUKernel needs 1 output.";
}
input_shape_a_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
input_shape_b_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
output_shape_ = CPUKernelUtils::GetBroadcastShape(input_shape_a_, input_shape_b_);
for (const uint64_t &d : output_shape_) {
tensor_size_ *= d;
}
node_wpt_ = kernel_node;
}
template <typename T>
bool MaskedSelectCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
auto x = reinterpret_cast<T *>(inputs[0]->addr);
auto mask = reinterpret_cast<bool *>(inputs[1]->addr);
auto y = reinterpret_cast<T *>(outputs[0]->addr);
uint64_t j = 0;
if (input_shape_a_ == input_shape_b_) {
for (uint64_t i = 0; i < tensor_size_; ++i) {
if (mask[i]) {
y[j++] = x[i];
}
}
} else { // Broadcast
BroadcastIterator iter(input_shape_a_, input_shape_b_, output_shape_);
iter.SetPos(0);
for (uint64_t i = 0; i < tensor_size_; ++i) {
if (mask[iter.GetInputPosB()]) {
y[j++] = x[iter.GetInputPosA()];
}
iter.GenNextPos();
}
}
if (!node_wpt_.expired()) {
auto node_ = node_wpt_.lock();
if (!node_) {
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
}
std::vector<size_t> out_shape;
out_shape.emplace_back(j);
std::vector<TypeId> dtypes;
size_t output_num = AnfAlgo::GetOutputTensorNum(node_);
for (size_t i = 0; i < output_num; i++) {
dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i));
}
AnfAlgo::SetOutputInferTypeAndShape(dtypes, {out_shape}, node_.get());
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_CPU_KERNEL_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class MaskedSelectCPUKernel : public CPUKernel {
public:
MaskedSelectCPUKernel() = default;
~MaskedSelectCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
std::vector<size_t> input_shape_a_;
std::vector<size_t> input_shape_b_;
std::vector<size_t> output_shape_;
uint64_t tensor_size_ = 1;
CNodeWeakPtr node_wpt_;
};
MS_REG_CPU_KERNEL_T(
MaskedSelect,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32),
MaskedSelectCPUKernel, float);
MS_REG_CPU_KERNEL_T(
MaskedSelect,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
MaskedSelectCPUKernel, int);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_CPU_KERNEL_H_

View File

@ -0,0 +1,70 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/masked_select_grad_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
template <typename T>
void MaskedSelectGradCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but MaskedSelectGradCPUKernel needs 3 input.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but MaskedSelectGradCPUKernel needs 1 output.";
}
input_shape_a_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
input_shape_b_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
grad_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
output_shape_ = CPUKernelUtils::GetBroadcastShape(input_shape_a_, input_shape_b_);
for (const uint64_t &d : output_shape_) {
tensor_size_ *= d;
}
}
template <typename T>
bool MaskedSelectGradCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
auto mask = reinterpret_cast<bool *>(inputs[1]->addr);
auto grad = reinterpret_cast<T *>(inputs[2]->addr);
auto dx = reinterpret_cast<T *>(outputs[0]->addr);
uint64_t output_size = outputs[0]->size / sizeof(T);
uint64_t j = 0;
if (input_shape_a_ == input_shape_b_) {
for (uint64_t i = 0; i < output_size; ++i) {
if (mask[i]) {
dx[i] += grad[j++];
}
}
} else {
BroadcastIterator iter(input_shape_a_, input_shape_b_, output_shape_);
iter.SetPos(0);
for (uint64_t i = 0; i < tensor_size_; ++i) {
if (mask[iter.GetInputPosB()]) {
dx[iter.GetInputPosA()] += grad[j++];
}
iter.GenNextPos();
}
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,63 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_GRAD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_GRAD_CPU_KERNEL_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class MaskedSelectGradCPUKernel : public CPUKernel {
public:
MaskedSelectGradCPUKernel() = default;
~MaskedSelectGradCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
std::vector<size_t> input_shape_a_;
std::vector<size_t> input_shape_b_;
std::vector<size_t> grad_shape_;
std::vector<size_t> output_shape_;
uint64_t tensor_size_ = 1;
};
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
MaskedSelectGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
MaskedSelectGradCPUKernel, int);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_GRAD_CPU_KERNEL_H_

View File

@ -32,6 +32,7 @@ namespace mindspore {
// op name. Op which not exists in operator/ops.h, so define it's name here
constexpr auto kConcatOpName = "Concat";
constexpr auto kUniqueOpName = "Unique";
constexpr auto kMaskedSelectOpName = "MaskedSelect";
constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits";
constexpr auto kCTCGreedyDecoderOpName = "CTCGreedyDecoder";
constexpr auto kFour2FiveOpName = "Four2Five";
@ -555,7 +556,8 @@ const std::set<std::string> kHWSpecialFormatSet = {
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName,
kPadAndShiftOpName, kCTCGreedyDecoderOpName, kDropoutGenMaskOpName};
kPadAndShiftOpName, kCTCGreedyDecoderOpName, kDropoutGenMaskOpName,
kMaskedSelectOpName};
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};

View File

@ -280,6 +280,8 @@ AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &pri
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSort(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMaskedSelect(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list or dict.

View File

@ -15,7 +15,9 @@
*/
#include <algorithm>
#include <functional>
#include <iterator>
#include <numeric>
#include "abstract/infer_functions.h"
#include "abstract/utils.h"
#include "abstract/param_validator.h"
@ -1150,5 +1152,25 @@ AbstractBasePtr InferImplSort(const AnalysisEnginePtr &, const PrimitivePtr &pri
return std::make_shared<AbstractTuple>(result);
}
AbstractBasePtr InferImplMaskedSelect(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTensorPtr x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
AbstractTensorPtr mask = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto x_shape = x->shape();
auto mask_shape = mask->shape();
auto broadcast_shape = BroadcastShape(x_shape->shape(), mask_shape->shape());
ShapeVector y_shape = {Shape::SHP_ANY};
ShapeVector min_shape = {1};
int64_t max_size = std::accumulate(broadcast_shape.begin(), broadcast_shape.end(), 1, std::multiplies<int64_t>());
ShapeVector max_shape = {max_size};
if (max_shape.empty()) {
max_shape = x_shape->shape();
}
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(y_shape, min_shape, max_shape));
}
} // namespace abstract
} // namespace mindspore

View File

@ -96,6 +96,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimSplit, {InferImplSplit, nullptr, true}},
{prim::kPrimSequenceMask, {InferImplSequenceMask, nullptr, true}},
{prim::kPrimSort, {InferImplSort, nullptr, true}},
{prim::kPrimMaskedSelect, {InferImplMaskedSelect, nullptr, true}},
// Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, nullptr, true}},
{prim::kPrimMakeList, {InferImplMakeList, nullptr, true}},

View File

@ -212,6 +212,7 @@ inline const PrimitivePtr kPrimRank = std::make_shared<Primitive>("Rank");
inline const PrimitivePtr kPrimResizeBilinear = std::make_shared<Primitive>("ResizeBilinear");
inline const PrimitivePtr kPrimResizeGrad = std::make_shared<Primitive>("ResizeGrad");
inline const PrimitivePtr kPrimSort = std::make_shared<Primitive>("Sort");
inline const PrimitivePtr kPrimMaskedSelect = std::make_shared<Primitive>("MaskedSelect");
// NN
inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam");

View File

@ -1055,3 +1055,14 @@ def get_bprop_unique(self):
return (dx,)
return bprop
@bprop_getters.register(P.MaskedSelect)
def get_bprop_masked_select(self):
"""Generate bprop for MaskedSelect"""
op = G.MaskedSelectGrad()
def bprop(x, mask, dout):
dx = op(x, mask, dout)
return (dx,)
return bprop

View File

@ -50,6 +50,8 @@ from .rnnt_loss import _rnnt_loss_aicpu
from .random_categorical import _random_categorical_aicpu
from .cast import _cast_aicpu
from .mirror_pad import _mirror_pad_aicpu
from .masked_select import _masked_select_aicpu
from .masked_select_grad import _masked_select_grad_aicpu
from .mirror_pad_grad import _mirror_pad_grad_aicpu
from .standard_normal import _standard_normal_aicpu
from .gamma import _gamma_aicpu

View File

@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskedSelect op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
masked_select_op_info = AiCPURegOp("MaskedSelect") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.input(1, "mask", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I32_Default, DataType.BOOL_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.BOOL_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(masked_select_op_info)
def _masked_select_aicpu():
"""MaskedSelect AiCPU register"""
return

View File

@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskedSelectGrad op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
masked_select_grad_op_info = AiCPURegOp("MaskedSelectGrad") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.input(1, "mask", "required") \
.input(2, "grad", "required") \
.output(0, "dx", "required") \
.dtype_format(DataType.I32_Default, DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.BOOL_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(masked_select_grad_op_info)
def _masked_select_grad_aicpu():
"""MaskedSelectGrad AiCPU register"""
return

View File

@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax,
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch,
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
EmbeddingLookup, Unique, GatherD, Identity, Range)
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedSelect)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd,

View File

@ -2194,3 +2194,17 @@ class LRNGrad(PrimitiveWithInfer):
def infer_shape(self, grads, x, y):
return x
class MaskedSelectGrad(PrimitiveWithInfer):
"""Computes gradient for MaskedSelect."""
@prim_attr_register
def __init__(self):
pass
def infer_shape(self, x, mask, grad):
return x
def infer_dtype(self, x, mask, grad):
return x

View File

@ -27,6 +27,7 @@ import numpy as np
from mindspore import log as logger
from mindspore.common.initializer import Zero
from .._utils import get_broadcast_shape
from .._utils import get_concat_offset
from ..operations.math_ops import _infer_shape_reduce
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
@ -5299,3 +5300,38 @@ class Range(PrimitiveWithCheck):
delat = np.asscalar(delat_value.asnumpy())
return Tensor(np.arange(start, limit, delat), dtype=start_value.dtype)
return None
class MaskedSelect(PrimitiveWithCheck):
"""
Returns a new 1-D Tensor which indexes the input tensor according to the boolean mask.
Inputs:
- **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
- **mask** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Outputs:
Tensor, the shape of tensor is the same as `input_x`, :math:`(x_1, x_2, ..., x_R)`.
Raises:
TypeError: If `x` is not a Tensor.
Supported Platforms:
``CPU``
Examples:
>>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int64)
>>> mask = Tensor(np.array([1, 0, 1, 0]), mindspore.bool)
>>> output = ops.MaskedSelect()(x, mask)
>>> print(output)
[1 3]
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['x', 'mask'], outputs=['output'])
def check_shape(self, x_shape, mask_shape):
get_broadcast_shape(x_shape, mask_shape, self.name)
def check_dtype(self, x_dtype, mask_dtype):
validator.check_tensor_dtype_valid('mask', mask_dtype, [mstype.bool_], self.name)

View File

@ -0,0 +1,68 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops.operations import _grad_ops as G
def maskedselect():
x = np.array([1, 2, 3, 4]).astype(np.int32)
mask = np.array([[[0], [1], [0], [1]], [[0], [1], [0], [1]]]).astype(np.bool)
net = P.MaskedSelect()
return net(Tensor(x), Tensor(mask))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_maskedselect():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU", save_graphs=True)
y = maskedselect()
expect = [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4]
assert (y.asnumpy() == expect).all()
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network
def construct(self, x, mask, grad):
gout = self.grad(self.network)(x, mask, grad)
return gout
def masked_select_grad():
x = np.array([1, 2, 3, 4]).astype(np.int32)
mask = np.array([[0], [1], [0], [1]]).astype(np.bool)
dy = np.array([i for i in range(8)]).astype(np.int32)
grad = G.MaskedSelectGrad()
return grad(Tensor(x), Tensor(mask), Tensor(dy))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_masked_select_grad():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU", save_graphs=True)
dx = masked_select_grad()
expect = [4, 6, 8, 10]
assert (dx.asnumpy() == expect).all()