!47009 remove bool exponentiation support of Pow op

Merge pull request !47009 from panshaowu/master
This commit is contained in:
i-robot 2022-12-20 11:36:32 +00:00 committed by Gitee
commit 51d3553d2a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 2 additions and 30 deletions

View File

@ -996,8 +996,6 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
SpecializeArithFunc<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
SpecializeArithFunc<uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
SpecializeArithFunc<bool>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)

View File

@ -198,23 +198,6 @@ POW_INTEGER_IMPL(int16_t)
POW_INTEGER_IMPL(int32_t)
POW_INTEGER_IMPL(int64_t)
template <>
struct PowerFunc<bool> {
__device__ __host__ __forceinline__ bool operator()(const bool &lhs, const bool &rhs) {
bool ret = true;
bool base = lhs;
bool exp = rhs;
while (exp) {
if (exp & 1) {
ret = ret && base;
}
base = base && base;
exp /= 2;
}
return ret;
}
};
template <typename T>
__device__ __host__ T abs_complex(const Complex<T> &x) {
double res = 0.0;

View File

@ -64,7 +64,7 @@ TypePtr PowInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr
std::map<std::string, TypePtr> types;
(void)types.emplace("x1", x1_type);
(void)types.emplace("x2", x2_type);
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_complex_and_bool, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_complex, prim->name());
return x1_type;
}
} // namespace

View File

@ -49,8 +49,6 @@ def test_net():
y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float64)
x3_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float64)
y3_np = np.array(3).astype(np.float64)
x4_np = np.array([[0, 0, 1], [1, 0, 1]]).astype(bool)
y4_np = np.array([[1, 0, 0], [0, 1, 1]]).astype(bool)
x0 = Tensor(x0_np)
y0 = Tensor(y0_np)
@ -60,8 +58,6 @@ def test_net():
y2 = Tensor(y2_np)
x3 = Tensor(x3_np)
y3 = Tensor(y3_np)
x4 = Tensor(x4_np)
y4 = Tensor(y4_np)
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net()
@ -85,11 +81,6 @@ def test_net():
assert np.all(out == expect)
assert out.shape == expect.shape
out = net(x4, y4).asnumpy()
expect = np.power(x4_np, y4_np)
assert np.all(out == expect)
assert out.shape == expect.shape
@pytest.mark.level0
@pytest.mark.platform_x86_cpu

View File

@ -43,7 +43,7 @@ def test_real_datatypes():
"""
real_datatypes = (np.uint8, np.uint16, np.uint32, np.uint64,
np.int8, np.int16, np.int32, np.int64,
np.float16, np.float32, np.float64, bool)
np.float16, np.float32, np.float64)
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net()
for datatype in real_datatypes: