From e64c755ad69a41f48d899a709e2bc196dcc34d82 Mon Sep 17 00:00:00 2001 From: kpy Date: Fri, 24 Apr 2020 17:51:55 +0800 Subject: [PATCH] change tensor equal bug --- mindspore/ccsrc/ir/meta_tensor.cc | 9 --------- mindspore/ccsrc/ir/meta_tensor.h | 3 --- mindspore/common/tensor.py | 11 +++++++++++ mindspore/ops/functional.py | 2 ++ tests/vm_impl/math_ops_vm_impl.py | 12 ++++++------ 5 files changed, 19 insertions(+), 18 deletions(-) diff --git a/mindspore/ccsrc/ir/meta_tensor.cc b/mindspore/ccsrc/ir/meta_tensor.cc index fe41abcef42..af6b4f7ffcf 100644 --- a/mindspore/ccsrc/ir/meta_tensor.cc +++ b/mindspore/ccsrc/ir/meta_tensor.cc @@ -185,14 +185,6 @@ bool Tensor::operator==(const Tensor &tensor) const { return (MetaTensor::operator==(tensor) && data_ == tensor.data_); } -bool Tensor::ValueEqualPy(const py::object &other) const { - if (!py::isinstance(other)) { - MS_LOG(WARNING) << "compare other not a tensor"; - return false; - } - return ValueEqual(py::cast(other)); -} - bool Tensor::ValueEqual(const Tensor &other) const { auto equal = [&other, this]() -> bool { auto np = py::module::import("numpy"); @@ -542,7 +534,6 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { )mydelimiter") .def("__str__", &Tensor::ToString) .def("__repr__", &Tensor::ToStringRepr) - .def("__eq__", &Tensor::ValueEqualPy) .def(py::pickle( [](const Tensor &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ diff --git a/mindspore/ccsrc/ir/meta_tensor.h b/mindspore/ccsrc/ir/meta_tensor.h index 1f6c866f11c..ff76a1d4f9d 100644 --- a/mindspore/ccsrc/ir/meta_tensor.h +++ b/mindspore/ccsrc/ir/meta_tensor.h @@ -329,9 +329,6 @@ class Tensor : public MetaTensor { // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. bool ValueEqual(const Tensor &other) const; - // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. - bool ValueEqualPy(const py::object &other) const; - bool operator==(const Value &other) const override { if (other.isa()) { auto other_ = static_cast(other); diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 70b8b169ca1..5504f2b4839 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -74,6 +74,17 @@ class Tensor(Tensor_): out = tensor_operator_registry.get('__add__')(self, other) return out + def __eq__(self, other): + if not isinstance(other, Tensor): + return False + x = self.asnumpy() + y = other.asnumpy() + out = np.equal(x, y) + return Tensor(np.array(out)) + + def __hash__(self): + return hash(id(self)) + def __mul__(self, other): check_type('tensor input_data', other, (Tensor, float, int)) out = tensor_operator_registry.get('__mul__')(self, other) diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 4135133e855..a2473fe7095 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -144,3 +144,5 @@ stop_gradient = Primitive("stop_gradient") tensor_operator_registry.register('__add__', tensor_add) tensor_operator_registry.register('__mul__', tensor_mul) tensor_operator_registry.register('__div__', tensor_div) +#ms cannot support Tensor(True) compare +tensor_operator_registry.register('__eq__', equal) diff --git a/tests/vm_impl/math_ops_vm_impl.py b/tests/vm_impl/math_ops_vm_impl.py index 01df0b824e2..e42ba92d5e8 100644 --- a/tests/vm_impl/math_ops_vm_impl.py +++ b/tests/vm_impl/math_ops_vm_impl.py @@ -172,7 +172,7 @@ def vm_impl_equal(self): x = x.asnumpy() y = y.asnumpy() out = vm.equal(x, y) - return Tensor(out) + return Tensor(np.array(out)) return vm_impl @@ -183,7 +183,7 @@ def vm_impl_not_equal(self): x = x.asnumpy() y = y.asnumpy() out = vm.not_equal(x, y) - return Tensor(out) + return Tensor(np.array(out)) return vm_impl @@ -194,7 +194,7 @@ def vm_impl_greater(self): x = x.asnumpy() y = y.asnumpy() out = vm.greater(x, y) - return Tensor(out) + return Tensor(np.array(out)) return vm_impl @vm_impl_getters.register(P.Maximum) @@ -219,17 +219,17 @@ def vm_impl_minimum(self): return vm_impl @vm_impl_getters.register(P.Less) -def vm_impl_greater(self): +def vm_impl_less(self): """Generate vm_impl function for Less""" def vm_impl(x, y): x = x.asnumpy() y = y.asnumpy() out = vm.less(x, y) - return Tensor(out) + return Tensor(np.array(out)) return vm_impl @vm_impl_getters.register(P.ScalarCast) -def vm_impl_greater(self): +def vm_impl_scalar_cast(self): """Generate vm_impl function for ScalarCast""" def vm_impl(x, t): np_type = dtype_to_nptype(t)