forked from mindspore-Ecosystem/mindspore
!48503 scalar算子支持bool
Merge pull request !48503 from huoxinyou/0207scalarops
This commit is contained in:
commit
bb8d835b26
|
@ -272,18 +272,27 @@ std::vector<std::pair<KernelAttr, ScalarArithmeticCpuKernelMod::ScalarArithmetic
|
|||
ADD_KERNEL(Float32, Float64, Float64, float, double, double),
|
||||
ADD_KERNEL(Float32, Int32, Float32, float, int32_t, float),
|
||||
ADD_KERNEL(Float32, Int64, Float32, float, int64_t, float),
|
||||
ADD_KERNEL(Float32, Bool, Float32, float, bool, float),
|
||||
ADD_KERNEL(Float64, Float64, Float64, double, double, double),
|
||||
ADD_KERNEL(Float64, Float32, Float64, double, float, double),
|
||||
ADD_KERNEL(Float64, Int64, Float64, double, int64_t, double),
|
||||
ADD_KERNEL(Float64, Int32, Float64, double, int32_t, double),
|
||||
ADD_KERNEL(Float64, Bool, Float64, double, bool, double),
|
||||
ADD_KERNEL(Int32, Float32, Float32, int32_t, float, float),
|
||||
ADD_KERNEL(Int32, Float64, Float64, int32_t, double, double),
|
||||
ADD_KERNEL(Int32, Int32, Int32, int32_t, int32_t, int32_t),
|
||||
ADD_KERNEL(Int32, Int64, Int64, int32_t, int64_t, int64_t),
|
||||
ADD_KERNEL(Int32, Bool, Int64, int32_t, bool, int32_t),
|
||||
ADD_KERNEL(Int64, Float64, Float64, int64_t, double, double),
|
||||
ADD_KERNEL(Int64, Float32, Float32, int64_t, float, float),
|
||||
ADD_KERNEL(Int64, Int64, Int64, int64_t, int64_t, int64_t),
|
||||
ADD_KERNEL(Int64, Int32, Int64, int64_t, int32_t, int64_t),
|
||||
ADD_KERNEL(Int64, Bool, Int64, int64_t, bool, int64_t),
|
||||
ADD_KERNEL(Bool, Float32, Float32, bool, float, float),
|
||||
ADD_KERNEL(Bool, Float64, Float64, bool, double, double),
|
||||
ADD_KERNEL(Bool, Int32, Float32, bool, int32_t, int32_t),
|
||||
ADD_KERNEL(Bool, Int64, Float32, bool, int64_t, int32_t),
|
||||
ADD_KERNEL(Bool, Bool, Int32, bool, bool, int32_t),
|
||||
};
|
||||
|
||||
std::vector<std::pair<KernelAttr, ScalarArithmeticCpuKernelMod::ScalarArithmeticFunc>>
|
||||
|
@ -292,31 +301,55 @@ std::vector<std::pair<KernelAttr, ScalarArithmeticCpuKernelMod::ScalarArithmetic
|
|||
ADD_KERNEL(Float32, Float64, Float32, float, double, float),
|
||||
ADD_KERNEL(Float32, Int32, Float32, float, int32_t, float),
|
||||
ADD_KERNEL(Float32, Int64, Float32, float, int64_t, float),
|
||||
ADD_KERNEL(Float32, Bool, Float32, float, bool, float),
|
||||
ADD_KERNEL(Float64, Float64, Float32, double, double, float),
|
||||
ADD_KERNEL(Float64, Float32, Float32, double, float, float),
|
||||
ADD_KERNEL(Float64, Int64, Float32, double, int64_t, float),
|
||||
ADD_KERNEL(Float64, Int32, Float32, double, int32_t, float),
|
||||
ADD_KERNEL(Float64, Bool, Float32, double, bool, float),
|
||||
ADD_KERNEL(Int32, Float32, Float32, int32_t, float, float),
|
||||
ADD_KERNEL(Int32, Float64, Float32, int32_t, double, float),
|
||||
ADD_KERNEL(Int32, Int32, Float32, int32_t, int32_t, float),
|
||||
ADD_KERNEL(Int32, Int64, Float32, int32_t, int64_t, float),
|
||||
ADD_KERNEL(Int32, Bool, Float32, int32_t, bool, float),
|
||||
ADD_KERNEL(Int64, Float64, Float32, int64_t, double, float),
|
||||
ADD_KERNEL(Int64, Float32, Float32, int64_t, float, float),
|
||||
ADD_KERNEL(Int64, Int64, Float32, int64_t, int64_t, float),
|
||||
ADD_KERNEL(Int64, Int32, Float32, int64_t, int32_t, float),
|
||||
ADD_KERNEL(Int64, Bool, Float32, int64_t, bool, float),
|
||||
ADD_KERNEL(Bool, Float64, Float32, bool, double, float),
|
||||
ADD_KERNEL(Bool, Float32, Float32, bool, float, float),
|
||||
ADD_KERNEL(Bool, Int64, Float32, bool, int64_t, float),
|
||||
ADD_KERNEL(Bool, Int32, Float32, bool, int32_t, float),
|
||||
ADD_KERNEL(Bool, Bool, Float32, bool, bool, float),
|
||||
};
|
||||
|
||||
std::vector<std::pair<KernelAttr, ScalarArithmeticCpuKernelMod::ScalarArithmeticFunc>>
|
||||
ScalarArithmeticCpuKernelMod::logic_func_list_ = {
|
||||
ADD_KERNEL(Float32, Float32, Bool, float, float, bool), ADD_KERNEL(Float32, Float64, Bool, float, double, bool),
|
||||
ADD_KERNEL(Float32, Int32, Bool, float, int32_t, bool), ADD_KERNEL(Float32, Int64, Bool, float, int64_t, bool),
|
||||
ADD_KERNEL(Float64, Float64, Bool, double, double, bool), ADD_KERNEL(Float64, Float32, Bool, double, float, bool),
|
||||
ADD_KERNEL(Float64, Int64, Bool, double, int64_t, bool), ADD_KERNEL(Float64, Int32, Bool, double, int32_t, bool),
|
||||
ADD_KERNEL(Int32, Float32, Bool, int32_t, float, bool), ADD_KERNEL(Int32, Float64, Bool, int32_t, double, bool),
|
||||
ADD_KERNEL(Int32, Int32, Bool, int32_t, int32_t, bool), ADD_KERNEL(Int32, Int64, Bool, int32_t, int64_t, bool),
|
||||
ADD_KERNEL(Int64, Float64, Bool, int64_t, double, bool), ADD_KERNEL(Int64, Float32, Bool, int64_t, float, bool),
|
||||
ADD_KERNEL(Int64, Int64, Bool, int64_t, int64_t, bool), ADD_KERNEL(Int64, Int32, Bool, int64_t, int32_t, bool),
|
||||
};
|
||||
ScalarArithmeticCpuKernelMod::logic_func_list_ = {ADD_KERNEL(Float32, Float32, Bool, float, float, bool),
|
||||
ADD_KERNEL(Float32, Float64, Bool, float, double, bool),
|
||||
ADD_KERNEL(Float32, Int32, Bool, float, int32_t, bool),
|
||||
ADD_KERNEL(Float32, Int64, Bool, float, int64_t, bool),
|
||||
ADD_KERNEL(Float32, Bool, Bool, float, bool, bool),
|
||||
ADD_KERNEL(Float64, Bool, Bool, double, bool, bool),
|
||||
ADD_KERNEL(Float64, Float64, Bool, double, double, bool),
|
||||
ADD_KERNEL(Float64, Float32, Bool, double, float, bool),
|
||||
ADD_KERNEL(Float64, Int64, Bool, double, int64_t, bool),
|
||||
ADD_KERNEL(Float64, Int32, Bool, double, int32_t, bool),
|
||||
ADD_KERNEL(Int32, Float32, Bool, int32_t, float, bool),
|
||||
ADD_KERNEL(Int32, Float64, Bool, int32_t, double, bool),
|
||||
ADD_KERNEL(Int32, Int32, Bool, int32_t, int32_t, bool),
|
||||
ADD_KERNEL(Int32, Int64, Bool, int32_t, int64_t, bool),
|
||||
ADD_KERNEL(Int32, Bool, Bool, int32_t, bool, bool),
|
||||
ADD_KERNEL(Int64, Bool, Bool, int64_t, bool, bool),
|
||||
ADD_KERNEL(Int64, Float64, Bool, int64_t, double, bool),
|
||||
ADD_KERNEL(Int64, Float32, Bool, int64_t, float, bool),
|
||||
ADD_KERNEL(Int64, Int64, Bool, int64_t, int64_t, bool),
|
||||
ADD_KERNEL(Int64, Int32, Bool, int64_t, int32_t, bool),
|
||||
ADD_KERNEL(Bool, Float64, Bool, bool, double, bool),
|
||||
ADD_KERNEL(Bool, Float32, Bool, bool, float, bool),
|
||||
ADD_KERNEL(Bool, Int64, Bool, bool, int64_t, bool),
|
||||
ADD_KERNEL(Bool, Int32, Bool, bool, int32_t, bool),
|
||||
ADD_KERNEL(Bool, Bool, Bool, bool, bool, bool)};
|
||||
|
||||
std::vector<KernelAttr> ScalarArithmeticCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
|
|
|
@ -164,12 +164,14 @@ template <typename T>
|
|||
ValuePtr EqImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
|
||||
MS_EXCEPTION_IF_NULL(x_value);
|
||||
MS_EXCEPTION_IF_NULL(y_value);
|
||||
auto x = GetScalarValue<T>(op_name, x_value);
|
||||
auto y = GetScalarValue<T>(op_name, y_value);
|
||||
if (std::isinf(static_cast<double>(x)) && std::isinf(static_cast<double>(y))) {
|
||||
auto x_tmp = GetScalarValue<T>(op_name, x_value);
|
||||
auto y_tmp = GetScalarValue<T>(op_name, y_value);
|
||||
auto x = static_cast<double>(x_tmp);
|
||||
auto y = static_cast<double>(y_tmp);
|
||||
if (std::isinf(x) && std::isinf(y)) {
|
||||
return MakeValue((x > 0 && y > 0) || (x < 0 && y < 0));
|
||||
}
|
||||
double error = static_cast<double>(x) - static_cast<double>(y);
|
||||
double error = x - y;
|
||||
error = fabs(error);
|
||||
return MakeValue(error < DBL_EPSILON);
|
||||
}
|
||||
|
@ -255,7 +257,7 @@ class ScalarArithmeticInfer : public abstract::OpInferBase {
|
|||
auto prim_name = primitive->name();
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
auto y_type = input_args[kIndex1]->BuildType();
|
||||
std::set<TypePtr> check_types = {kInt32, kInt64, kFloat32, kFloat64};
|
||||
std::set<TypePtr> check_types = {kInt32, kInt64, kFloat32, kFloat64, kBool};
|
||||
std::set<std::string> compare_ops = {prim::kScalarEq, prim::kScalarGe, prim::kScalarGt, prim::kScalarLt,
|
||||
prim::kScalarLe};
|
||||
(void)CheckAndConvertUtils::CheckSubClass("x_dtype", x_type, check_types, prim_name);
|
||||
|
@ -317,6 +319,11 @@ class ScalarArithmeticInfer : public abstract::OpInferBase {
|
|||
result = func(x_value, y_value, op_name);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeBool: {
|
||||
auto func = ChooseFunc<bool>(op_name);
|
||||
result = func(x_value, y_value, op_name);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
MS_EXCEPTION(TypeError) << "For '" << op_name
|
||||
<< "', the supported type is in the list: [int32, int64, float32, float64], but got "
|
||||
|
|
|
@ -39,6 +39,23 @@ def test_constant_scalar_div_and_mod():
|
|||
assert np.abs(ret2 - 1) < tol
|
||||
|
||||
|
||||
def test_constant_scalar_div_and_mod_bool():
|
||||
"""
|
||||
Feature: Constant scalar div and mod operation.
|
||||
Description:
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
return False/True, True%True
|
||||
|
||||
ret1, ret2 = foo()
|
||||
tol = 1e-6
|
||||
assert np.abs(ret1 - 0) < tol
|
||||
assert np.abs(ret2 - 0) < tol
|
||||
|
||||
|
||||
def test_constant_scalar_bitwise():
|
||||
"""
|
||||
Feature: Constant scalar bitwise operation.
|
||||
|
|
Loading…
Reference in New Issue