Optimize scalar operator implementations

Use compiler built-in functions for overflow checking.
This commit is contained in:
He Wei 2021-10-28 16:52:42 +08:00
parent 91675add3e
commit e3602df91e
1 changed files with 98 additions and 134 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -48,76 +48,41 @@ DataType InferType(const AnyPtrList &list) {
return DataType::kUnknown;
}
template <typename T>
bool IsAddOverflow(const T &x, const T &y, const T &max, const T &min) {
return (y > 0 && (max - y) < x) || (y < 0 && (min - y) > x);
}
template <typename T>
bool IsSubOverflow(const T &x, const T &y, const T &max, const T &min) {
return (y < 0 && (max + y) < x) || (y > 0 && (min + y) > x);
}
template <typename T>
bool IsMulOverflow(const T &x, const T &y, const T &max, const T &min) {
return (x > 0 && y > 0 && (max / y) < x) || (x < 0 && y < 0 && (max / y) > x) || (x > 0 && y < 0 && (min / y) < x) ||
(x < 0 && y > 0 && (min / y) > x);
}
template <typename T>
bool IsDivOverflow(const T &x, const T &y, const T &min) {
return (x == min && static_cast<int64_t>(y) == -1);
}
enum class OpType { ADD, SUB, MUL, DIV, MOD };
template <typename T>
bool IsSignedIntOverflow(T x, T y, OpType opType) {
auto max = std::numeric_limits<T>::max();
auto min = std::numeric_limits<T>::min();
if (opType == OpType::ADD) {
return IsAddOverflow<T>(x, y, max, min);
}
if (opType == OpType::SUB) {
return IsSubOverflow<T>(x, y, max, min);
}
if (opType == OpType::MUL) {
return IsMulOverflow<T>(x, y, max, min);
}
if (opType == OpType::DIV || opType == OpType::MOD) {
return IsDivOverflow<T>(x, y, min);
}
MS_LOG(EXCEPTION) << "Unsupported operation type.";
}
template <typename T>
T InnerScalarAdd(T x, T y) {
if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::ADD)) {
MS_LOG(EXCEPTION) << "Overflow of the sum of two signed number x: " << std::to_string(x)
<< ", y: " << std::to_string(y) << ".";
if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
T res;
if (__builtin_add_overflow(x, y, &res)) {
MS_LOG(EXCEPTION) << "Overflow of the sum of two signed number x: " << std::to_string(x)
<< ", y: " << std::to_string(y) << ".";
}
return res;
}
return x + y;
}
template <typename T>
T InnerScalarSub(T x, T y) {
if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::SUB)) {
MS_LOG(EXCEPTION) << "Overflow of the sub of two signed number x: " << std::to_string(x)
<< ", y: " << std::to_string(y) << ".";
if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
T res;
if (__builtin_sub_overflow(x, y, &res)) {
MS_LOG(EXCEPTION) << "Overflow of the sub of two signed number x: " << std::to_string(x)
<< ", y: " << std::to_string(y) << ".";
}
return res;
}
return x - y;
}
template <typename T>
T InnerScalarMul(T x, T y) {
if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::MUL)) {
MS_LOG(EXCEPTION) << "Overflow of the mul of two signed number x: " << std::to_string(x)
<< ", y: " << std::to_string(y) << ".";
if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
T res;
if (__builtin_mul_overflow(x, y, &res)) {
MS_LOG(EXCEPTION) << "Overflow of the mul of two signed number x: " << std::to_string(x)
<< ", y: " << std::to_string(y) << ".";
}
return res;
}
return x * y;
}
@ -127,9 +92,11 @@ float InnerScalarDiv(T x, T y) {
if (y == 0) {
MS_LOG(EXCEPTION) << "The divisor could not be zero.";
}
if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::DIV)) {
MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x)
<< ", y: " << std::to_string(y) << ".";
if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
if (x == std::numeric_limits<T>::min() && static_cast<int64_t>(y) == -1) {
MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x)
<< ", y: " << std::to_string(y) << ".";
}
}
return static_cast<float>(x) / static_cast<float>(y);
}
@ -137,10 +104,7 @@ float InnerScalarDiv(T x, T y) {
template <typename T>
T InnerScalarFloordiv(T x, T y) {
auto ret = std::floor(InnerScalarDiv(x, y));
if (std::is_integral<T>::value) {
return static_cast<int64_t>(ret);
}
return ret;
return static_cast<T>(ret);
}
template <typename T>
@ -148,14 +112,16 @@ T InnerScalarMod(T x, T y) {
if (y == 0) {
MS_LOG(EXCEPTION) << "Could not mod to zero.";
}
if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::MOD)) {
MS_LOG(EXCEPTION) << "Overflow of the mod of two signed number x: " << std::to_string(x)
<< ", y: " << std::to_string(y) << ".";
if constexpr (!std::is_integral<T>::value) {
return x - y * std::floor(x / y);
}
if (std::is_integral<T>::value) {
return static_cast<int64_t>(x) % static_cast<int64_t>(y);
if constexpr (std::is_signed<T>::value) {
if (x == std::numeric_limits<T>::min() && static_cast<int64_t>(y) == -1) {
MS_LOG(EXCEPTION) << "Overflow of the mod of two signed number x: " << std::to_string(x)
<< ", y: " << std::to_string(y) << ".";
}
}
return x - y * std::floor(x / y);
return static_cast<int64_t>(x) % static_cast<int64_t>(y);
}
template <typename T, typename U>
@ -195,68 +161,66 @@ bool InnerScalarGe(T x, U y) {
return x >= y;
}
#define SCALAR_OP(op_t) \
ValuePtr Scalar##op_t(const ValuePtrList &list) { \
do { \
if (list.size() < 2) { \
MS_LOG(EXCEPTION) << "The length of input list for Scalar" << #op_t << " is less than 2."; \
} \
ValuePtr x = list[0]; \
ValuePtr y = list[1]; \
MS_EXCEPTION_IF_NULL(x); \
MS_EXCEPTION_IF_NULL(y); \
if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \
double sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y)); \
return MakeValue(sum); \
} \
if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) { \
float sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y)); \
return MakeValue(sum); \
} \
if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) { \
int sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y)); \
return MakeValue(sum); \
} \
if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
float sum = InnerScalar##op_t(IntToFloat(GetValue<int>(x)), GetValue<float>(y)); \
return MakeValue(sum); \
} \
if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) { \
float sum = InnerScalar##op_t(GetValue<float>(x), IntToFloat(GetValue<int>(y))); \
return MakeValue(sum); \
} \
if (x->isa<Int64Imm>() && y->isa<Int64Imm>()) { \
int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int64_t>(y)); \
return MakeValue(sum); \
} \
if (x->isa<Int64Imm>() && y->isa<FP64Imm>()) { \
double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), GetValue<double>(y)); \
return MakeValue(sum); \
} \
if (x->isa<Int64Imm>() && y->isa<FP32Imm>()) { \
double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), FloatToDouble(GetValue<float>(y))); \
return MakeValue(sum); \
} \
if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), IntToLong(GetValue<int>(y))); \
return MakeValue(sum); \
} \
if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) { \
double sum = InnerScalar##op_t(FloatToDouble(GetValue<float>(x)), LongToDouble(GetValue<int64_t>(y))); \
return MakeValue(sum); \
} \
if (x->isa<FP64Imm>() && y->isa<Int64Imm>()) { \
double sum = InnerScalar##op_t(GetValue<double>(x), LongToDouble(GetValue<int64_t>(y))); \
return MakeValue(sum); \
} \
if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) { \
int64_t sum = InnerScalar##op_t(IntToLong(GetValue<int>(x)), GetValue<int64_t>(y)); \
return MakeValue(sum); \
} \
MS_LOG(EXCEPTION) << "Unsupported input type for Scalar" << #op_t << ", type of x:" << x->type_name() \
<< ", value of x:" << x->ToString() << ", type of y:" << y->type_name() \
<< ", value of y:" << y->ToString(); \
} while (0); \
#define SCALAR_OP(op_t) \
ValuePtr Scalar##op_t(const ValuePtrList &list) { \
if (list.size() < 2) { \
MS_LOG(EXCEPTION) << "The length of input list for Scalar" << #op_t << " is less than 2."; \
} \
const ValuePtr &x = list[0]; \
const ValuePtr &y = list[1]; \
MS_EXCEPTION_IF_NULL(x); \
MS_EXCEPTION_IF_NULL(y); \
if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \
double sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y)); \
return MakeValue(sum); \
} \
if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) { \
float sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y)); \
return MakeValue(sum); \
} \
if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) { \
int sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y)); \
return MakeValue(sum); \
} \
if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
float sum = InnerScalar##op_t(IntToFloat(GetValue<int>(x)), GetValue<float>(y)); \
return MakeValue(sum); \
} \
if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) { \
float sum = InnerScalar##op_t(GetValue<float>(x), IntToFloat(GetValue<int>(y))); \
return MakeValue(sum); \
} \
if (x->isa<Int64Imm>() && y->isa<Int64Imm>()) { \
int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int64_t>(y)); \
return MakeValue(sum); \
} \
if (x->isa<Int64Imm>() && y->isa<FP64Imm>()) { \
double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), GetValue<double>(y)); \
return MakeValue(sum); \
} \
if (x->isa<Int64Imm>() && y->isa<FP32Imm>()) { \
double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), FloatToDouble(GetValue<float>(y))); \
return MakeValue(sum); \
} \
if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), IntToLong(GetValue<int>(y))); \
return MakeValue(sum); \
} \
if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) { \
double sum = InnerScalar##op_t(FloatToDouble(GetValue<float>(x)), LongToDouble(GetValue<int64_t>(y))); \
return MakeValue(sum); \
} \
if (x->isa<FP64Imm>() && y->isa<Int64Imm>()) { \
double sum = InnerScalar##op_t(GetValue<double>(x), LongToDouble(GetValue<int64_t>(y))); \
return MakeValue(sum); \
} \
if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) { \
int64_t sum = InnerScalar##op_t(IntToLong(GetValue<int>(x)), GetValue<int64_t>(y)); \
return MakeValue(sum); \
} \
MS_LOG(EXCEPTION) << "Unsupported input type for Scalar" << #op_t << ", type of x:" << x->type_name() \
<< ", value of x:" << x->ToString() << ", type of y:" << y->type_name() \
<< ", value of y:" << y->ToString(); \
}
SCALAR_OP(Add)
@ -273,8 +237,8 @@ SCALAR_OP(Floordiv)
if (list.size() < kListInputSize) { \
MS_LOG(EXCEPTION) << "The length of input list for Scalar" << #op_t << " is less than 2."; \
} \
ValuePtr x = list[0]; \
ValuePtr y = list[1]; \
const ValuePtr &x = list[0]; \
const ValuePtr &y = list[1]; \
MS_EXCEPTION_IF_NULL(x); \
MS_EXCEPTION_IF_NULL(y); \
if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \