From b0762483ce240dc807527a3957fa81d7fffc0929 Mon Sep 17 00:00:00 2001 From: yanglf1121 Date: Mon, 21 Jun 2021 15:30:46 +0800 Subject: [PATCH] support int32 add int64 scalar op --- mindspore/ccsrc/frontend/operator/cc_implementations.cc | 8 ++++++++ mindspore/core/utils/convert_utils_base.h | 2 ++ 2 files changed, 10 insertions(+) diff --git a/mindspore/ccsrc/frontend/operator/cc_implementations.cc b/mindspore/ccsrc/frontend/operator/cc_implementations.cc index 95fe71b356e..ca5a5596c1e 100644 --- a/mindspore/ccsrc/frontend/operator/cc_implementations.cc +++ b/mindspore/ccsrc/frontend/operator/cc_implementations.cc @@ -237,6 +237,10 @@ bool InnerScalarGe(T x, U y) { double sum = InnerScalar##op_t(LongToDouble(GetValue(x)), FloatToDouble(GetValue(y))); \ return MakeValue(sum); \ } \ + if (x->isa() && y->isa()) { \ + int64_t sum = InnerScalar##op_t(GetValue(x), IntToLong(GetValue(y))); \ + return MakeValue(sum); \ + } \ if (x->isa() && y->isa()) { \ double sum = InnerScalar##op_t(FloatToDouble(GetValue(x)), LongToDouble(GetValue(y))); \ return MakeValue(sum); \ @@ -245,6 +249,10 @@ bool InnerScalarGe(T x, U y) { double sum = InnerScalar##op_t(GetValue(x), LongToDouble(GetValue(y))); \ return MakeValue(sum); \ } \ + if (x->isa() && y->isa()) { \ + int64_t sum = InnerScalar##op_t(IntToLong(GetValue(x)), GetValue(y)); \ + return MakeValue(sum); \ + } \ MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \ << ", y: " << y->ToString(); \ } while (0); \ diff --git a/mindspore/core/utils/convert_utils_base.h b/mindspore/core/utils/convert_utils_base.h index bcdfb299a7f..12daef2ecb1 100644 --- a/mindspore/core/utils/convert_utils_base.h +++ b/mindspore/core/utils/convert_utils_base.h @@ -114,6 +114,8 @@ inline int32_t LongToInt(int64_t u) { return static_cast(u); } +inline int64_t IntToLong(int32_t v) { return static_cast(v); } + inline int64_t UlongToLong(uint64_t u) { if (u > static_cast((std::numeric_limits::max)())) { MS_LOG(EXCEPTION) << "The uint64_t value(" << u << ") exceeds the maximum value of int64_t.";