diff --git a/mindspore/ccsrc/operator/cc_implementations.cc b/mindspore/ccsrc/operator/cc_implementations.cc index 62b23b346fc..49dc3ab7910 100644 --- a/mindspore/ccsrc/operator/cc_implementations.cc +++ b/mindspore/ccsrc/operator/cc_implementations.cc @@ -201,6 +201,14 @@ bool InnerScalarGe(T x, U y) { int sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ + if (x->isa() && y->isa()) { \ + float sum = InnerScalar##op_t(IntToFloat(GetValue(x)), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + float sum = InnerScalar##op_t(GetValue(x), IntToFloat(GetValue(y))); \ + return MakeValue(sum); \ + } \ MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \ << ", y: " << y->ToString(); \ } while (0); \ diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index 1512596cb43..b6cb57a9b87 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -445,6 +445,9 @@ AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const } ValuePtr inferred_value = RunImpl(value_list); + if (!(*inferred_value == *kAnyValue)) { + ret_value_type = inferred_value->type(); + } // for comparison primitives , return type shall have be specified to be bool. if (specify_out_type_ != nullptr) { ret_value_type = specify_out_type_; diff --git a/mindspore/ccsrc/utils/convert_utils.h b/mindspore/ccsrc/utils/convert_utils.h index fbd4485a3f6..55f478d5fe1 100644 --- a/mindspore/ccsrc/utils/convert_utils.h +++ b/mindspore/ccsrc/utils/convert_utils.h @@ -81,6 +81,7 @@ inline size_t FloatToSize(float u) { } return static_cast(u); } +inline float IntToFloat(int32_t v) { return static_cast(v); } inline uint32_t IntToUint(int32_t u) { if (u < 0) { diff --git a/tests/ut/python/ops/test_python_operators.py b/tests/ut/python/ops/test_python_operators.py index eb65a7f373e..705774068d4 100644 --- a/tests/ut/python/ops/test_python_operators.py +++ b/tests/ut/python/ops/test_python_operators.py @@ -25,11 +25,13 @@ from ....mindspore_test_framework.mindspore_test import mindspore_test from ....mindspore_test_framework.pipeline.forward.compile_forward \ import pipeline_for_compile_forward_ge_graph_for_case_by_case_config -context.set_context(mode=context.GRAPH_MODE, save_graphs=True) +context.set_context(mode=context.GRAPH_MODE) + class ComparisonOpsNet(nn.Cell): def __init__(self): super(ComparisonOpsNet, self).__init__() + def construct(self, x, y): a = x <= y b = x <= 1.0 @@ -46,22 +48,60 @@ class ComparisonOpsNet(nn.Cell): m = k != l return a or b or c or d or e or f or g or h or i or j or m + +class MathOpsNet(nn.Cell): + def __init__(self): + super(MathOpsNet, self).__init__() + self.relu = P.ReLU() + + def construct(self, x, y): + x = x - (-1) + return self.relu(x) + + +class ScalarCompareNet(nn.Cell): + def __init__(self): + super(ScalarCompareNet, self).__init__() + self.relu = P.ReLU() + + def construct(self, x, y): + t = 0 + if 3 > 3.2: + t = x + y + else: + t = x - y + if 3.1 <= 5: + t = t - x + else: + t = t + x + a = 32.0 * 12 + b = 12/3.0 + if a > b: + t = t * x + else: + t = t / x + return t + + class LogicalNumberOpsNet(nn.Cell): def __init__(self): super(LogicalNumberOpsNet, self).__init__() self.cond = True self.one = 0 self.zero = 0.0 + def construct(self, x, y): if self.cond and self.one or self.zero and not self.one: return x + y return x - y + class LogicalTensorOpsNet(nn.Cell): def __init__(self): """""" super(LogicalTensorOpsNet, self).__init__() self.const_true = Tensor(True, dtype=mstype.bool_) + def construct(self, x, y): ret = x and y and (y or self.const_true) and (not self.const_true) return ret @@ -71,20 +111,29 @@ test_case_ops = [ ('CompareOpsNet', { 'block': ComparisonOpsNet(), 'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32), - Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}), + Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}), + ('MathOpsNet', { + 'block': MathOpsNet(), + 'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32), + Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}), + ('ScalarCompareNet', { + 'block': ScalarCompareNet(), + 'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32), + Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}), ('LogicalNumberOps', { 'block': LogicalNumberOpsNet(), 'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32), - Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}), + Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}), ('LogicalTensorOps', { 'block': LogicalTensorOpsNet(), 'desc_inputs': [Tensor(np.ones([6, 9, 10]).astype(np.bool_), dtype=mstype.bool_), - Tensor(np.zeros([6, 9, 10]).astype(np.bool_), dtype=mstype.bool_)]}), + Tensor(np.zeros([6, 9, 10]).astype(np.bool_), dtype=mstype.bool_)]}), ] test_case_lists = [test_case_ops] test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists) + @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) def test_compile(): - return test_exec_case \ No newline at end of file + return test_exec_case