change tensor equal bug

This commit is contained in:
kpy 2020-04-24 17:51:55 +08:00
parent 00859ae119
commit e64c755ad6
5 changed files with 19 additions and 18 deletions

View File

@ -185,14 +185,6 @@ bool Tensor::operator==(const Tensor &tensor) const {
return (MetaTensor::operator==(tensor) && data_ == tensor.data_);
}
bool Tensor::ValueEqualPy(const py::object &other) const {
if (!py::isinstance<Tensor>(other)) {
MS_LOG(WARNING) << "compare other not a tensor";
return false;
}
return ValueEqual(py::cast<Tensor>(other));
}
bool Tensor::ValueEqual(const Tensor &other) const {
auto equal = [&other, this]() -> bool {
auto np = py::module::import("numpy");
@ -542,7 +534,6 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
)mydelimiter")
.def("__str__", &Tensor::ToString)
.def("__repr__", &Tensor::ToStringRepr)
.def("__eq__", &Tensor::ValueEqualPy)
.def(py::pickle(
[](const Tensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */

View File

@ -329,9 +329,6 @@ class Tensor : public MetaTensor {
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
bool ValueEqual(const Tensor &other) const;
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
bool ValueEqualPy(const py::object &other) const;
bool operator==(const Value &other) const override {
if (other.isa<Tensor>()) {
auto other_ = static_cast<const Tensor &>(other);

View File

@ -74,6 +74,17 @@ class Tensor(Tensor_):
out = tensor_operator_registry.get('__add__')(self, other)
return out
def __eq__(self, other):
if not isinstance(other, Tensor):
return False
x = self.asnumpy()
y = other.asnumpy()
out = np.equal(x, y)
return Tensor(np.array(out))
def __hash__(self):
return hash(id(self))
def __mul__(self, other):
check_type('tensor input_data', other, (Tensor, float, int))
out = tensor_operator_registry.get('__mul__')(self, other)

View File

@ -144,3 +144,5 @@ stop_gradient = Primitive("stop_gradient")
tensor_operator_registry.register('__add__', tensor_add)
tensor_operator_registry.register('__mul__', tensor_mul)
tensor_operator_registry.register('__div__', tensor_div)
#ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)

View File

@ -172,7 +172,7 @@ def vm_impl_equal(self):
x = x.asnumpy()
y = y.asnumpy()
out = vm.equal(x, y)
return Tensor(out)
return Tensor(np.array(out))
return vm_impl
@ -183,7 +183,7 @@ def vm_impl_not_equal(self):
x = x.asnumpy()
y = y.asnumpy()
out = vm.not_equal(x, y)
return Tensor(out)
return Tensor(np.array(out))
return vm_impl
@ -194,7 +194,7 @@ def vm_impl_greater(self):
x = x.asnumpy()
y = y.asnumpy()
out = vm.greater(x, y)
return Tensor(out)
return Tensor(np.array(out))
return vm_impl
@vm_impl_getters.register(P.Maximum)
@ -219,17 +219,17 @@ def vm_impl_minimum(self):
return vm_impl
@vm_impl_getters.register(P.Less)
def vm_impl_greater(self):
def vm_impl_less(self):
"""Generate vm_impl function for Less"""
def vm_impl(x, y):
x = x.asnumpy()
y = y.asnumpy()
out = vm.less(x, y)
return Tensor(out)
return Tensor(np.array(out))
return vm_impl
@vm_impl_getters.register(P.ScalarCast)
def vm_impl_greater(self):
def vm_impl_scalar_cast(self):
"""Generate vm_impl function for ScalarCast"""
def vm_impl(x, t):
np_type = dtype_to_nptype(t)