diff --git a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc index 90b9a6a5a2d..fbdf6c4cd46 100644 --- a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc +++ b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc @@ -206,7 +206,7 @@ bool CompareShape(const std::vector &x_shape, const std::vectorsize(); std::set axis_set; auto axis_data = axis_value_ptr->value(); @@ -348,17 +348,17 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); } - // Axis can be scalar, tuple or None - AbstractTuplePtr axis = nullptr; + // Axis can be scalar, tuple or list + AbstractSequeuePtr axis = nullptr; if (args_spec_list[1]->isa()) { MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar"; AbstractBasePtrList axis_list = {dyn_cast(args_spec_list[1])}; axis = std::make_shared(axis_list); - } else if (args_spec_list[1]->isa()) { - MS_LOG(DEBUG) << op_name << " evaluator second parameter is tuple"; - axis = args_spec_list[1]->cast(); + } else if (args_spec_list[1]->isa()) { + MS_LOG(DEBUG) << op_name << " evaluator second parameter is sequeue"; + axis = args_spec_list[1]->cast(); } else { - MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple, but got " + MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple or list, but got " << args_spec_list[1]->ToString(); } @@ -367,7 +367,7 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP MS_LOG(EXCEPTION) << op_name << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); } - auto axis_value_ptr = axis_value->cast(); + auto axis_value_ptr = axis_value->cast(); MS_EXCEPTION_IF_NULL(axis_value_ptr); return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 56388abbcdb..4af684a9c60 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -261,17 +261,19 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExe auto obj = out_args[i]; if (py::isinstance(obj)) { auto arg = py::cast(obj); - if (arg->data_type() == it->second) { + TypeId arg_type_id = arg->data_type(); + if (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second) { continue; } if (signature[i].rw == SignatureEnumRW::kRWWrite) { - prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg->data_type()), + prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg_type_id), TypeIdToMsTypeStr(it->second)); } } if (!py::isinstance(obj) && !py::isinstance(obj) && !py::isinstance(obj)) { - MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i << "th input is a not support type: " + MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i + << "th input is a not support implicit conversion type: " << py::cast(obj.attr("__class__").attr("__name__")) << ", and the value is " << py::cast(obj) << "."; } diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 0a975802f6c..733eae4a2da 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -239,7 +239,8 @@ class Tensor(Tensor_): Check all array elements along a given axis evaluate to True. Args: - axis (Union[None, int, tuple(int)): Dimensions of reduction. + axis (Union[None, int, tuple(int)): Dimensions of reduction, + when axis is None or empty tuple, reduce all dimensions. Default: (), reduce all dimensions. keep_dims (bool): Whether to keep the reduced dimensions. Default : False, don't keep these reduced dimensions. @@ -257,7 +258,8 @@ class Tensor(Tensor_): Check any array element along a given axis evaluate to True. Args: - axis (Union[None, int, tuple(int)): Dimensions of reduction. + axis (Union[None, int, tuple(int)): Dimensions of reduction, + when axis is None or empty tuple, reduce all dimensions. Default: (), reduce all dimensions. keep_dims (bool): Whether to keep the reduced dimensions. Default : False, don't keep these reduced dimensions. diff --git a/mindspore/core/ir/dtype_extends.cc b/mindspore/core/ir/dtype_extends.cc index 26bdaf00ab2..5144aff490e 100644 --- a/mindspore/core/ir/dtype_extends.cc +++ b/mindspore/core/ir/dtype_extends.cc @@ -338,21 +338,21 @@ TypePtr FunctionStrToType(const std::string &type_name) { TypePtr StringToType(const std::string &type_name) { TypePtr type = nullptr; - if (type_name.compare("None") == 0) { + if (type_name == "None") { type = std::make_shared(); - } else if (type_name.compare("Ellipsis") == 0) { + } else if (type_name == "Ellipsis") { type = std::make_shared(); - } else if (type_name.compare("TypeType") == 0) { + } else if (type_name == "TypeType") { type = std::make_shared(); - } else if (type_name.compare("SymbolicKeyType") == 0) { + } else if (type_name == "SymbolicKeyType") { type = std::make_shared(); - } else if (type_name.compare("RefKeyType") == 0) { + } else if (type_name == "RefKeyType") { type = std::make_shared(); - } else if (type_name.compare("EnvType") == 0) { + } else if (type_name == "EnvType") { type = std::make_shared(); - } else if (type_name.compare("Number") == 0) { + } else if (type_name == "Number") { type = std::make_shared(); - } else if (type_name.compare("Bool") == 0) { + } else if (type_name == "Bool") { type = std::make_shared(); } else if (type_name.compare(0, strlen("Int"), "Int") == 0) { type = StringToNumberType(type_name, "Int"); @@ -372,16 +372,18 @@ TypePtr StringToType(const std::string &type_name) { type = ListStrToType(type_name); } else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) { type = TupleStrToType(type_name); - } else if (type_name.compare("Slice") == 0) { + } else if (type_name == "Slice") { type = std::make_shared(); - } else if (type_name.compare("Dictionary") == 0) { + } else if (type_name == "Dictionary") { type = std::make_shared(); - } else if (type_name.compare("String") == 0) { + } else if (type_name == "String") { type = std::make_shared(); - } else if (type_name.compare("Problem") == 0) { + } else if (type_name == "Problem") { type = std::make_shared(); } else if (type_name.compare(0, strlen("Function"), "Function") == 0) { type = FunctionStrToType(type_name); + } else if (type_name == "mstype") { + type = std::make_shared(); } else { // - unsupported to convert // Class @@ -389,7 +391,6 @@ TypePtr StringToType(const std::string &type_name) { // JTagged // Anything // External - // Problem MS_LOG(EXCEPTION) << "Unsupported type name: " << type_name << "!"; } return type; @@ -403,10 +404,7 @@ bool IsParentOrChildrenType(TypePtr const &x, TypePtr const &base_type) { if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) { return false; } - if (base_type->type_id() == x->parent_type() || x->type_id() == base_type->parent_type()) { - return true; - } - return false; + return base_type->type_id() == x->parent_type() || x->type_id() == base_type->parent_type(); } bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) { diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc index 0e6b73201bf..010c931afa8 100644 --- a/mindspore/core/ir/func_graph_cloner.cc +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -206,8 +206,8 @@ void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNode MS_EXCEPTION_IF_NULL(func_graph); auto &old_params = func_graph->parameters(); if (old_params.size() != params.size()) { - MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]"; - return; + MS_EXCEPTION(TypeError) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() + << "]"; } for (size_t i = 0; i < old_params.size(); ++i) { repl_node_[old_params[i]] = params[i]; diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 88a900407ac..8bc337010a5 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -762,3 +762,10 @@ def get_stride_info_from_tuple(data_shape, index_tuple): end_strides.append(data_shape[item]) step_strides.append(1) return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis + + +@constexpr +def mstype_eq(x, y): + if x == y: + return True + return False diff --git a/mindspore/ops/composite/multitype_ops/equal_impl.py b/mindspore/ops/composite/multitype_ops/equal_impl.py index ff54c34fad3..97e647bc7c1 100644 --- a/mindspore/ops/composite/multitype_ops/equal_impl.py +++ b/mindspore/ops/composite/multitype_ops/equal_impl.py @@ -14,7 +14,7 @@ # ============================================================================ """equal_impl""" - +from . import _constexpr_utils as const_utils from ...composite import base from ... import functional as F @@ -32,8 +32,8 @@ def _equal_scalar(x, y): Determine if two numbers are equal. Args: - x (Number): x - y (NUmber): y + x (Number): first input number. + y (NUmber): second input number. Returns: bool, if x == y return true, x != y return false. @@ -41,14 +41,29 @@ def _equal_scalar(x, y): return F.scalar_eq(x, y) +@equal.register("mstype", "mstype") +def _equal_mstype(x, y): + """ + Determine if two mindspore types are equal. + + Args: + x (mstype): first input mindspore type. + y (mstype): second input mindspore type. + + Returns: + bool, if x == y return true, x != y return false. + """ + return const_utils.mstype_eq(x, y) + + @equal.register("String", "String") def _equal_string(x, y): """ Determine if two strings are equal. Args: - x: str - y: str + x (str): first input string. + y (str): second input string. Returns: bool, if x == y return true, x != y return false. @@ -62,8 +77,8 @@ def _string_equal_none(x, y): Determine if string equals none. Args: - x: str. - y: None. + x (str): first input string. + y (None) second input None. Returns: bool, return false. @@ -77,8 +92,8 @@ def _none_equal_string(x, y): Determine if string equals none. Args: - x: None. - y: str. + x (None): first input None. + y (str): second input string. Returns: bool, return false. @@ -92,8 +107,8 @@ def _none_equal_none(x, y): Determine if none equals none. Args: - x: None. - y: None. + x (None): first input None. + y (None): second input None. Returns: bool, return true. @@ -107,8 +122,8 @@ def _scalar_equal_none(x, y): Determine if number equals none. Args: - x: Number. - y: None. + x (Number): first input number. + y (None): second input None. Returns: bool, return false. @@ -122,8 +137,8 @@ def _none_equal_scalar(x, y): Determine if number equals none. Args: - x: None. - y: NUmber. + x (None): first input None. + y (Number): second input Number. Returns: bool, return false. @@ -137,8 +152,8 @@ def _euqal_tuple(x, y): Determine if two tuples are equal by element. Args: - x (tuple): x - y (tuple): y + x (tuple): first input tuple. + y (tuple): second input tuple. Returns: bool, if x and y are equal by element return true, else return false. @@ -152,8 +167,8 @@ def _euqal_list(x, y): Determine if two lists are equal by element. Args: - x (list): x - y (list): y + x (list): first input list. + y (list): second input list. Returns: bool, if x and y are equal by element return true, else return false. @@ -167,8 +182,8 @@ def _tuple_euqal_none(x, y): Determine if tuple element equals none element. Args: - x: Tuple. - y: None. + x(tuple): first input tuple. + y (None): second input None. Returns: bool, return false. @@ -182,8 +197,8 @@ def _none_equal_tuple(x, y): Determine if tuple element equals none element. Args: - x: None. - y: Tuple. + x (None): first input None. + y (tuple): second input tuple. Returns: bool, return false. @@ -199,8 +214,8 @@ def _tensor_equal_tensor(x, y): Determine if two tensors are equal. Args: - x : Tensor. - y : Tensor. + x (Tensor): first input tensor. + y (Tensor): second input tensor. Returns: bool, if x == y return true, x != y return false. @@ -214,8 +229,8 @@ def _tensor_equal_none(x, y): Determine if tensor equal none. Args: - x : Tensor. - y : None. + x (Tensor): first input tensor. + y (None): second input None. Returns: bool, return false. @@ -229,8 +244,8 @@ def _none_equal_tensor(x, y): Determine if tensor equal none. Args: - x : None. - y : Tensor. + x (None): first input None. + y (Tensor): second input tensor. Returns: bool, return false. @@ -245,7 +260,7 @@ def _list_equal_none(x, y): Args: x (list): The first input which is a list. - y (none): The second input which is none. + y (None): The second input which is none. Returns: bool, return false. @@ -259,7 +274,7 @@ def _none_equal_list(x, y): Determine if none equal list. Args: - x (none): The first input which is none. + x (None): The first input which is none. y (list): The second input which is a list. Returns: diff --git a/tests/st/pynative/test_implicit_conversion.py b/tests/st/pynative/test_implicit_conversion.py index fce6c24cbb8..ab3d46f3b29 100644 --- a/tests/st/pynative/test_implicit_conversion.py +++ b/tests/st/pynative/test_implicit_conversion.py @@ -49,6 +49,7 @@ def test_bool_and_int_tensor_add(): ret_expect = Tensor(np.array([[2, 3, 4], [5, 6, 7]], dtype=np.int32)) assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + def test_float_tensor_and_int_tensor_add(): x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) diff --git a/tests/ut/python/pipeline/parse/test_equal_two_mstype.py b/tests/ut/python/pipeline/parse/test_equal_two_mstype.py new file mode 100644 index 00000000000..0f4a72d0eec --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_equal_two_mstype.py @@ -0,0 +1,65 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test enumerate""" +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE) + + +def test_equal_two_const_mstype(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.type_base = mstype.float32 + self.type_0 = mstype.float32 + self.type_1 = mstype.float16 + self.type_2 = mstype.int32 + self.type_3 = mstype.tuple_ + + def construct(self): + ret_0 = self.type_0 == self.type_base + ret_1 = self.type_1 == self.type_base + ret_2 = self.type_2 == self.type_base + ret_3 = self.type_3 == self.type_base + return ret_0, ret_1, ret_2, ret_3 + + net = Net() + assert net() == (True, False, False, False) + + +def test_equal_two_tensor_mstype(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x, y, z): + ret_x = x.dtype == mstype.float32 + ret_y = y.dtype == mstype.int32 + ret_z = z.dtype == mstype.bool_ + ret_xy = x.dtype == y.dtype + ret_xz = x.dtype == z.dtype + ret_yz = y.dtype == z.dtype + return ret_x, ret_y, ret_z, ret_xy, ret_xz, ret_yz + + net = Net() + x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32) + y = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.int32) + z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.bool_) + assert net(x, y, z) == (True, True, True, False, False, False) diff --git a/tests/ut/python/pynative_mode/test_implicit_conversion.py b/tests/ut/python/pynative_mode/test_implicit_conversion.py index 3a19732462c..1cf65807ddd 100644 --- a/tests/ut/python/pynative_mode/test_implicit_conversion.py +++ b/tests/ut/python/pynative_mode/test_implicit_conversion.py @@ -96,7 +96,7 @@ def test_float_tensor_and_str_add(): y = "ok" with pytest.raises(TypeError) as er: ret = x + y - assert "For 'TensorAdd', the 1th input is a not support type: str" in str(er.value) + assert "For 'TensorAdd', the 1th input is a not support implicit conversion type: str" in str(er.value) def test_float_tensor_and_tuple_add(): @@ -104,7 +104,7 @@ def test_float_tensor_and_tuple_add(): y = (1, 2, 3) with pytest.raises(TypeError) as er: ret = x + y - assert "For 'TensorAdd', the 1th input is a not support type: tuple" in str(er.value) + assert "For 'TensorAdd', the 1th input is a not support implicit conversion type: tuple" in str(er.value) def test_float_tensor_and_list_add(): @@ -112,7 +112,7 @@ def test_float_tensor_and_list_add(): y = [1, 2, 3] with pytest.raises(TypeError) as er: ret = x + y - assert "For 'TensorAdd', the 1th input is a not support type: list" in str(er.value) + assert "For 'TensorAdd', the 1th input is a not support implicit conversion type: list" in str(er.value) def test_float_tensor_and_bool_tensors_add_grad():