LogicalXor

This commit is contained in:
lvxinyu 2021-09-29 12:20:41 +08:00
parent 564f6089c6
commit 28a08b3cf6
10 changed files with 136 additions and 5 deletions

View File

@ -104,6 +104,20 @@ void ArithmeticLogicCpuKernelMod<T>::LogicalOr(const T *input1, const T *input2,
BinaryOp(input1, input2, out, std::logical_or<T>());
}
template <typename T>
void ArithmeticLogicCpuKernelMod<T>::LogicalXor(const T *input1, const T *input2, bool *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [input1, input2, out, &base_iter](size_t start, size_t end) {
auto iter = base_iter;
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
out[i] = input1[iter.GetInputPosA()] != input2[iter.GetInputPosB()];
iter.GenNextPos();
}
};
CPUKernelUtils::ParallelFor(task, output_size_);
}
template <typename T>
void ArithmeticLogicCpuKernelMod<T>::Greater(const T *input1, const T *input2, bool *out) {
BinaryOp(input1, input2, out, std::greater<T>());
@ -127,6 +141,7 @@ void ArithmeticLogicCpuKernelMod<T>::InitComputeFunc() {
{prim::kPrimLogicalAnd->name(), &ArithmeticLogicCpuKernelMod<T>::LogicalAnd},
{prim::kPrimLessEqual->name(), &ArithmeticLogicCpuKernelMod<T>::LessEqual},
{prim::kPrimLogicalOr->name(), &ArithmeticLogicCpuKernelMod<T>::LogicalOr},
{prim::kPrimLogicalXor->name(), &ArithmeticLogicCpuKernelMod<T>::LogicalXor},
{prim::kPrimLess->name(), &ArithmeticLogicCpuKernelMod<T>::Less},
{prim::kPrimNotEqual->name(), &ArithmeticLogicCpuKernelMod<T>::NotEqual},
{prim::kPrimEqual->name(), &ArithmeticLogicCpuKernelMod<T>::Equal}};

View File

@ -50,6 +50,7 @@ class ArithmeticLogicCpuKernelMod : public NativeCpuKernelMod {
void LessEqual(const T *input1, const T *input2, bool *out);
void LogicalAnd(const T *input1, const T *input2, bool *out);
void LogicalOr(const T *input1, const T *input2, bool *out);
void LogicalXor(const T *input1, const T *input2, bool *out);
using TypeComputeFunc = std::function<void(ArithmeticLogicCpuKernelMod *, const T *, const T *, bool *)>;
TypeComputeFunc compute_func_{nullptr};
@ -195,6 +196,9 @@ MS_REG_CPU_KERNEL_T(
MS_REG_CPU_KERNEL_T(
LogicalOr, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, bool);
MS_REG_CPU_KERNEL_T(
LogicalXor, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, bool);
} // namespace kernel
} // namespace mindspore

View File

@ -62,6 +62,7 @@ constexpr auto kACosGrad = "ACosGrad";
constexpr auto kRealDiv = "RealDiv";
constexpr auto kReciprocal = "Reciprocal";
constexpr auto kLog = "Log";
constexpr auto kLogicalXor = "LogicalXor";
constexpr auto kSelect = "Select";
constexpr auto kAdd = "Add";
constexpr auto kBiasAdd = "BiasAdd";
@ -217,6 +218,7 @@ MS_CORE_API inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive
MS_CORE_API inline const PrimitivePtr kPrimLogicalAnd = std::make_shared<Primitive>("LogicalAnd");
MS_CORE_API inline const PrimitivePtr kPrimLogicalOr = std::make_shared<Primitive>("LogicalOr");
MS_CORE_API inline const PrimitivePtr kPrimLogicalNot = std::make_shared<Primitive>("LogicalNot");
MS_CORE_API inline const PrimitivePtr kPrimLogicalXor = std::make_shared<Primitive>(kLogicalXor);
MS_CORE_API inline const PrimitivePtr kPrimEqualCount = std::make_shared<Primitive>("EqualCount");
MS_CORE_API inline const PrimitivePtr kPrimApproximateEqual = std::make_shared<Primitive>("ApproximateEqual");

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,10 +14,39 @@
* limitations under the License.
*/
#include <map>
#include <string>
#include <set>
#include "ops/op_utils.h"
#include "ops/logical_xor.h"
namespace mindspore {
namespace ops {
REGISTER_PRIMITIVE_C(kNameLogicalXor, LogicalXor);
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
return BroadCastInferShape(op_name, input_args);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
std::map<std::string, TypePtr> types;
const std::set<TypePtr> valid_types = {kBool};
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace
AbstractBasePtr LogicalXorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = InferType(primitive, input_args);
auto infer_shape = InferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(LogicalXor, prim::kPrimLogicalXor, LogicalXorInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,8 @@
#ifndef MINDSPORE_CORE_OPS_LOGICAL_XOR_H_
#define MINDSPORE_CORE_OPS_LOGICAL_XOR_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
@ -28,13 +30,17 @@ constexpr auto kNameLogicalXor = "LogicalXor";
class MS_CORE_API LogicalXor : public PrimitiveC {
public:
/// \brief Constructor.
LogicalXor() : PrimitiveC(kNameLogicalXor) {}
LogicalXor() : PrimitiveC(kNameLogicalXor) { InitIOName({"x", "y"}, {"output"}); }
/// \brief Destructor.
~LogicalXor() = default;
MS_DECLARE_PARENT(LogicalXor, PrimitiveC);
/// \brief Init. Refer to the parameters of Python API @ref mindspore.numpy.logical_xor for the inputs.
void Init() const {}
};
AbstractBasePtr LogicalXorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimLogicalXorPtr = std::shared_ptr<LogicalXor>;
} // namespace ops
} // namespace mindspore

View File

@ -38,6 +38,7 @@ from .dynamic_stitch import _dynamic_stitch_aicpu
from .get_next import _get_next_aicpu
from .print_tensor import _print_aicpu
from .topk import _top_k_aicpu
from .logical_xor import _logical_xor_aicpu
from .is_finite import _is_finite_aicpu
from .is_inf import _is_inf_aicpu
from .is_nan import _is_nan_aicpu

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LogicalXor op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
logical_xor_op_info = AiCPURegOp("LogicalXor") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.input(1, "y", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(logical_xor_op_info)
def _logical_xor_aicpu():
"""LogicalXor AiCPU register"""
return

View File

@ -54,7 +54,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, Cdist, ReduceAny,
Cos, Cross, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod,
Ceil, Acosh, Greater, GreaterEqual, Lerp, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
LogicalNot, LogicalOr, LpNorm, MatMul, Maximum, MulNoNan,
LogicalNot, LogicalOr, LogicalXor, LpNorm, MatMul, Maximum, MulNoNan,
MatrixDeterminant, LogMatrixDeterminant, Minimum, Mul, Neg, NMSWithMask, NotEqual,
NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace, Einsum,
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
@ -368,6 +368,7 @@ __all__ = [
'LogicalNot',
'LogicalAnd',
'LogicalOr',
'LogicalXor',
'Size',
'DepthwiseConv2dNative',
'UnsortedSegmentSum',

View File

@ -4124,6 +4124,44 @@ class LogicalOr(_LogicBinaryOp):
"""
class LogicalXor(Primitive):
r"""
Computes the "logical XOR" of two tensors element-wise.
.. math::
out_{i} = x_{i} \oplus y_{i}
Inputs:
- **x** (Tensor) - The first input is a tensor whose data type is bool.
- **y** (Tensor) - The second input is a the tensor to compute XOR with the first input.
Datatype must be bool.
Outputs:
Tensor, the shape is the same as the one after broadcasting, and the data type is bool.
Raises:
TypeError: If neither `x` nor `y` is a Tensor whose data type is bool.
ValueError: If the shape of two inputs cannot be broadcast.
Supported Platforms:
``CPU``
Examples:
>>> x = Tensor(np.array([True, False, True]), mindspore.bool_)
>>> y = Tensor(np.array([True, True, False]), mindspore.bool_)
>>> logical_xor = ops.LogicalXor()
>>> output = logical_xor(x, y)
>>> print(output)
[ False True True]
"""
@prim_attr_register
def __init__(self):
"""Initialize LogicalXor"""
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
class IsNan(Primitive):
r"""
Determines which elements are NaN for each position.

View File

@ -1615,6 +1615,11 @@ test_case_math_ops = [
'block': P.LogicalOr(),
'desc_inputs': [Tensor(np.zeros((3, 4, 5), np.bool_)), Tensor(np.ones((3, 1, 1), np.bool_))],
'desc_bprop': [Tensor(np.zeros((3, 4, 5), np.bool_))]}),
('LogicalXor', {
'block': P.LogicalXor(),
'desc_inputs': [Tensor(np.zeros((3, 4, 5), np.bool_)), Tensor(np.ones((3, 1, 1), np.bool_))],
'desc_bprop': [Tensor(np.zeros((3, 4, 5), np.bool_))],
'skip': ['backward']}),
('NpuAllocFloatStatus', {
'block': P.NPUAllocFloatStatus(),
'desc_inputs': [],