forked from mindspore-Ecosystem/mindspore
!1804 add tensor compare & len & constexpr feature
Merge pull request !1804 from wangqiuliang/add-tensor-compare-len-consexpr-feature
This commit is contained in:
commit
5d397d8404
|
@ -531,6 +531,11 @@ py::tuple RunOp(const py::args &args) {
|
|||
value_ret[0] = output["value"];
|
||||
return value_ret;
|
||||
}
|
||||
if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) {
|
||||
py::tuple value_ret(1);
|
||||
value_ret[0] = "";
|
||||
return value_ret;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
|
||||
mindspore::parse::python_adapter::set_python_env_flag(true);
|
||||
|
|
|
@ -71,19 +71,18 @@ class Tensor(Tensor_):
|
|||
return str(self.__str__())
|
||||
|
||||
def __add__(self, other):
|
||||
check_type('tensor input_data', other, (Tensor, float, int))
|
||||
out = tensor_operator_registry.get('__add__')(self, other)
|
||||
return out
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Tensor):
|
||||
return False
|
||||
return Tensor(np.array(self.asnumpy() == other.asnumpy()))
|
||||
return tensor_operator_registry.get('__eq__')(self, other)
|
||||
|
||||
def __ne__(self, other):
|
||||
if not isinstance(other, Tensor):
|
||||
return True
|
||||
return Tensor(np.array(self.asnumpy() != other.asnumpy()))
|
||||
return tensor_operator_registry.get('__ne__')(self, other)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(id(self))
|
||||
|
@ -93,7 +92,8 @@ class Tensor(Tensor_):
|
|||
return out
|
||||
|
||||
def __neg__(self):
|
||||
return Tensor(-self.asnumpy())
|
||||
out = tensor_operator_registry.get('__neg__')(self)
|
||||
return out
|
||||
|
||||
def __iadd__(self, other):
|
||||
out = self.__add__(other)
|
||||
|
@ -120,7 +120,7 @@ class Tensor(Tensor_):
|
|||
return out
|
||||
|
||||
def __sub__(self, other):
|
||||
out = self.__add__(-other)
|
||||
out = tensor_operator_registry.get('__sub__')(self, other)
|
||||
return out
|
||||
|
||||
def __isub__(self, other):
|
||||
|
@ -128,9 +128,31 @@ class Tensor(Tensor_):
|
|||
return out
|
||||
|
||||
def __rsub__(self, other):
|
||||
out = tensor_operator_registry.get('__add__')(other, Tensor(-self.asnumpy()))
|
||||
out = tensor_operator_registry.get('__sub__')(other, self)
|
||||
return out
|
||||
|
||||
def __lt__(self, other):
|
||||
out = tensor_operator_registry.get('__lt__')(self, other)
|
||||
return out
|
||||
|
||||
def __le__(self, other):
|
||||
out = tensor_operator_registry.get('__le__')(self, other)
|
||||
return out
|
||||
|
||||
def __gt__(self, other):
|
||||
out = tensor_operator_registry.get('__gt__')(self, other)
|
||||
return out
|
||||
|
||||
def __ge__(self, other):
|
||||
out = tensor_operator_registry.get('__ge__')(self, other)
|
||||
return out
|
||||
|
||||
def __len__(self):
|
||||
out = tensor_operator_registry.get('__shape__')(self)
|
||||
if not out:
|
||||
return 1
|
||||
return out[0]
|
||||
|
||||
def __str__(self):
|
||||
if self.dtype() == mstype.type_none:
|
||||
return "Unknown Tensor type!"
|
||||
|
|
|
@ -151,7 +151,15 @@ shape_mul = Primitive("shape_mul")
|
|||
stop_gradient = Primitive("stop_gradient")
|
||||
|
||||
tensor_operator_registry.register('__add__', tensor_add)
|
||||
tensor_operator_registry.register('__sub__', tensor_sub)
|
||||
tensor_operator_registry.register('__mul__', tensor_mul)
|
||||
tensor_operator_registry.register('__div__', tensor_div)
|
||||
#ms cannot support Tensor(True) compare
|
||||
tensor_operator_registry.register('__eq__', equal)
|
||||
tensor_operator_registry.register('__ne__', not_equal)
|
||||
tensor_operator_registry.register('__neg__', neg_tensor)
|
||||
tensor_operator_registry.register('__lt__', tensor_lt)
|
||||
tensor_operator_registry.register('__le__', tensor_le)
|
||||
tensor_operator_registry.register('__gt__', tensor_gt)
|
||||
tensor_operator_registry.register('__ge__', tensor_ge)
|
||||
tensor_operator_registry.register('__shape__', shape)
|
||||
|
|
|
@ -310,6 +310,7 @@ def constexpr(fn=None, get_instance=True, name=None):
|
|||
def __init__(self):
|
||||
op_name = name if name else fn.__name__
|
||||
PrimitiveWithInfer.__init__(self, op_name)
|
||||
self.const_value = True
|
||||
|
||||
def infer_value(self, *args):
|
||||
return fn(*args)
|
||||
|
|
|
@ -29,6 +29,7 @@ from mindspore._extends.parse.standard_method import ms_len
|
|||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops.composite import core
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from ..ut_filter import non_graph_engine
|
||||
|
||||
|
||||
|
@ -417,3 +418,11 @@ def test_range():
|
|||
""" test_range """
|
||||
res = range_spec(10, 10)
|
||||
return res
|
||||
|
||||
def test_expr():
|
||||
""" test const expr """
|
||||
a = (1, 2)
|
||||
@constexpr
|
||||
def tuple_len(x):
|
||||
assert len(x) == 2
|
||||
tuple_len(a)
|
||||
|
|
Loading…
Reference in New Issue