forked from mindspore-Ecosystem/mindspore
!4265 support two mstypes do equal.
Merge pull request !4265 from zhangbuxue/support_two_mstypes_do_equal
This commit is contained in:
commit
735fc98c78
|
@ -206,7 +206,7 @@ bool CompareShape(const std::vector<ValuePtr> &x_shape, const std::vector<ValueP
|
|||
}
|
||||
|
||||
AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value,
|
||||
const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) {
|
||||
const ValueSequeuePtr &axis_value_ptr, const PrimitivePtr &primitive) {
|
||||
size_t x_rank = x_shape->size();
|
||||
std::set<int> axis_set;
|
||||
auto axis_data = axis_value_ptr->value();
|
||||
|
@ -348,17 +348,17 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP
|
|||
<< " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
|
||||
}
|
||||
|
||||
// Axis can be scalar, tuple or None
|
||||
AbstractTuplePtr axis = nullptr;
|
||||
// Axis can be scalar, tuple or list
|
||||
AbstractSequeuePtr axis = nullptr;
|
||||
if (args_spec_list[1]->isa<AbstractScalar>()) {
|
||||
MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar";
|
||||
AbstractBasePtrList axis_list = {dyn_cast<AbstractScalar>(args_spec_list[1])};
|
||||
axis = std::make_shared<AbstractTuple>(axis_list);
|
||||
} else if (args_spec_list[1]->isa<AbstractTuple>()) {
|
||||
MS_LOG(DEBUG) << op_name << " evaluator second parameter is tuple";
|
||||
axis = args_spec_list[1]->cast<AbstractTuplePtr>();
|
||||
} else if (args_spec_list[1]->isa<AbstractSequeue>()) {
|
||||
MS_LOG(DEBUG) << op_name << " evaluator second parameter is sequeue";
|
||||
axis = args_spec_list[1]->cast<AbstractSequeuePtr>();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple, but got "
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple or list, but got "
|
||||
<< args_spec_list[1]->ToString();
|
||||
}
|
||||
|
||||
|
@ -367,7 +367,7 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP
|
|||
MS_LOG(EXCEPTION) << op_name
|
||||
<< " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
|
||||
}
|
||||
auto axis_value_ptr = axis_value->cast<ValueTuplePtr>();
|
||||
auto axis_value_ptr = axis_value->cast<ValueSequeuePtr>();
|
||||
MS_EXCEPTION_IF_NULL(axis_value_ptr);
|
||||
|
||||
return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive);
|
||||
|
|
|
@ -261,17 +261,19 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExe
|
|||
auto obj = out_args[i];
|
||||
if (py::isinstance<tensor::Tensor>(obj)) {
|
||||
auto arg = py::cast<tensor::TensorPtr>(obj);
|
||||
if (arg->data_type() == it->second) {
|
||||
TypeId arg_type_id = arg->data_type();
|
||||
if (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second) {
|
||||
continue;
|
||||
}
|
||||
if (signature[i].rw == SignatureEnumRW::kRWWrite) {
|
||||
prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg->data_type()),
|
||||
prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg_type_id),
|
||||
TypeIdToMsTypeStr(it->second));
|
||||
}
|
||||
}
|
||||
|
||||
if (!py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj)) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i << "th input is a not support type: "
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i
|
||||
<< "th input is a not support implicit conversion type: "
|
||||
<< py::cast<std::string>(obj.attr("__class__").attr("__name__")) << ", and the value is "
|
||||
<< py::cast<py::str>(obj) << ".";
|
||||
}
|
||||
|
|
|
@ -239,7 +239,8 @@ class Tensor(Tensor_):
|
|||
Check all array elements along a given axis evaluate to True.
|
||||
|
||||
Args:
|
||||
axis (Union[None, int, tuple(int)): Dimensions of reduction.
|
||||
axis (Union[None, int, tuple(int)): Dimensions of reduction,
|
||||
when axis is None or empty tuple, reduce all dimensions.
|
||||
Default: (), reduce all dimensions.
|
||||
keep_dims (bool): Whether to keep the reduced dimensions.
|
||||
Default : False, don't keep these reduced dimensions.
|
||||
|
@ -257,7 +258,8 @@ class Tensor(Tensor_):
|
|||
Check any array element along a given axis evaluate to True.
|
||||
|
||||
Args:
|
||||
axis (Union[None, int, tuple(int)): Dimensions of reduction.
|
||||
axis (Union[None, int, tuple(int)): Dimensions of reduction,
|
||||
when axis is None or empty tuple, reduce all dimensions.
|
||||
Default: (), reduce all dimensions.
|
||||
keep_dims (bool): Whether to keep the reduced dimensions.
|
||||
Default : False, don't keep these reduced dimensions.
|
||||
|
|
|
@ -338,21 +338,21 @@ TypePtr FunctionStrToType(const std::string &type_name) {
|
|||
|
||||
TypePtr StringToType(const std::string &type_name) {
|
||||
TypePtr type = nullptr;
|
||||
if (type_name.compare("None") == 0) {
|
||||
if (type_name == "None") {
|
||||
type = std::make_shared<TypeNone>();
|
||||
} else if (type_name.compare("Ellipsis") == 0) {
|
||||
} else if (type_name == "Ellipsis") {
|
||||
type = std::make_shared<TypeEllipsis>();
|
||||
} else if (type_name.compare("TypeType") == 0) {
|
||||
} else if (type_name == "TypeType") {
|
||||
type = std::make_shared<TypeType>();
|
||||
} else if (type_name.compare("SymbolicKeyType") == 0) {
|
||||
} else if (type_name == "SymbolicKeyType") {
|
||||
type = std::make_shared<SymbolicKeyType>();
|
||||
} else if (type_name.compare("RefKeyType") == 0) {
|
||||
} else if (type_name == "RefKeyType") {
|
||||
type = std::make_shared<RefKeyType>();
|
||||
} else if (type_name.compare("EnvType") == 0) {
|
||||
} else if (type_name == "EnvType") {
|
||||
type = std::make_shared<EnvType>();
|
||||
} else if (type_name.compare("Number") == 0) {
|
||||
} else if (type_name == "Number") {
|
||||
type = std::make_shared<Number>();
|
||||
} else if (type_name.compare("Bool") == 0) {
|
||||
} else if (type_name == "Bool") {
|
||||
type = std::make_shared<Bool>();
|
||||
} else if (type_name.compare(0, strlen("Int"), "Int") == 0) {
|
||||
type = StringToNumberType<Int>(type_name, "Int");
|
||||
|
@ -372,16 +372,18 @@ TypePtr StringToType(const std::string &type_name) {
|
|||
type = ListStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) {
|
||||
type = TupleStrToType(type_name);
|
||||
} else if (type_name.compare("Slice") == 0) {
|
||||
} else if (type_name == "Slice") {
|
||||
type = std::make_shared<Slice>();
|
||||
} else if (type_name.compare("Dictionary") == 0) {
|
||||
} else if (type_name == "Dictionary") {
|
||||
type = std::make_shared<Dictionary>();
|
||||
} else if (type_name.compare("String") == 0) {
|
||||
} else if (type_name == "String") {
|
||||
type = std::make_shared<String>();
|
||||
} else if (type_name.compare("Problem") == 0) {
|
||||
} else if (type_name == "Problem") {
|
||||
type = std::make_shared<Problem>();
|
||||
} else if (type_name.compare(0, strlen("Function"), "Function") == 0) {
|
||||
type = FunctionStrToType(type_name);
|
||||
} else if (type_name == "mstype") {
|
||||
type = std::make_shared<TypeType>();
|
||||
} else {
|
||||
// - unsupported to convert
|
||||
// Class
|
||||
|
@ -389,7 +391,6 @@ TypePtr StringToType(const std::string &type_name) {
|
|||
// JTagged
|
||||
// Anything
|
||||
// External
|
||||
// Problem
|
||||
MS_LOG(EXCEPTION) << "Unsupported type name: " << type_name << "!";
|
||||
}
|
||||
return type;
|
||||
|
@ -403,10 +404,7 @@ bool IsParentOrChildrenType(TypePtr const &x, TypePtr const &base_type) {
|
|||
if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) {
|
||||
return false;
|
||||
}
|
||||
if (base_type->type_id() == x->parent_type() || x->type_id() == base_type->parent_type()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
return base_type->type_id() == x->parent_type() || x->type_id() == base_type->parent_type();
|
||||
}
|
||||
|
||||
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
|
||||
|
|
|
@ -206,8 +206,8 @@ void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNode
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto &old_params = func_graph->parameters();
|
||||
if (old_params.size() != params.size()) {
|
||||
MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]";
|
||||
return;
|
||||
MS_EXCEPTION(TypeError) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size()
|
||||
<< "]";
|
||||
}
|
||||
for (size_t i = 0; i < old_params.size(); ++i) {
|
||||
repl_node_[old_params[i]] = params[i];
|
||||
|
|
|
@ -762,3 +762,10 @@ def get_stride_info_from_tuple(data_shape, index_tuple):
|
|||
end_strides.append(data_shape[item])
|
||||
step_strides.append(1)
|
||||
return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis
|
||||
|
||||
|
||||
@constexpr
|
||||
def mstype_eq(x, y):
|
||||
if x == y:
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
|
||||
"""equal_impl"""
|
||||
|
||||
from . import _constexpr_utils as const_utils
|
||||
from ...composite import base
|
||||
from ... import functional as F
|
||||
|
||||
|
@ -32,8 +32,8 @@ def _equal_scalar(x, y):
|
|||
Determine if two numbers are equal.
|
||||
|
||||
Args:
|
||||
x (Number): x
|
||||
y (NUmber): y
|
||||
x (Number): first input number.
|
||||
y (NUmber): second input number.
|
||||
|
||||
Returns:
|
||||
bool, if x == y return true, x != y return false.
|
||||
|
@ -41,14 +41,29 @@ def _equal_scalar(x, y):
|
|||
return F.scalar_eq(x, y)
|
||||
|
||||
|
||||
@equal.register("mstype", "mstype")
|
||||
def _equal_mstype(x, y):
|
||||
"""
|
||||
Determine if two mindspore types are equal.
|
||||
|
||||
Args:
|
||||
x (mstype): first input mindspore type.
|
||||
y (mstype): second input mindspore type.
|
||||
|
||||
Returns:
|
||||
bool, if x == y return true, x != y return false.
|
||||
"""
|
||||
return const_utils.mstype_eq(x, y)
|
||||
|
||||
|
||||
@equal.register("String", "String")
|
||||
def _equal_string(x, y):
|
||||
"""
|
||||
Determine if two strings are equal.
|
||||
|
||||
Args:
|
||||
x: str
|
||||
y: str
|
||||
x (str): first input string.
|
||||
y (str): second input string.
|
||||
|
||||
Returns:
|
||||
bool, if x == y return true, x != y return false.
|
||||
|
@ -62,8 +77,8 @@ def _string_equal_none(x, y):
|
|||
Determine if string equals none.
|
||||
|
||||
Args:
|
||||
x: str.
|
||||
y: None.
|
||||
x (str): first input string.
|
||||
y (None) second input None.
|
||||
|
||||
Returns:
|
||||
bool, return false.
|
||||
|
@ -77,8 +92,8 @@ def _none_equal_string(x, y):
|
|||
Determine if string equals none.
|
||||
|
||||
Args:
|
||||
x: None.
|
||||
y: str.
|
||||
x (None): first input None.
|
||||
y (str): second input string.
|
||||
|
||||
Returns:
|
||||
bool, return false.
|
||||
|
@ -92,8 +107,8 @@ def _none_equal_none(x, y):
|
|||
Determine if none equals none.
|
||||
|
||||
Args:
|
||||
x: None.
|
||||
y: None.
|
||||
x (None): first input None.
|
||||
y (None): second input None.
|
||||
|
||||
Returns:
|
||||
bool, return true.
|
||||
|
@ -107,8 +122,8 @@ def _scalar_equal_none(x, y):
|
|||
Determine if number equals none.
|
||||
|
||||
Args:
|
||||
x: Number.
|
||||
y: None.
|
||||
x (Number): first input number.
|
||||
y (None): second input None.
|
||||
|
||||
Returns:
|
||||
bool, return false.
|
||||
|
@ -122,8 +137,8 @@ def _none_equal_scalar(x, y):
|
|||
Determine if number equals none.
|
||||
|
||||
Args:
|
||||
x: None.
|
||||
y: NUmber.
|
||||
x (None): first input None.
|
||||
y (Number): second input Number.
|
||||
|
||||
Returns:
|
||||
bool, return false.
|
||||
|
@ -137,8 +152,8 @@ def _euqal_tuple(x, y):
|
|||
Determine if two tuples are equal by element.
|
||||
|
||||
Args:
|
||||
x (tuple): x
|
||||
y (tuple): y
|
||||
x (tuple): first input tuple.
|
||||
y (tuple): second input tuple.
|
||||
|
||||
Returns:
|
||||
bool, if x and y are equal by element return true, else return false.
|
||||
|
@ -152,8 +167,8 @@ def _euqal_list(x, y):
|
|||
Determine if two lists are equal by element.
|
||||
|
||||
Args:
|
||||
x (list): x
|
||||
y (list): y
|
||||
x (list): first input list.
|
||||
y (list): second input list.
|
||||
|
||||
Returns:
|
||||
bool, if x and y are equal by element return true, else return false.
|
||||
|
@ -167,8 +182,8 @@ def _tuple_euqal_none(x, y):
|
|||
Determine if tuple element equals none element.
|
||||
|
||||
Args:
|
||||
x: Tuple.
|
||||
y: None.
|
||||
x(tuple): first input tuple.
|
||||
y (None): second input None.
|
||||
|
||||
Returns:
|
||||
bool, return false.
|
||||
|
@ -182,8 +197,8 @@ def _none_equal_tuple(x, y):
|
|||
Determine if tuple element equals none element.
|
||||
|
||||
Args:
|
||||
x: None.
|
||||
y: Tuple.
|
||||
x (None): first input None.
|
||||
y (tuple): second input tuple.
|
||||
|
||||
Returns:
|
||||
bool, return false.
|
||||
|
@ -199,8 +214,8 @@ def _tensor_equal_tensor(x, y):
|
|||
Determine if two tensors are equal.
|
||||
|
||||
Args:
|
||||
x : Tensor.
|
||||
y : Tensor.
|
||||
x (Tensor): first input tensor.
|
||||
y (Tensor): second input tensor.
|
||||
|
||||
Returns:
|
||||
bool, if x == y return true, x != y return false.
|
||||
|
@ -214,8 +229,8 @@ def _tensor_equal_none(x, y):
|
|||
Determine if tensor equal none.
|
||||
|
||||
Args:
|
||||
x : Tensor.
|
||||
y : None.
|
||||
x (Tensor): first input tensor.
|
||||
y (None): second input None.
|
||||
|
||||
Returns:
|
||||
bool, return false.
|
||||
|
@ -229,8 +244,8 @@ def _none_equal_tensor(x, y):
|
|||
Determine if tensor equal none.
|
||||
|
||||
Args:
|
||||
x : None.
|
||||
y : Tensor.
|
||||
x (None): first input None.
|
||||
y (Tensor): second input tensor.
|
||||
|
||||
Returns:
|
||||
bool, return false.
|
||||
|
@ -245,7 +260,7 @@ def _list_equal_none(x, y):
|
|||
|
||||
Args:
|
||||
x (list): The first input which is a list.
|
||||
y (none): The second input which is none.
|
||||
y (None): The second input which is none.
|
||||
|
||||
Returns:
|
||||
bool, return false.
|
||||
|
@ -259,7 +274,7 @@ def _none_equal_list(x, y):
|
|||
Determine if none equal list.
|
||||
|
||||
Args:
|
||||
x (none): The first input which is none.
|
||||
x (None): The first input which is none.
|
||||
y (list): The second input which is a list.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -49,6 +49,7 @@ def test_bool_and_int_tensor_add():
|
|||
ret_expect = Tensor(np.array([[2, 3, 4], [5, 6, 7]], dtype=np.int32))
|
||||
assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all()
|
||||
|
||||
|
||||
def test_float_tensor_and_int_tensor_add():
|
||||
x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||
y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test enumerate"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_equal_two_const_mstype():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.type_base = mstype.float32
|
||||
self.type_0 = mstype.float32
|
||||
self.type_1 = mstype.float16
|
||||
self.type_2 = mstype.int32
|
||||
self.type_3 = mstype.tuple_
|
||||
|
||||
def construct(self):
|
||||
ret_0 = self.type_0 == self.type_base
|
||||
ret_1 = self.type_1 == self.type_base
|
||||
ret_2 = self.type_2 == self.type_base
|
||||
ret_3 = self.type_3 == self.type_base
|
||||
return ret_0, ret_1, ret_2, ret_3
|
||||
|
||||
net = Net()
|
||||
assert net() == (True, False, False, False)
|
||||
|
||||
|
||||
def test_equal_two_tensor_mstype():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
ret_x = x.dtype == mstype.float32
|
||||
ret_y = y.dtype == mstype.int32
|
||||
ret_z = z.dtype == mstype.bool_
|
||||
ret_xy = x.dtype == y.dtype
|
||||
ret_xz = x.dtype == z.dtype
|
||||
ret_yz = y.dtype == z.dtype
|
||||
return ret_x, ret_y, ret_z, ret_xy, ret_xz, ret_yz
|
||||
|
||||
net = Net()
|
||||
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32)
|
||||
y = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.int32)
|
||||
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.bool_)
|
||||
assert net(x, y, z) == (True, True, True, False, False, False)
|
|
@ -96,7 +96,7 @@ def test_float_tensor_and_str_add():
|
|||
y = "ok"
|
||||
with pytest.raises(TypeError) as er:
|
||||
ret = x + y
|
||||
assert "For 'TensorAdd', the 1th input is a not support type: str" in str(er.value)
|
||||
assert "For 'TensorAdd', the 1th input is a not support implicit conversion type: str" in str(er.value)
|
||||
|
||||
|
||||
def test_float_tensor_and_tuple_add():
|
||||
|
@ -104,7 +104,7 @@ def test_float_tensor_and_tuple_add():
|
|||
y = (1, 2, 3)
|
||||
with pytest.raises(TypeError) as er:
|
||||
ret = x + y
|
||||
assert "For 'TensorAdd', the 1th input is a not support type: tuple" in str(er.value)
|
||||
assert "For 'TensorAdd', the 1th input is a not support implicit conversion type: tuple" in str(er.value)
|
||||
|
||||
|
||||
def test_float_tensor_and_list_add():
|
||||
|
@ -112,7 +112,7 @@ def test_float_tensor_and_list_add():
|
|||
y = [1, 2, 3]
|
||||
with pytest.raises(TypeError) as er:
|
||||
ret = x + y
|
||||
assert "For 'TensorAdd', the 1th input is a not support type: list" in str(er.value)
|
||||
assert "For 'TensorAdd', the 1th input is a not support implicit conversion type: list" in str(er.value)
|
||||
|
||||
|
||||
def test_float_tensor_and_bool_tensors_add_grad():
|
||||
|
|
Loading…
Reference in New Issue