forked from mindspore-Ecosystem/mindspore
Add MaskedSelect CPU Operation
This commit is contained in:
parent
2c0dd077f5
commit
5ad5e427d0
|
@ -58,7 +58,9 @@ constexpr auto kPadAndShift = "PadAndShift";
|
|||
constexpr auto kCustRunApi = "RunCpuKernel";
|
||||
constexpr auto kDropout2D = "Dropout2D";
|
||||
constexpr auto kDropout3D = "Dropout3D";
|
||||
const std::set<std::string> kCustAiCpuKernelOps{kIdentity};
|
||||
constexpr auto kMaskedSelect = "MaskedSelect";
|
||||
constexpr auto kMaskedSelectGrad = "MaskedSelectGrad";
|
||||
const std::set<std::string> kCustAiCpuKernelOps{kIdentity, kMaskedSelect, kMaskedSelectGrad};
|
||||
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter,
|
||||
kPadAndShift, kDropout3D, kDropout2D};
|
||||
|
||||
|
|
|
@ -236,5 +236,36 @@ void TransposeIterator::GenNextPos() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> CPUKernelUtils::GetBroadcastShape(const std::vector<size_t> &x, const std::vector<size_t> &y) {
|
||||
size_t x_len = x.size();
|
||||
size_t y_len = y.size();
|
||||
size_t length = x_len < y_len ? x_len : y_len;
|
||||
std::vector<size_t> broadcast_shape;
|
||||
std::vector<size_t> broadcast_shape_back;
|
||||
for (int i = -length; i < 0; ++i) {
|
||||
if (x[x_len + i] == 1) {
|
||||
broadcast_shape_back.push_back(y[y_len + i]);
|
||||
} else if (y[y_len + i] == 1) {
|
||||
broadcast_shape_back.push_back(x[x_len + i]);
|
||||
} else if (x[x_len + i] == y[y_len + i]) {
|
||||
broadcast_shape_back.push_back(x[x_len + i]);
|
||||
}
|
||||
}
|
||||
if (length == x_len) {
|
||||
for (size_t i = 0; i < y_len - length; ++i) {
|
||||
broadcast_shape.push_back(y[i]);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < x_len - length; ++i) {
|
||||
broadcast_shape.push_back(x[i]);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
broadcast_shape.push_back(broadcast_shape_back[i]);
|
||||
}
|
||||
return broadcast_shape;
|
||||
}
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -153,6 +153,7 @@ class CPUKernelUtils {
|
|||
static void GetElementNumEveryDim(const std::vector<size_t> &shape, std::vector<size_t> *element_num);
|
||||
static void ParallelFor(const CTask &task, size_t count);
|
||||
static std::vector<size_t> FlatShapeByAxis(const std::vector<size_t> &shape, int axis);
|
||||
static std::vector<size_t> GetBroadcastShape(const std::vector<size_t> &x, const std::vector<size_t> &y);
|
||||
};
|
||||
|
||||
class BroadcastIterator {
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/masked_select_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
template <typename T>
|
||||
void MaskedSelectCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but MaskedSelectCPUKernel needs 2 input.";
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but MaskedSelectCPUKernel needs 1 output.";
|
||||
}
|
||||
input_shape_a_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
input_shape_b_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
output_shape_ = CPUKernelUtils::GetBroadcastShape(input_shape_a_, input_shape_b_);
|
||||
for (const uint64_t &d : output_shape_) {
|
||||
tensor_size_ *= d;
|
||||
}
|
||||
node_wpt_ = kernel_node;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MaskedSelectCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto mask = reinterpret_cast<bool *>(inputs[1]->addr);
|
||||
auto y = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
uint64_t j = 0;
|
||||
if (input_shape_a_ == input_shape_b_) {
|
||||
for (uint64_t i = 0; i < tensor_size_; ++i) {
|
||||
if (mask[i]) {
|
||||
y[j++] = x[i];
|
||||
}
|
||||
}
|
||||
} else { // Broadcast
|
||||
BroadcastIterator iter(input_shape_a_, input_shape_b_, output_shape_);
|
||||
iter.SetPos(0);
|
||||
for (uint64_t i = 0; i < tensor_size_; ++i) {
|
||||
if (mask[iter.GetInputPosB()]) {
|
||||
y[j++] = x[iter.GetInputPosA()];
|
||||
}
|
||||
iter.GenNextPos();
|
||||
}
|
||||
}
|
||||
if (!node_wpt_.expired()) {
|
||||
auto node_ = node_wpt_.lock();
|
||||
if (!node_) {
|
||||
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
|
||||
}
|
||||
std::vector<size_t> out_shape;
|
||||
out_shape.emplace_back(j);
|
||||
std::vector<TypeId> dtypes;
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(node_);
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i));
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, {out_shape}, node_.get());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_CPU_KERNEL_H_
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class MaskedSelectCPUKernel : public CPUKernel {
|
||||
public:
|
||||
MaskedSelectCPUKernel() = default;
|
||||
~MaskedSelectCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_shape_a_;
|
||||
std::vector<size_t> input_shape_b_;
|
||||
std::vector<size_t> output_shape_;
|
||||
uint64_t tensor_size_ = 1;
|
||||
CNodeWeakPtr node_wpt_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
MaskedSelect,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32),
|
||||
MaskedSelectCPUKernel, float);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
MaskedSelect,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
|
||||
MaskedSelectCPUKernel, int);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_CPU_KERNEL_H_
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/masked_select_grad_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
template <typename T>
|
||||
void MaskedSelectGradCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 3) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but MaskedSelectGradCPUKernel needs 3 input.";
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but MaskedSelectGradCPUKernel needs 1 output.";
|
||||
}
|
||||
input_shape_a_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
input_shape_b_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
grad_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
|
||||
output_shape_ = CPUKernelUtils::GetBroadcastShape(input_shape_a_, input_shape_b_);
|
||||
for (const uint64_t &d : output_shape_) {
|
||||
tensor_size_ *= d;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MaskedSelectGradCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto mask = reinterpret_cast<bool *>(inputs[1]->addr);
|
||||
auto grad = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
auto dx = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
uint64_t output_size = outputs[0]->size / sizeof(T);
|
||||
uint64_t j = 0;
|
||||
if (input_shape_a_ == input_shape_b_) {
|
||||
for (uint64_t i = 0; i < output_size; ++i) {
|
||||
if (mask[i]) {
|
||||
dx[i] += grad[j++];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
BroadcastIterator iter(input_shape_a_, input_shape_b_, output_shape_);
|
||||
iter.SetPos(0);
|
||||
for (uint64_t i = 0; i < tensor_size_; ++i) {
|
||||
if (mask[iter.GetInputPosB()]) {
|
||||
dx[iter.GetInputPosA()] += grad[j++];
|
||||
}
|
||||
iter.GenNextPos();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_GRAD_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_GRAD_CPU_KERNEL_H_
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class MaskedSelectGradCPUKernel : public CPUKernel {
|
||||
public:
|
||||
MaskedSelectGradCPUKernel() = default;
|
||||
~MaskedSelectGradCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_shape_a_;
|
||||
std::vector<size_t> input_shape_b_;
|
||||
std::vector<size_t> grad_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
uint64_t tensor_size_ = 1;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MaskedSelectGradCPUKernel, float);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
MaskedSelectGradCPUKernel, int);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_GRAD_CPU_KERNEL_H_
|
|
@ -32,6 +32,7 @@ namespace mindspore {
|
|||
// op name. Op which not exists in operator/ops.h, so define it's name here
|
||||
constexpr auto kConcatOpName = "Concat";
|
||||
constexpr auto kUniqueOpName = "Unique";
|
||||
constexpr auto kMaskedSelectOpName = "MaskedSelect";
|
||||
constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits";
|
||||
constexpr auto kCTCGreedyDecoderOpName = "CTCGreedyDecoder";
|
||||
constexpr auto kFour2FiveOpName = "Four2Five";
|
||||
|
@ -555,7 +556,8 @@ const std::set<std::string> kHWSpecialFormatSet = {
|
|||
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
|
||||
|
||||
const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName,
|
||||
kPadAndShiftOpName, kCTCGreedyDecoderOpName, kDropoutGenMaskOpName};
|
||||
kPadAndShiftOpName, kCTCGreedyDecoderOpName, kDropoutGenMaskOpName,
|
||||
kMaskedSelectOpName};
|
||||
|
||||
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
|
||||
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};
|
||||
|
|
|
@ -280,6 +280,8 @@ AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &pri
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSort(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMaskedSelect(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple or list or dict.
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include "abstract/infer_functions.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "abstract/param_validator.h"
|
||||
|
@ -1150,5 +1152,25 @@ AbstractBasePtr InferImplSort(const AnalysisEnginePtr &, const PrimitivePtr &pri
|
|||
return std::make_shared<AbstractTuple>(result);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMaskedSelect(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractTensorPtr x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
AbstractTensorPtr mask = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
|
||||
auto x_shape = x->shape();
|
||||
auto mask_shape = mask->shape();
|
||||
auto broadcast_shape = BroadcastShape(x_shape->shape(), mask_shape->shape());
|
||||
ShapeVector y_shape = {Shape::SHP_ANY};
|
||||
ShapeVector min_shape = {1};
|
||||
int64_t max_size = std::accumulate(broadcast_shape.begin(), broadcast_shape.end(), 1, std::multiplies<int64_t>());
|
||||
ShapeVector max_shape = {max_size};
|
||||
if (max_shape.empty()) {
|
||||
max_shape = x_shape->shape();
|
||||
}
|
||||
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(y_shape, min_shape, max_shape));
|
||||
}
|
||||
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -96,6 +96,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimSplit, {InferImplSplit, nullptr, true}},
|
||||
{prim::kPrimSequenceMask, {InferImplSequenceMask, nullptr, true}},
|
||||
{prim::kPrimSort, {InferImplSort, nullptr, true}},
|
||||
{prim::kPrimMaskedSelect, {InferImplMaskedSelect, nullptr, true}},
|
||||
// Structure
|
||||
{prim::kPrimMakeTuple, {InferImplMakeTuple, nullptr, true}},
|
||||
{prim::kPrimMakeList, {InferImplMakeList, nullptr, true}},
|
||||
|
|
|
@ -212,6 +212,7 @@ inline const PrimitivePtr kPrimRank = std::make_shared<Primitive>("Rank");
|
|||
inline const PrimitivePtr kPrimResizeBilinear = std::make_shared<Primitive>("ResizeBilinear");
|
||||
inline const PrimitivePtr kPrimResizeGrad = std::make_shared<Primitive>("ResizeGrad");
|
||||
inline const PrimitivePtr kPrimSort = std::make_shared<Primitive>("Sort");
|
||||
inline const PrimitivePtr kPrimMaskedSelect = std::make_shared<Primitive>("MaskedSelect");
|
||||
|
||||
// NN
|
||||
inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam");
|
||||
|
|
|
@ -1055,3 +1055,14 @@ def get_bprop_unique(self):
|
|||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
@bprop_getters.register(P.MaskedSelect)
|
||||
def get_bprop_masked_select(self):
|
||||
"""Generate bprop for MaskedSelect"""
|
||||
op = G.MaskedSelectGrad()
|
||||
|
||||
def bprop(x, mask, dout):
|
||||
dx = op(x, mask, dout)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -50,6 +50,8 @@ from .rnnt_loss import _rnnt_loss_aicpu
|
|||
from .random_categorical import _random_categorical_aicpu
|
||||
from .cast import _cast_aicpu
|
||||
from .mirror_pad import _mirror_pad_aicpu
|
||||
from .masked_select import _masked_select_aicpu
|
||||
from .masked_select_grad import _masked_select_grad_aicpu
|
||||
from .mirror_pad_grad import _mirror_pad_grad_aicpu
|
||||
from .standard_normal import _standard_normal_aicpu
|
||||
from .gamma import _gamma_aicpu
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""MaskedSelect op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
masked_select_op_info = AiCPURegOp("MaskedSelect") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "mask", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.BOOL_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.BOOL_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(masked_select_op_info)
|
||||
def _masked_select_aicpu():
|
||||
"""MaskedSelect AiCPU register"""
|
||||
return
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""MaskedSelectGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
masked_select_grad_op_info = AiCPURegOp("MaskedSelectGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "mask", "required") \
|
||||
.input(2, "grad", "required") \
|
||||
.output(0, "dx", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.BOOL_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(masked_select_grad_op_info)
|
||||
def _masked_select_grad_aicpu():
|
||||
"""MaskedSelectGrad AiCPU register"""
|
||||
return
|
|
@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
|
|||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax,
|
||||
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch,
|
||||
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
|
||||
EmbeddingLookup, Unique, GatherD, Identity, Range)
|
||||
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedSelect)
|
||||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
|
||||
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
|
||||
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd,
|
||||
|
|
|
@ -2194,3 +2194,17 @@ class LRNGrad(PrimitiveWithInfer):
|
|||
|
||||
def infer_shape(self, grads, x, y):
|
||||
return x
|
||||
|
||||
|
||||
class MaskedSelectGrad(PrimitiveWithInfer):
|
||||
"""Computes gradient for MaskedSelect."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def infer_shape(self, x, mask, grad):
|
||||
return x
|
||||
|
||||
def infer_dtype(self, x, mask, grad):
|
||||
return x
|
||||
|
|
|
@ -27,6 +27,7 @@ import numpy as np
|
|||
|
||||
from mindspore import log as logger
|
||||
from mindspore.common.initializer import Zero
|
||||
from .._utils import get_broadcast_shape
|
||||
from .._utils import get_concat_offset
|
||||
from ..operations.math_ops import _infer_shape_reduce
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
|
||||
|
@ -5299,3 +5300,38 @@ class Range(PrimitiveWithCheck):
|
|||
delat = np.asscalar(delat_value.asnumpy())
|
||||
return Tensor(np.arange(start, limit, delat), dtype=start_value.dtype)
|
||||
return None
|
||||
|
||||
class MaskedSelect(PrimitiveWithCheck):
|
||||
"""
|
||||
Returns a new 1-D Tensor which indexes the input tensor according to the boolean mask.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
- **mask** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of tensor is the same as `input_x`, :math:`(x_1, x_2, ..., x_R)`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int64)
|
||||
>>> mask = Tensor(np.array([1, 0, 1, 0]), mindspore.bool)
|
||||
>>> output = ops.MaskedSelect()(x, mask)
|
||||
>>> print(output)
|
||||
[1 3]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=['x', 'mask'], outputs=['output'])
|
||||
|
||||
def check_shape(self, x_shape, mask_shape):
|
||||
get_broadcast_shape(x_shape, mask_shape, self.name)
|
||||
|
||||
def check_dtype(self, x_dtype, mask_dtype):
|
||||
validator.check_tensor_dtype_valid('mask', mask_dtype, [mstype.bool_], self.name)
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
def maskedselect():
|
||||
x = np.array([1, 2, 3, 4]).astype(np.int32)
|
||||
mask = np.array([[[0], [1], [0], [1]], [[0], [1], [0], [1]]]).astype(np.bool)
|
||||
net = P.MaskedSelect()
|
||||
return net(Tensor(x), Tensor(mask))
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_maskedselect():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU", save_graphs=True)
|
||||
y = maskedselect()
|
||||
expect = [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4]
|
||||
assert (y.asnumpy() == expect).all()
|
||||
|
||||
|
||||
class Grad(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Grad, self).__init__()
|
||||
self.grad = C.GradOperation(get_all=True, sens_param=True)
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, mask, grad):
|
||||
gout = self.grad(self.network)(x, mask, grad)
|
||||
return gout
|
||||
|
||||
|
||||
def masked_select_grad():
|
||||
x = np.array([1, 2, 3, 4]).astype(np.int32)
|
||||
mask = np.array([[0], [1], [0], [1]]).astype(np.bool)
|
||||
dy = np.array([i for i in range(8)]).astype(np.int32)
|
||||
grad = G.MaskedSelectGrad()
|
||||
return grad(Tensor(x), Tensor(mask), Tensor(dy))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_masked_select_grad():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU", save_graphs=True)
|
||||
dx = masked_select_grad()
|
||||
expect = [4, 6, 8, 10]
|
||||
assert (dx.asnumpy() == expect).all()
|
Loading…
Reference in New Issue