forked from mindspore-Ecosystem/mindspore
Optimize scalar operator implementations
Use compiler built-in functions for overflow checking.
This commit is contained in:
parent
91675add3e
commit
e3602df91e
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -48,76 +48,41 @@ DataType InferType(const AnyPtrList &list) {
|
|||
return DataType::kUnknown;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool IsAddOverflow(const T &x, const T &y, const T &max, const T &min) {
|
||||
return (y > 0 && (max - y) < x) || (y < 0 && (min - y) > x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool IsSubOverflow(const T &x, const T &y, const T &max, const T &min) {
|
||||
return (y < 0 && (max + y) < x) || (y > 0 && (min + y) > x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool IsMulOverflow(const T &x, const T &y, const T &max, const T &min) {
|
||||
return (x > 0 && y > 0 && (max / y) < x) || (x < 0 && y < 0 && (max / y) > x) || (x > 0 && y < 0 && (min / y) < x) ||
|
||||
(x < 0 && y > 0 && (min / y) > x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool IsDivOverflow(const T &x, const T &y, const T &min) {
|
||||
return (x == min && static_cast<int64_t>(y) == -1);
|
||||
}
|
||||
|
||||
enum class OpType { ADD, SUB, MUL, DIV, MOD };
|
||||
|
||||
template <typename T>
|
||||
bool IsSignedIntOverflow(T x, T y, OpType opType) {
|
||||
auto max = std::numeric_limits<T>::max();
|
||||
auto min = std::numeric_limits<T>::min();
|
||||
|
||||
if (opType == OpType::ADD) {
|
||||
return IsAddOverflow<T>(x, y, max, min);
|
||||
}
|
||||
|
||||
if (opType == OpType::SUB) {
|
||||
return IsSubOverflow<T>(x, y, max, min);
|
||||
}
|
||||
|
||||
if (opType == OpType::MUL) {
|
||||
return IsMulOverflow<T>(x, y, max, min);
|
||||
}
|
||||
|
||||
if (opType == OpType::DIV || opType == OpType::MOD) {
|
||||
return IsDivOverflow<T>(x, y, min);
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Unsupported operation type.";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T InnerScalarAdd(T x, T y) {
|
||||
if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::ADD)) {
|
||||
MS_LOG(EXCEPTION) << "Overflow of the sum of two signed number x: " << std::to_string(x)
|
||||
<< ", y: " << std::to_string(y) << ".";
|
||||
if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
|
||||
T res;
|
||||
if (__builtin_add_overflow(x, y, &res)) {
|
||||
MS_LOG(EXCEPTION) << "Overflow of the sum of two signed number x: " << std::to_string(x)
|
||||
<< ", y: " << std::to_string(y) << ".";
|
||||
}
|
||||
return res;
|
||||
}
|
||||
return x + y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T InnerScalarSub(T x, T y) {
|
||||
if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::SUB)) {
|
||||
MS_LOG(EXCEPTION) << "Overflow of the sub of two signed number x: " << std::to_string(x)
|
||||
<< ", y: " << std::to_string(y) << ".";
|
||||
if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
|
||||
T res;
|
||||
if (__builtin_sub_overflow(x, y, &res)) {
|
||||
MS_LOG(EXCEPTION) << "Overflow of the sub of two signed number x: " << std::to_string(x)
|
||||
<< ", y: " << std::to_string(y) << ".";
|
||||
}
|
||||
return res;
|
||||
}
|
||||
return x - y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T InnerScalarMul(T x, T y) {
|
||||
if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::MUL)) {
|
||||
MS_LOG(EXCEPTION) << "Overflow of the mul of two signed number x: " << std::to_string(x)
|
||||
<< ", y: " << std::to_string(y) << ".";
|
||||
if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
|
||||
T res;
|
||||
if (__builtin_mul_overflow(x, y, &res)) {
|
||||
MS_LOG(EXCEPTION) << "Overflow of the mul of two signed number x: " << std::to_string(x)
|
||||
<< ", y: " << std::to_string(y) << ".";
|
||||
}
|
||||
return res;
|
||||
}
|
||||
return x * y;
|
||||
}
|
||||
|
@ -127,9 +92,11 @@ float InnerScalarDiv(T x, T y) {
|
|||
if (y == 0) {
|
||||
MS_LOG(EXCEPTION) << "The divisor could not be zero.";
|
||||
}
|
||||
if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::DIV)) {
|
||||
MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x)
|
||||
<< ", y: " << std::to_string(y) << ".";
|
||||
if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
|
||||
if (x == std::numeric_limits<T>::min() && static_cast<int64_t>(y) == -1) {
|
||||
MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x)
|
||||
<< ", y: " << std::to_string(y) << ".";
|
||||
}
|
||||
}
|
||||
return static_cast<float>(x) / static_cast<float>(y);
|
||||
}
|
||||
|
@ -137,10 +104,7 @@ float InnerScalarDiv(T x, T y) {
|
|||
template <typename T>
|
||||
T InnerScalarFloordiv(T x, T y) {
|
||||
auto ret = std::floor(InnerScalarDiv(x, y));
|
||||
if (std::is_integral<T>::value) {
|
||||
return static_cast<int64_t>(ret);
|
||||
}
|
||||
return ret;
|
||||
return static_cast<T>(ret);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -148,14 +112,16 @@ T InnerScalarMod(T x, T y) {
|
|||
if (y == 0) {
|
||||
MS_LOG(EXCEPTION) << "Could not mod to zero.";
|
||||
}
|
||||
if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::MOD)) {
|
||||
MS_LOG(EXCEPTION) << "Overflow of the mod of two signed number x: " << std::to_string(x)
|
||||
<< ", y: " << std::to_string(y) << ".";
|
||||
if constexpr (!std::is_integral<T>::value) {
|
||||
return x - y * std::floor(x / y);
|
||||
}
|
||||
if (std::is_integral<T>::value) {
|
||||
return static_cast<int64_t>(x) % static_cast<int64_t>(y);
|
||||
if constexpr (std::is_signed<T>::value) {
|
||||
if (x == std::numeric_limits<T>::min() && static_cast<int64_t>(y) == -1) {
|
||||
MS_LOG(EXCEPTION) << "Overflow of the mod of two signed number x: " << std::to_string(x)
|
||||
<< ", y: " << std::to_string(y) << ".";
|
||||
}
|
||||
}
|
||||
return x - y * std::floor(x / y);
|
||||
return static_cast<int64_t>(x) % static_cast<int64_t>(y);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
|
@ -195,68 +161,66 @@ bool InnerScalarGe(T x, U y) {
|
|||
return x >= y;
|
||||
}
|
||||
|
||||
#define SCALAR_OP(op_t) \
|
||||
ValuePtr Scalar##op_t(const ValuePtrList &list) { \
|
||||
do { \
|
||||
if (list.size() < 2) { \
|
||||
MS_LOG(EXCEPTION) << "The length of input list for Scalar" << #op_t << " is less than 2."; \
|
||||
} \
|
||||
ValuePtr x = list[0]; \
|
||||
ValuePtr y = list[1]; \
|
||||
MS_EXCEPTION_IF_NULL(x); \
|
||||
MS_EXCEPTION_IF_NULL(y); \
|
||||
if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \
|
||||
double sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y)); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) { \
|
||||
float sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y)); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) { \
|
||||
int sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y)); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
|
||||
float sum = InnerScalar##op_t(IntToFloat(GetValue<int>(x)), GetValue<float>(y)); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) { \
|
||||
float sum = InnerScalar##op_t(GetValue<float>(x), IntToFloat(GetValue<int>(y))); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<Int64Imm>() && y->isa<Int64Imm>()) { \
|
||||
int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int64_t>(y)); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<Int64Imm>() && y->isa<FP64Imm>()) { \
|
||||
double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), GetValue<double>(y)); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<Int64Imm>() && y->isa<FP32Imm>()) { \
|
||||
double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), FloatToDouble(GetValue<float>(y))); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
|
||||
int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), IntToLong(GetValue<int>(y))); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) { \
|
||||
double sum = InnerScalar##op_t(FloatToDouble(GetValue<float>(x)), LongToDouble(GetValue<int64_t>(y))); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<FP64Imm>() && y->isa<Int64Imm>()) { \
|
||||
double sum = InnerScalar##op_t(GetValue<double>(x), LongToDouble(GetValue<int64_t>(y))); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) { \
|
||||
int64_t sum = InnerScalar##op_t(IntToLong(GetValue<int>(x)), GetValue<int64_t>(y)); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
MS_LOG(EXCEPTION) << "Unsupported input type for Scalar" << #op_t << ", type of x:" << x->type_name() \
|
||||
<< ", value of x:" << x->ToString() << ", type of y:" << y->type_name() \
|
||||
<< ", value of y:" << y->ToString(); \
|
||||
} while (0); \
|
||||
#define SCALAR_OP(op_t) \
|
||||
ValuePtr Scalar##op_t(const ValuePtrList &list) { \
|
||||
if (list.size() < 2) { \
|
||||
MS_LOG(EXCEPTION) << "The length of input list for Scalar" << #op_t << " is less than 2."; \
|
||||
} \
|
||||
const ValuePtr &x = list[0]; \
|
||||
const ValuePtr &y = list[1]; \
|
||||
MS_EXCEPTION_IF_NULL(x); \
|
||||
MS_EXCEPTION_IF_NULL(y); \
|
||||
if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \
|
||||
double sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y)); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) { \
|
||||
float sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y)); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) { \
|
||||
int sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y)); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
|
||||
float sum = InnerScalar##op_t(IntToFloat(GetValue<int>(x)), GetValue<float>(y)); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) { \
|
||||
float sum = InnerScalar##op_t(GetValue<float>(x), IntToFloat(GetValue<int>(y))); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<Int64Imm>() && y->isa<Int64Imm>()) { \
|
||||
int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int64_t>(y)); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<Int64Imm>() && y->isa<FP64Imm>()) { \
|
||||
double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), GetValue<double>(y)); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<Int64Imm>() && y->isa<FP32Imm>()) { \
|
||||
double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), FloatToDouble(GetValue<float>(y))); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
|
||||
int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), IntToLong(GetValue<int>(y))); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) { \
|
||||
double sum = InnerScalar##op_t(FloatToDouble(GetValue<float>(x)), LongToDouble(GetValue<int64_t>(y))); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<FP64Imm>() && y->isa<Int64Imm>()) { \
|
||||
double sum = InnerScalar##op_t(GetValue<double>(x), LongToDouble(GetValue<int64_t>(y))); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) { \
|
||||
int64_t sum = InnerScalar##op_t(IntToLong(GetValue<int>(x)), GetValue<int64_t>(y)); \
|
||||
return MakeValue(sum); \
|
||||
} \
|
||||
MS_LOG(EXCEPTION) << "Unsupported input type for Scalar" << #op_t << ", type of x:" << x->type_name() \
|
||||
<< ", value of x:" << x->ToString() << ", type of y:" << y->type_name() \
|
||||
<< ", value of y:" << y->ToString(); \
|
||||
}
|
||||
|
||||
SCALAR_OP(Add)
|
||||
|
@ -273,8 +237,8 @@ SCALAR_OP(Floordiv)
|
|||
if (list.size() < kListInputSize) { \
|
||||
MS_LOG(EXCEPTION) << "The length of input list for Scalar" << #op_t << " is less than 2."; \
|
||||
} \
|
||||
ValuePtr x = list[0]; \
|
||||
ValuePtr y = list[1]; \
|
||||
const ValuePtr &x = list[0]; \
|
||||
const ValuePtr &y = list[1]; \
|
||||
MS_EXCEPTION_IF_NULL(x); \
|
||||
MS_EXCEPTION_IF_NULL(y); \
|
||||
if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \
|
||||
|
|
Loading…
Reference in New Issue