forked from mindspore-Ecosystem/mindspore
LogicalXor
This commit is contained in:
parent
564f6089c6
commit
28a08b3cf6
|
@ -104,6 +104,20 @@ void ArithmeticLogicCpuKernelMod<T>::LogicalOr(const T *input1, const T *input2,
|
|||
BinaryOp(input1, input2, out, std::logical_or<T>());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticLogicCpuKernelMod<T>::LogicalXor(const T *input1, const T *input2, bool *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [input1, input2, out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = input1[iter.GetInputPosA()] != input2[iter.GetInputPosB()];
|
||||
iter.GenNextPos();
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, output_size_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticLogicCpuKernelMod<T>::Greater(const T *input1, const T *input2, bool *out) {
|
||||
BinaryOp(input1, input2, out, std::greater<T>());
|
||||
|
@ -127,6 +141,7 @@ void ArithmeticLogicCpuKernelMod<T>::InitComputeFunc() {
|
|||
{prim::kPrimLogicalAnd->name(), &ArithmeticLogicCpuKernelMod<T>::LogicalAnd},
|
||||
{prim::kPrimLessEqual->name(), &ArithmeticLogicCpuKernelMod<T>::LessEqual},
|
||||
{prim::kPrimLogicalOr->name(), &ArithmeticLogicCpuKernelMod<T>::LogicalOr},
|
||||
{prim::kPrimLogicalXor->name(), &ArithmeticLogicCpuKernelMod<T>::LogicalXor},
|
||||
{prim::kPrimLess->name(), &ArithmeticLogicCpuKernelMod<T>::Less},
|
||||
{prim::kPrimNotEqual->name(), &ArithmeticLogicCpuKernelMod<T>::NotEqual},
|
||||
{prim::kPrimEqual->name(), &ArithmeticLogicCpuKernelMod<T>::Equal}};
|
||||
|
|
|
@ -50,6 +50,7 @@ class ArithmeticLogicCpuKernelMod : public NativeCpuKernelMod {
|
|||
void LessEqual(const T *input1, const T *input2, bool *out);
|
||||
void LogicalAnd(const T *input1, const T *input2, bool *out);
|
||||
void LogicalOr(const T *input1, const T *input2, bool *out);
|
||||
void LogicalXor(const T *input1, const T *input2, bool *out);
|
||||
|
||||
using TypeComputeFunc = std::function<void(ArithmeticLogicCpuKernelMod *, const T *, const T *, bool *)>;
|
||||
TypeComputeFunc compute_func_{nullptr};
|
||||
|
@ -195,6 +196,9 @@ MS_REG_CPU_KERNEL_T(
|
|||
MS_REG_CPU_KERNEL_T(
|
||||
LogicalOr, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, bool);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
LogicalXor, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, bool);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -62,6 +62,7 @@ constexpr auto kACosGrad = "ACosGrad";
|
|||
constexpr auto kRealDiv = "RealDiv";
|
||||
constexpr auto kReciprocal = "Reciprocal";
|
||||
constexpr auto kLog = "Log";
|
||||
constexpr auto kLogicalXor = "LogicalXor";
|
||||
constexpr auto kSelect = "Select";
|
||||
constexpr auto kAdd = "Add";
|
||||
constexpr auto kBiasAdd = "BiasAdd";
|
||||
|
@ -217,6 +218,7 @@ MS_CORE_API inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive
|
|||
MS_CORE_API inline const PrimitivePtr kPrimLogicalAnd = std::make_shared<Primitive>("LogicalAnd");
|
||||
MS_CORE_API inline const PrimitivePtr kPrimLogicalOr = std::make_shared<Primitive>("LogicalOr");
|
||||
MS_CORE_API inline const PrimitivePtr kPrimLogicalNot = std::make_shared<Primitive>("LogicalNot");
|
||||
MS_CORE_API inline const PrimitivePtr kPrimLogicalXor = std::make_shared<Primitive>(kLogicalXor);
|
||||
MS_CORE_API inline const PrimitivePtr kPrimEqualCount = std::make_shared<Primitive>("EqualCount");
|
||||
MS_CORE_API inline const PrimitivePtr kPrimApproximateEqual = std::make_shared<Primitive>("ApproximateEqual");
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,10 +14,39 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/logical_xor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
REGISTER_PRIMITIVE_C(kNameLogicalXor, LogicalXor);
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
return BroadCastInferShape(op_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
std::map<std::string, TypePtr> types;
|
||||
const std::set<TypePtr> valid_types = {kBool};
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr LogicalXorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(LogicalXor, prim::kPrimLogicalXor, LogicalXorInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +16,8 @@
|
|||
|
||||
#ifndef MINDSPORE_CORE_OPS_LOGICAL_XOR_H_
|
||||
#define MINDSPORE_CORE_OPS_LOGICAL_XOR_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
@ -28,13 +30,17 @@ constexpr auto kNameLogicalXor = "LogicalXor";
|
|||
class MS_CORE_API LogicalXor : public PrimitiveC {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
LogicalXor() : PrimitiveC(kNameLogicalXor) {}
|
||||
LogicalXor() : PrimitiveC(kNameLogicalXor) { InitIOName({"x", "y"}, {"output"}); }
|
||||
/// \brief Destructor.
|
||||
~LogicalXor() = default;
|
||||
MS_DECLARE_PARENT(LogicalXor, PrimitiveC);
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.numpy.logical_xor for the inputs.
|
||||
void Init() const {}
|
||||
};
|
||||
|
||||
AbstractBasePtr LogicalXorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimLogicalXorPtr = std::shared_ptr<LogicalXor>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -38,6 +38,7 @@ from .dynamic_stitch import _dynamic_stitch_aicpu
|
|||
from .get_next import _get_next_aicpu
|
||||
from .print_tensor import _print_aicpu
|
||||
from .topk import _top_k_aicpu
|
||||
from .logical_xor import _logical_xor_aicpu
|
||||
from .is_finite import _is_finite_aicpu
|
||||
from .is_inf import _is_inf_aicpu
|
||||
from .is_nan import _is_nan_aicpu
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""LogicalXor op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
logical_xor_op_info = AiCPURegOp("LogicalXor") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "y", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(logical_xor_op_info)
|
||||
def _logical_xor_aicpu():
|
||||
"""LogicalXor AiCPU register"""
|
||||
return
|
|
@ -54,7 +54,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
|||
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, Cdist, ReduceAny,
|
||||
Cos, Cross, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod,
|
||||
Ceil, Acosh, Greater, GreaterEqual, Lerp, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
|
||||
LogicalNot, LogicalOr, LpNorm, MatMul, Maximum, MulNoNan,
|
||||
LogicalNot, LogicalOr, LogicalXor, LpNorm, MatMul, Maximum, MulNoNan,
|
||||
MatrixDeterminant, LogMatrixDeterminant, Minimum, Mul, Neg, NMSWithMask, NotEqual,
|
||||
NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace, Einsum,
|
||||
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
||||
|
@ -368,6 +368,7 @@ __all__ = [
|
|||
'LogicalNot',
|
||||
'LogicalAnd',
|
||||
'LogicalOr',
|
||||
'LogicalXor',
|
||||
'Size',
|
||||
'DepthwiseConv2dNative',
|
||||
'UnsortedSegmentSum',
|
||||
|
|
|
@ -4124,6 +4124,44 @@ class LogicalOr(_LogicBinaryOp):
|
|||
"""
|
||||
|
||||
|
||||
class LogicalXor(Primitive):
|
||||
r"""
|
||||
Computes the "logical XOR" of two tensors element-wise.
|
||||
|
||||
.. math::
|
||||
|
||||
out_{i} = x_{i} \oplus y_{i}
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The first input is a tensor whose data type is bool.
|
||||
- **y** (Tensor) - The second input is a the tensor to compute XOR with the first input.
|
||||
Datatype must be bool.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is the same as the one after broadcasting, and the data type is bool.
|
||||
|
||||
Raises:
|
||||
TypeError: If neither `x` nor `y` is a Tensor whose data type is bool.
|
||||
ValueError: If the shape of two inputs cannot be broadcast.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([True, False, True]), mindspore.bool_)
|
||||
>>> y = Tensor(np.array([True, True, False]), mindspore.bool_)
|
||||
>>> logical_xor = ops.LogicalXor()
|
||||
>>> output = logical_xor(x, y)
|
||||
>>> print(output)
|
||||
[ False True True]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize LogicalXor"""
|
||||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||
|
||||
|
||||
class IsNan(Primitive):
|
||||
r"""
|
||||
Determines which elements are NaN for each position.
|
||||
|
|
|
@ -1615,6 +1615,11 @@ test_case_math_ops = [
|
|||
'block': P.LogicalOr(),
|
||||
'desc_inputs': [Tensor(np.zeros((3, 4, 5), np.bool_)), Tensor(np.ones((3, 1, 1), np.bool_))],
|
||||
'desc_bprop': [Tensor(np.zeros((3, 4, 5), np.bool_))]}),
|
||||
('LogicalXor', {
|
||||
'block': P.LogicalXor(),
|
||||
'desc_inputs': [Tensor(np.zeros((3, 4, 5), np.bool_)), Tensor(np.ones((3, 1, 1), np.bool_))],
|
||||
'desc_bprop': [Tensor(np.zeros((3, 4, 5), np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
('NpuAllocFloatStatus', {
|
||||
'block': P.NPUAllocFloatStatus(),
|
||||
'desc_inputs': [],
|
||||
|
|
Loading…
Reference in New Issue