forked from mindspore-Ecosystem/mindspore
!18657 support int32 scalar + int64 scalar in graph mode
Merge pull request !18657 from 杨林枫/support_int32_int64_scalar_add
This commit is contained in:
commit
e7ea93dacd
|
@ -237,6 +237,10 @@ bool InnerScalarGe(T x, U y) {
|
||||||
double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), FloatToDouble(GetValue<float>(y))); \
|
double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), FloatToDouble(GetValue<float>(y))); \
|
||||||
return MakeValue(sum); \
|
return MakeValue(sum); \
|
||||||
} \
|
} \
|
||||||
|
if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
|
||||||
|
int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), IntToLong(GetValue<int>(y))); \
|
||||||
|
return MakeValue(sum); \
|
||||||
|
} \
|
||||||
if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) { \
|
if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) { \
|
||||||
double sum = InnerScalar##op_t(FloatToDouble(GetValue<float>(x)), LongToDouble(GetValue<int64_t>(y))); \
|
double sum = InnerScalar##op_t(FloatToDouble(GetValue<float>(x)), LongToDouble(GetValue<int64_t>(y))); \
|
||||||
return MakeValue(sum); \
|
return MakeValue(sum); \
|
||||||
|
@ -245,6 +249,10 @@ bool InnerScalarGe(T x, U y) {
|
||||||
double sum = InnerScalar##op_t(GetValue<double>(x), LongToDouble(GetValue<int64_t>(y))); \
|
double sum = InnerScalar##op_t(GetValue<double>(x), LongToDouble(GetValue<int64_t>(y))); \
|
||||||
return MakeValue(sum); \
|
return MakeValue(sum); \
|
||||||
} \
|
} \
|
||||||
|
if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) { \
|
||||||
|
int64_t sum = InnerScalar##op_t(IntToLong(GetValue<int>(x)), GetValue<int64_t>(y)); \
|
||||||
|
return MakeValue(sum); \
|
||||||
|
} \
|
||||||
MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \
|
MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \
|
||||||
<< ", y: " << y->ToString(); \
|
<< ", y: " << y->ToString(); \
|
||||||
} while (0); \
|
} while (0); \
|
||||||
|
|
|
@ -114,6 +114,8 @@ inline int32_t LongToInt(int64_t u) {
|
||||||
return static_cast<int32_t>(u);
|
return static_cast<int32_t>(u);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline int64_t IntToLong(int32_t v) { return static_cast<int64_t>(v); }
|
||||||
|
|
||||||
inline int64_t UlongToLong(uint64_t u) {
|
inline int64_t UlongToLong(uint64_t u) {
|
||||||
if (u > static_cast<uint64_t>((std::numeric_limits<int64_t>::max)())) {
|
if (u > static_cast<uint64_t>((std::numeric_limits<int64_t>::max)())) {
|
||||||
MS_LOG(EXCEPTION) << "The uint64_t value(" << u << ") exceeds the maximum value of int64_t.";
|
MS_LOG(EXCEPTION) << "The uint64_t value(" << u << ") exceeds the maximum value of int64_t.";
|
||||||
|
|
Loading…
Reference in New Issue