fix logical_xor ops shape broadcast bug

This commit is contained in:
w00517672 2023-02-09 09:55:43 +08:00
parent b699c4f939
commit 716e803d52
1 changed files with 3 additions and 26 deletions

View File

@ -123,7 +123,7 @@ class ArithLogicCpuTypeFunc : public CpuKernelFunc {
void LessEqual(const T *input1, const T *input2, bool *out);
void LogicalAnd(const T *input1, const T *input2, bool *out);
void LogicalOr(const T *input1, const T *input2, bool *out);
void LogicalXor(const T *input1, const T *input2, bool *out) const;
void LogicalXor(const T *input1, const T *input2, bool *out);
using TypeComputeFunc = std::function<void(ArithLogicCpuTypeFunc *, const T *, const T *, bool *)>;
TypeComputeFunc compute_func_{nullptr};
@ -340,31 +340,8 @@ void ArithLogicCpuTypeFunc<T>::LogicalOr(const T *input1, const T *input2, bool
}
template <typename T>
void ArithLogicCpuTypeFunc<T>::LogicalXor(const T *input1, const T *input2, bool *out) const {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [input1, input2, out, &base_iter](size_t start, size_t end) {
auto iter = base_iter;
iter.SetPos(start);
if constexpr (std::is_same_v<T, float>) {
for (size_t i = start; i < end; i++) {
out[i] = !(common::IsFloatEqual(input1[iter.GetInputPosA()], input2[iter.GetInputPosB()]));
iter.GenNextPos();
}
} else {
if constexpr (std::is_same_v<T, double>) {
for (size_t i = start; i < end; i++) {
out[i] = !(common::IsDoubleEqual(input1[iter.GetInputPosA()], input2[iter.GetInputPosB()]));
iter.GenNextPos();
}
} else {
for (size_t i = start; i < end; i++) {
out[i] = input1[iter.GetInputPosA()] != input2[iter.GetInputPosB()];
iter.GenNextPos();
}
}
}
};
CPUKernelUtils::ParallelFor(task, output_size_);
void ArithLogicCpuTypeFunc<T>::LogicalXor(const T *input1, const T *input2, bool *out) {
BinaryOp(input1, input2, out, std::not_equal_to<T>());
}
template <typename T>