forked from mindspore-Ecosystem/mindspore
fix logical_xor ops shape broadcast bug
This commit is contained in:
parent
b699c4f939
commit
716e803d52
|
@ -123,7 +123,7 @@ class ArithLogicCpuTypeFunc : public CpuKernelFunc {
|
|||
void LessEqual(const T *input1, const T *input2, bool *out);
|
||||
void LogicalAnd(const T *input1, const T *input2, bool *out);
|
||||
void LogicalOr(const T *input1, const T *input2, bool *out);
|
||||
void LogicalXor(const T *input1, const T *input2, bool *out) const;
|
||||
void LogicalXor(const T *input1, const T *input2, bool *out);
|
||||
|
||||
using TypeComputeFunc = std::function<void(ArithLogicCpuTypeFunc *, const T *, const T *, bool *)>;
|
||||
TypeComputeFunc compute_func_{nullptr};
|
||||
|
@ -340,31 +340,8 @@ void ArithLogicCpuTypeFunc<T>::LogicalOr(const T *input1, const T *input2, bool
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithLogicCpuTypeFunc<T>::LogicalXor(const T *input1, const T *input2, bool *out) const {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [input1, input2, out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = !(common::IsFloatEqual(input1[iter.GetInputPosA()], input2[iter.GetInputPosB()]));
|
||||
iter.GenNextPos();
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same_v<T, double>) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = !(common::IsDoubleEqual(input1[iter.GetInputPosA()], input2[iter.GetInputPosB()]));
|
||||
iter.GenNextPos();
|
||||
}
|
||||
} else {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = input1[iter.GetInputPosA()] != input2[iter.GetInputPosB()];
|
||||
iter.GenNextPos();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, output_size_);
|
||||
void ArithLogicCpuTypeFunc<T>::LogicalXor(const T *input1, const T *input2, bool *out) {
|
||||
BinaryOp(input1, input2, out, std::not_equal_to<T>());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
Loading…
Reference in New Issue