!1804 add tensor compare & len & constexpr feature

Merge pull request !1804 from wangqiuliang/add-tensor-compare-len-consexpr-feature
This commit is contained in:
mindspore-ci-bot 2020-06-03 15:41:00 +08:00 committed by Gitee
commit 5d397d8404
5 changed files with 51 additions and 6 deletions

View File

@ -531,6 +531,11 @@ py::tuple RunOp(const py::args &args) {
value_ret[0] = output["value"];
return value_ret;
}
if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) {
py::tuple value_ret(1);
value_ret[0] = "";
return value_ret;
}
}
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
mindspore::parse::python_adapter::set_python_env_flag(true);

View File

@ -71,19 +71,18 @@ class Tensor(Tensor_):
return str(self.__str__())
def __add__(self, other):
check_type('tensor input_data', other, (Tensor, float, int))
out = tensor_operator_registry.get('__add__')(self, other)
return out
def __eq__(self, other):
if not isinstance(other, Tensor):
return False
return Tensor(np.array(self.asnumpy() == other.asnumpy()))
return tensor_operator_registry.get('__eq__')(self, other)
def __ne__(self, other):
if not isinstance(other, Tensor):
return True
return Tensor(np.array(self.asnumpy() != other.asnumpy()))
return tensor_operator_registry.get('__ne__')(self, other)
def __hash__(self):
return hash(id(self))
@ -93,7 +92,8 @@ class Tensor(Tensor_):
return out
def __neg__(self):
return Tensor(-self.asnumpy())
out = tensor_operator_registry.get('__neg__')(self)
return out
def __iadd__(self, other):
out = self.__add__(other)
@ -120,7 +120,7 @@ class Tensor(Tensor_):
return out
def __sub__(self, other):
out = self.__add__(-other)
out = tensor_operator_registry.get('__sub__')(self, other)
return out
def __isub__(self, other):
@ -128,9 +128,31 @@ class Tensor(Tensor_):
return out
def __rsub__(self, other):
out = tensor_operator_registry.get('__add__')(other, Tensor(-self.asnumpy()))
out = tensor_operator_registry.get('__sub__')(other, self)
return out
def __lt__(self, other):
out = tensor_operator_registry.get('__lt__')(self, other)
return out
def __le__(self, other):
out = tensor_operator_registry.get('__le__')(self, other)
return out
def __gt__(self, other):
out = tensor_operator_registry.get('__gt__')(self, other)
return out
def __ge__(self, other):
out = tensor_operator_registry.get('__ge__')(self, other)
return out
def __len__(self):
out = tensor_operator_registry.get('__shape__')(self)
if not out:
return 1
return out[0]
def __str__(self):
if self.dtype() == mstype.type_none:
return "Unknown Tensor type!"

View File

@ -151,7 +151,15 @@ shape_mul = Primitive("shape_mul")
stop_gradient = Primitive("stop_gradient")
tensor_operator_registry.register('__add__', tensor_add)
tensor_operator_registry.register('__sub__', tensor_sub)
tensor_operator_registry.register('__mul__', tensor_mul)
tensor_operator_registry.register('__div__', tensor_div)
#ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal)
tensor_operator_registry.register('__neg__', neg_tensor)
tensor_operator_registry.register('__lt__', tensor_lt)
tensor_operator_registry.register('__le__', tensor_le)
tensor_operator_registry.register('__gt__', tensor_gt)
tensor_operator_registry.register('__ge__', tensor_ge)
tensor_operator_registry.register('__shape__', shape)

View File

@ -310,6 +310,7 @@ def constexpr(fn=None, get_instance=True, name=None):
def __init__(self):
op_name = name if name else fn.__name__
PrimitiveWithInfer.__init__(self, op_name)
self.const_value = True
def infer_value(self, *args):
return fn(*args)

View File

@ -29,6 +29,7 @@ from mindspore._extends.parse.standard_method import ms_len
from mindspore.common.api import ms_function
from mindspore.common.tensor import Tensor
from mindspore.ops.composite import core
from mindspore.ops.primitive import constexpr
from ..ut_filter import non_graph_engine
@ -417,3 +418,11 @@ def test_range():
""" test_range """
res = range_spec(10, 10)
return res
def test_expr():
""" test const expr """
a = (1, 2)
@constexpr
def tuple_len(x):
assert len(x) == 2
tuple_len(a)