!48503 scalar算子支持bool

Merge pull request !48503 from huoxinyou/0207scalarops
This commit is contained in:
i-robot 2023-02-09 07:05:50 +00:00 committed by Gitee
commit bb8d835b26
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 72 additions and 15 deletions

View File

@ -272,18 +272,27 @@ std::vector<std::pair<KernelAttr, ScalarArithmeticCpuKernelMod::ScalarArithmetic
ADD_KERNEL(Float32, Float64, Float64, float, double, double),
ADD_KERNEL(Float32, Int32, Float32, float, int32_t, float),
ADD_KERNEL(Float32, Int64, Float32, float, int64_t, float),
ADD_KERNEL(Float32, Bool, Float32, float, bool, float),
ADD_KERNEL(Float64, Float64, Float64, double, double, double),
ADD_KERNEL(Float64, Float32, Float64, double, float, double),
ADD_KERNEL(Float64, Int64, Float64, double, int64_t, double),
ADD_KERNEL(Float64, Int32, Float64, double, int32_t, double),
ADD_KERNEL(Float64, Bool, Float64, double, bool, double),
ADD_KERNEL(Int32, Float32, Float32, int32_t, float, float),
ADD_KERNEL(Int32, Float64, Float64, int32_t, double, double),
ADD_KERNEL(Int32, Int32, Int32, int32_t, int32_t, int32_t),
ADD_KERNEL(Int32, Int64, Int64, int32_t, int64_t, int64_t),
ADD_KERNEL(Int32, Bool, Int64, int32_t, bool, int32_t),
ADD_KERNEL(Int64, Float64, Float64, int64_t, double, double),
ADD_KERNEL(Int64, Float32, Float32, int64_t, float, float),
ADD_KERNEL(Int64, Int64, Int64, int64_t, int64_t, int64_t),
ADD_KERNEL(Int64, Int32, Int64, int64_t, int32_t, int64_t),
ADD_KERNEL(Int64, Bool, Int64, int64_t, bool, int64_t),
ADD_KERNEL(Bool, Float32, Float32, bool, float, float),
ADD_KERNEL(Bool, Float64, Float64, bool, double, double),
ADD_KERNEL(Bool, Int32, Float32, bool, int32_t, int32_t),
ADD_KERNEL(Bool, Int64, Float32, bool, int64_t, int32_t),
ADD_KERNEL(Bool, Bool, Int32, bool, bool, int32_t),
};
std::vector<std::pair<KernelAttr, ScalarArithmeticCpuKernelMod::ScalarArithmeticFunc>>
@ -292,31 +301,55 @@ std::vector<std::pair<KernelAttr, ScalarArithmeticCpuKernelMod::ScalarArithmetic
ADD_KERNEL(Float32, Float64, Float32, float, double, float),
ADD_KERNEL(Float32, Int32, Float32, float, int32_t, float),
ADD_KERNEL(Float32, Int64, Float32, float, int64_t, float),
ADD_KERNEL(Float32, Bool, Float32, float, bool, float),
ADD_KERNEL(Float64, Float64, Float32, double, double, float),
ADD_KERNEL(Float64, Float32, Float32, double, float, float),
ADD_KERNEL(Float64, Int64, Float32, double, int64_t, float),
ADD_KERNEL(Float64, Int32, Float32, double, int32_t, float),
ADD_KERNEL(Float64, Bool, Float32, double, bool, float),
ADD_KERNEL(Int32, Float32, Float32, int32_t, float, float),
ADD_KERNEL(Int32, Float64, Float32, int32_t, double, float),
ADD_KERNEL(Int32, Int32, Float32, int32_t, int32_t, float),
ADD_KERNEL(Int32, Int64, Float32, int32_t, int64_t, float),
ADD_KERNEL(Int32, Bool, Float32, int32_t, bool, float),
ADD_KERNEL(Int64, Float64, Float32, int64_t, double, float),
ADD_KERNEL(Int64, Float32, Float32, int64_t, float, float),
ADD_KERNEL(Int64, Int64, Float32, int64_t, int64_t, float),
ADD_KERNEL(Int64, Int32, Float32, int64_t, int32_t, float),
ADD_KERNEL(Int64, Bool, Float32, int64_t, bool, float),
ADD_KERNEL(Bool, Float64, Float32, bool, double, float),
ADD_KERNEL(Bool, Float32, Float32, bool, float, float),
ADD_KERNEL(Bool, Int64, Float32, bool, int64_t, float),
ADD_KERNEL(Bool, Int32, Float32, bool, int32_t, float),
ADD_KERNEL(Bool, Bool, Float32, bool, bool, float),
};
std::vector<std::pair<KernelAttr, ScalarArithmeticCpuKernelMod::ScalarArithmeticFunc>>
ScalarArithmeticCpuKernelMod::logic_func_list_ = {
ADD_KERNEL(Float32, Float32, Bool, float, float, bool), ADD_KERNEL(Float32, Float64, Bool, float, double, bool),
ADD_KERNEL(Float32, Int32, Bool, float, int32_t, bool), ADD_KERNEL(Float32, Int64, Bool, float, int64_t, bool),
ADD_KERNEL(Float64, Float64, Bool, double, double, bool), ADD_KERNEL(Float64, Float32, Bool, double, float, bool),
ADD_KERNEL(Float64, Int64, Bool, double, int64_t, bool), ADD_KERNEL(Float64, Int32, Bool, double, int32_t, bool),
ADD_KERNEL(Int32, Float32, Bool, int32_t, float, bool), ADD_KERNEL(Int32, Float64, Bool, int32_t, double, bool),
ADD_KERNEL(Int32, Int32, Bool, int32_t, int32_t, bool), ADD_KERNEL(Int32, Int64, Bool, int32_t, int64_t, bool),
ADD_KERNEL(Int64, Float64, Bool, int64_t, double, bool), ADD_KERNEL(Int64, Float32, Bool, int64_t, float, bool),
ADD_KERNEL(Int64, Int64, Bool, int64_t, int64_t, bool), ADD_KERNEL(Int64, Int32, Bool, int64_t, int32_t, bool),
};
ScalarArithmeticCpuKernelMod::logic_func_list_ = {ADD_KERNEL(Float32, Float32, Bool, float, float, bool),
ADD_KERNEL(Float32, Float64, Bool, float, double, bool),
ADD_KERNEL(Float32, Int32, Bool, float, int32_t, bool),
ADD_KERNEL(Float32, Int64, Bool, float, int64_t, bool),
ADD_KERNEL(Float32, Bool, Bool, float, bool, bool),
ADD_KERNEL(Float64, Bool, Bool, double, bool, bool),
ADD_KERNEL(Float64, Float64, Bool, double, double, bool),
ADD_KERNEL(Float64, Float32, Bool, double, float, bool),
ADD_KERNEL(Float64, Int64, Bool, double, int64_t, bool),
ADD_KERNEL(Float64, Int32, Bool, double, int32_t, bool),
ADD_KERNEL(Int32, Float32, Bool, int32_t, float, bool),
ADD_KERNEL(Int32, Float64, Bool, int32_t, double, bool),
ADD_KERNEL(Int32, Int32, Bool, int32_t, int32_t, bool),
ADD_KERNEL(Int32, Int64, Bool, int32_t, int64_t, bool),
ADD_KERNEL(Int32, Bool, Bool, int32_t, bool, bool),
ADD_KERNEL(Int64, Bool, Bool, int64_t, bool, bool),
ADD_KERNEL(Int64, Float64, Bool, int64_t, double, bool),
ADD_KERNEL(Int64, Float32, Bool, int64_t, float, bool),
ADD_KERNEL(Int64, Int64, Bool, int64_t, int64_t, bool),
ADD_KERNEL(Int64, Int32, Bool, int64_t, int32_t, bool),
ADD_KERNEL(Bool, Float64, Bool, bool, double, bool),
ADD_KERNEL(Bool, Float32, Bool, bool, float, bool),
ADD_KERNEL(Bool, Int64, Bool, bool, int64_t, bool),
ADD_KERNEL(Bool, Int32, Bool, bool, int32_t, bool),
ADD_KERNEL(Bool, Bool, Bool, bool, bool, bool)};
std::vector<KernelAttr> ScalarArithmeticCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;

View File

@ -164,12 +164,14 @@ template <typename T>
ValuePtr EqImpl(const ValuePtr &x_value, const ValuePtr &y_value, const std::string &op_name) {
MS_EXCEPTION_IF_NULL(x_value);
MS_EXCEPTION_IF_NULL(y_value);
auto x = GetScalarValue<T>(op_name, x_value);
auto y = GetScalarValue<T>(op_name, y_value);
if (std::isinf(static_cast<double>(x)) && std::isinf(static_cast<double>(y))) {
auto x_tmp = GetScalarValue<T>(op_name, x_value);
auto y_tmp = GetScalarValue<T>(op_name, y_value);
auto x = static_cast<double>(x_tmp);
auto y = static_cast<double>(y_tmp);
if (std::isinf(x) && std::isinf(y)) {
return MakeValue((x > 0 && y > 0) || (x < 0 && y < 0));
}
double error = static_cast<double>(x) - static_cast<double>(y);
double error = x - y;
error = fabs(error);
return MakeValue(error < DBL_EPSILON);
}
@ -255,7 +257,7 @@ class ScalarArithmeticInfer : public abstract::OpInferBase {
auto prim_name = primitive->name();
auto x_type = input_args[0]->BuildType();
auto y_type = input_args[kIndex1]->BuildType();
std::set<TypePtr> check_types = {kInt32, kInt64, kFloat32, kFloat64};
std::set<TypePtr> check_types = {kInt32, kInt64, kFloat32, kFloat64, kBool};
std::set<std::string> compare_ops = {prim::kScalarEq, prim::kScalarGe, prim::kScalarGt, prim::kScalarLt,
prim::kScalarLe};
(void)CheckAndConvertUtils::CheckSubClass("x_dtype", x_type, check_types, prim_name);
@ -317,6 +319,11 @@ class ScalarArithmeticInfer : public abstract::OpInferBase {
result = func(x_value, y_value, op_name);
break;
}
case kNumberTypeBool: {
auto func = ChooseFunc<bool>(op_name);
result = func(x_value, y_value, op_name);
break;
}
default: {
MS_EXCEPTION(TypeError) << "For '" << op_name
<< "', the supported type is in the list: [int32, int64, float32, float64], but got "

View File

@ -39,6 +39,23 @@ def test_constant_scalar_div_and_mod():
assert np.abs(ret2 - 1) < tol
def test_constant_scalar_div_and_mod_bool():
"""
Feature: Constant scalar div and mod operation.
Description:
Expectation: No exception.
"""
@jit
def foo():
return False/True, True%True
ret1, ret2 = foo()
tol = 1e-6
assert np.abs(ret1 - 0) < tol
assert np.abs(ret2 - 0) < tol
def test_constant_scalar_bitwise():
"""
Feature: Constant scalar bitwise operation.