forked from mindspore-Ecosystem/mindspore
change tensor equal bug
This commit is contained in:
parent
00859ae119
commit
e64c755ad6
|
@ -185,14 +185,6 @@ bool Tensor::operator==(const Tensor &tensor) const {
|
|||
return (MetaTensor::operator==(tensor) && data_ == tensor.data_);
|
||||
}
|
||||
|
||||
bool Tensor::ValueEqualPy(const py::object &other) const {
|
||||
if (!py::isinstance<Tensor>(other)) {
|
||||
MS_LOG(WARNING) << "compare other not a tensor";
|
||||
return false;
|
||||
}
|
||||
return ValueEqual(py::cast<Tensor>(other));
|
||||
}
|
||||
|
||||
bool Tensor::ValueEqual(const Tensor &other) const {
|
||||
auto equal = [&other, this]() -> bool {
|
||||
auto np = py::module::import("numpy");
|
||||
|
@ -542,7 +534,6 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
|||
)mydelimiter")
|
||||
.def("__str__", &Tensor::ToString)
|
||||
.def("__repr__", &Tensor::ToStringRepr)
|
||||
.def("__eq__", &Tensor::ValueEqualPy)
|
||||
.def(py::pickle(
|
||||
[](const Tensor &t) { // __getstate__
|
||||
/* Return a tuple that fully encodes the state of the object */
|
||||
|
|
|
@ -329,9 +329,6 @@ class Tensor : public MetaTensor {
|
|||
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
|
||||
bool ValueEqual(const Tensor &other) const;
|
||||
|
||||
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
|
||||
bool ValueEqualPy(const py::object &other) const;
|
||||
|
||||
bool operator==(const Value &other) const override {
|
||||
if (other.isa<Tensor>()) {
|
||||
auto other_ = static_cast<const Tensor &>(other);
|
||||
|
|
|
@ -74,6 +74,17 @@ class Tensor(Tensor_):
|
|||
out = tensor_operator_registry.get('__add__')(self, other)
|
||||
return out
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Tensor):
|
||||
return False
|
||||
x = self.asnumpy()
|
||||
y = other.asnumpy()
|
||||
out = np.equal(x, y)
|
||||
return Tensor(np.array(out))
|
||||
|
||||
def __hash__(self):
|
||||
return hash(id(self))
|
||||
|
||||
def __mul__(self, other):
|
||||
check_type('tensor input_data', other, (Tensor, float, int))
|
||||
out = tensor_operator_registry.get('__mul__')(self, other)
|
||||
|
|
|
@ -144,3 +144,5 @@ stop_gradient = Primitive("stop_gradient")
|
|||
tensor_operator_registry.register('__add__', tensor_add)
|
||||
tensor_operator_registry.register('__mul__', tensor_mul)
|
||||
tensor_operator_registry.register('__div__', tensor_div)
|
||||
#ms cannot support Tensor(True) compare
|
||||
tensor_operator_registry.register('__eq__', equal)
|
||||
|
|
|
@ -172,7 +172,7 @@ def vm_impl_equal(self):
|
|||
x = x.asnumpy()
|
||||
y = y.asnumpy()
|
||||
out = vm.equal(x, y)
|
||||
return Tensor(out)
|
||||
return Tensor(np.array(out))
|
||||
return vm_impl
|
||||
|
||||
|
||||
|
@ -183,7 +183,7 @@ def vm_impl_not_equal(self):
|
|||
x = x.asnumpy()
|
||||
y = y.asnumpy()
|
||||
out = vm.not_equal(x, y)
|
||||
return Tensor(out)
|
||||
return Tensor(np.array(out))
|
||||
return vm_impl
|
||||
|
||||
|
||||
|
@ -194,7 +194,7 @@ def vm_impl_greater(self):
|
|||
x = x.asnumpy()
|
||||
y = y.asnumpy()
|
||||
out = vm.greater(x, y)
|
||||
return Tensor(out)
|
||||
return Tensor(np.array(out))
|
||||
return vm_impl
|
||||
|
||||
@vm_impl_getters.register(P.Maximum)
|
||||
|
@ -219,17 +219,17 @@ def vm_impl_minimum(self):
|
|||
return vm_impl
|
||||
|
||||
@vm_impl_getters.register(P.Less)
|
||||
def vm_impl_greater(self):
|
||||
def vm_impl_less(self):
|
||||
"""Generate vm_impl function for Less"""
|
||||
def vm_impl(x, y):
|
||||
x = x.asnumpy()
|
||||
y = y.asnumpy()
|
||||
out = vm.less(x, y)
|
||||
return Tensor(out)
|
||||
return Tensor(np.array(out))
|
||||
return vm_impl
|
||||
|
||||
@vm_impl_getters.register(P.ScalarCast)
|
||||
def vm_impl_greater(self):
|
||||
def vm_impl_scalar_cast(self):
|
||||
"""Generate vm_impl function for ScalarCast"""
|
||||
def vm_impl(x, t):
|
||||
np_type = dtype_to_nptype(t)
|
||||
|
|
Loading…
Reference in New Issue