diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_logic_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_logic_cpu_kernel.cc index bba56c0e15e..1c68ec61da1 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_logic_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_logic_cpu_kernel.cc @@ -123,7 +123,7 @@ class ArithLogicCpuTypeFunc : public CpuKernelFunc { void LessEqual(const T *input1, const T *input2, bool *out); void LogicalAnd(const T *input1, const T *input2, bool *out); void LogicalOr(const T *input1, const T *input2, bool *out); - void LogicalXor(const T *input1, const T *input2, bool *out) const; + void LogicalXor(const T *input1, const T *input2, bool *out); using TypeComputeFunc = std::function; TypeComputeFunc compute_func_{nullptr}; @@ -340,31 +340,8 @@ void ArithLogicCpuTypeFunc::LogicalOr(const T *input1, const T *input2, bool } template -void ArithLogicCpuTypeFunc::LogicalXor(const T *input1, const T *input2, bool *out) const { - BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); - auto task = [input1, input2, out, &base_iter](size_t start, size_t end) { - auto iter = base_iter; - iter.SetPos(start); - if constexpr (std::is_same_v) { - for (size_t i = start; i < end; i++) { - out[i] = !(common::IsFloatEqual(input1[iter.GetInputPosA()], input2[iter.GetInputPosB()])); - iter.GenNextPos(); - } - } else { - if constexpr (std::is_same_v) { - for (size_t i = start; i < end; i++) { - out[i] = !(common::IsDoubleEqual(input1[iter.GetInputPosA()], input2[iter.GetInputPosB()])); - iter.GenNextPos(); - } - } else { - for (size_t i = start; i < end; i++) { - out[i] = input1[iter.GetInputPosA()] != input2[iter.GetInputPosB()]; - iter.GenNextPos(); - } - } - } - }; - CPUKernelUtils::ParallelFor(task, output_size_); +void ArithLogicCpuTypeFunc::LogicalXor(const T *input1, const T *input2, bool *out) { + BinaryOp(input1, input2, out, std::not_equal_to()); } template