!18657 support int32 scalar + int64 scalar in graph mode

Merge pull request !18657 from 杨林枫/support_int32_int64_scalar_add
This commit is contained in:
i-robot 2021-06-22 08:14:03 +00:00 committed by Gitee
commit e7ea93dacd
2 changed files with 10 additions and 0 deletions

View File

@ -237,6 +237,10 @@ bool InnerScalarGe(T x, U y) {
double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), FloatToDouble(GetValue<float>(y))); \
return MakeValue(sum); \
} \
if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), IntToLong(GetValue<int>(y))); \
return MakeValue(sum); \
} \
if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) { \
double sum = InnerScalar##op_t(FloatToDouble(GetValue<float>(x)), LongToDouble(GetValue<int64_t>(y))); \
return MakeValue(sum); \
@ -245,6 +249,10 @@ bool InnerScalarGe(T x, U y) {
double sum = InnerScalar##op_t(GetValue<double>(x), LongToDouble(GetValue<int64_t>(y))); \
return MakeValue(sum); \
} \
if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) { \
int64_t sum = InnerScalar##op_t(IntToLong(GetValue<int>(x)), GetValue<int64_t>(y)); \
return MakeValue(sum); \
} \
MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \
<< ", y: " << y->ToString(); \
} while (0); \

View File

@ -114,6 +114,8 @@ inline int32_t LongToInt(int64_t u) {
return static_cast<int32_t>(u);
}
inline int64_t IntToLong(int32_t v) { return static_cast<int64_t>(v); }
inline int64_t UlongToLong(uint64_t u) {
if (u > static_cast<uint64_t>((std::numeric_limits<int64_t>::max)())) {
MS_LOG(EXCEPTION) << "The uint64_t value(" << u << ") exceeds the maximum value of int64_t.";