change tensor equal bug
This commit is contained in:
parent
00859ae119
commit
e64c755ad6
|
@ -185,14 +185,6 @@ bool Tensor::operator==(const Tensor &tensor) const {
|
||||||
return (MetaTensor::operator==(tensor) && data_ == tensor.data_);
|
return (MetaTensor::operator==(tensor) && data_ == tensor.data_);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Tensor::ValueEqualPy(const py::object &other) const {
|
|
||||||
if (!py::isinstance<Tensor>(other)) {
|
|
||||||
MS_LOG(WARNING) << "compare other not a tensor";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return ValueEqual(py::cast<Tensor>(other));
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Tensor::ValueEqual(const Tensor &other) const {
|
bool Tensor::ValueEqual(const Tensor &other) const {
|
||||||
auto equal = [&other, this]() -> bool {
|
auto equal = [&other, this]() -> bool {
|
||||||
auto np = py::module::import("numpy");
|
auto np = py::module::import("numpy");
|
||||||
|
@ -542,7 +534,6 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
||||||
)mydelimiter")
|
)mydelimiter")
|
||||||
.def("__str__", &Tensor::ToString)
|
.def("__str__", &Tensor::ToString)
|
||||||
.def("__repr__", &Tensor::ToStringRepr)
|
.def("__repr__", &Tensor::ToStringRepr)
|
||||||
.def("__eq__", &Tensor::ValueEqualPy)
|
|
||||||
.def(py::pickle(
|
.def(py::pickle(
|
||||||
[](const Tensor &t) { // __getstate__
|
[](const Tensor &t) { // __getstate__
|
||||||
/* Return a tuple that fully encodes the state of the object */
|
/* Return a tuple that fully encodes the state of the object */
|
||||||
|
|
|
@ -329,9 +329,6 @@ class Tensor : public MetaTensor {
|
||||||
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
|
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
|
||||||
bool ValueEqual(const Tensor &other) const;
|
bool ValueEqual(const Tensor &other) const;
|
||||||
|
|
||||||
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
|
|
||||||
bool ValueEqualPy(const py::object &other) const;
|
|
||||||
|
|
||||||
bool operator==(const Value &other) const override {
|
bool operator==(const Value &other) const override {
|
||||||
if (other.isa<Tensor>()) {
|
if (other.isa<Tensor>()) {
|
||||||
auto other_ = static_cast<const Tensor &>(other);
|
auto other_ = static_cast<const Tensor &>(other);
|
||||||
|
|
|
@ -74,6 +74,17 @@ class Tensor(Tensor_):
|
||||||
out = tensor_operator_registry.get('__add__')(self, other)
|
out = tensor_operator_registry.get('__add__')(self, other)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, Tensor):
|
||||||
|
return False
|
||||||
|
x = self.asnumpy()
|
||||||
|
y = other.asnumpy()
|
||||||
|
out = np.equal(x, y)
|
||||||
|
return Tensor(np.array(out))
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(id(self))
|
||||||
|
|
||||||
def __mul__(self, other):
|
def __mul__(self, other):
|
||||||
check_type('tensor input_data', other, (Tensor, float, int))
|
check_type('tensor input_data', other, (Tensor, float, int))
|
||||||
out = tensor_operator_registry.get('__mul__')(self, other)
|
out = tensor_operator_registry.get('__mul__')(self, other)
|
||||||
|
|
|
@ -144,3 +144,5 @@ stop_gradient = Primitive("stop_gradient")
|
||||||
tensor_operator_registry.register('__add__', tensor_add)
|
tensor_operator_registry.register('__add__', tensor_add)
|
||||||
tensor_operator_registry.register('__mul__', tensor_mul)
|
tensor_operator_registry.register('__mul__', tensor_mul)
|
||||||
tensor_operator_registry.register('__div__', tensor_div)
|
tensor_operator_registry.register('__div__', tensor_div)
|
||||||
|
#ms cannot support Tensor(True) compare
|
||||||
|
tensor_operator_registry.register('__eq__', equal)
|
||||||
|
|
|
@ -172,7 +172,7 @@ def vm_impl_equal(self):
|
||||||
x = x.asnumpy()
|
x = x.asnumpy()
|
||||||
y = y.asnumpy()
|
y = y.asnumpy()
|
||||||
out = vm.equal(x, y)
|
out = vm.equal(x, y)
|
||||||
return Tensor(out)
|
return Tensor(np.array(out))
|
||||||
return vm_impl
|
return vm_impl
|
||||||
|
|
||||||
|
|
||||||
|
@ -183,7 +183,7 @@ def vm_impl_not_equal(self):
|
||||||
x = x.asnumpy()
|
x = x.asnumpy()
|
||||||
y = y.asnumpy()
|
y = y.asnumpy()
|
||||||
out = vm.not_equal(x, y)
|
out = vm.not_equal(x, y)
|
||||||
return Tensor(out)
|
return Tensor(np.array(out))
|
||||||
return vm_impl
|
return vm_impl
|
||||||
|
|
||||||
|
|
||||||
|
@ -194,7 +194,7 @@ def vm_impl_greater(self):
|
||||||
x = x.asnumpy()
|
x = x.asnumpy()
|
||||||
y = y.asnumpy()
|
y = y.asnumpy()
|
||||||
out = vm.greater(x, y)
|
out = vm.greater(x, y)
|
||||||
return Tensor(out)
|
return Tensor(np.array(out))
|
||||||
return vm_impl
|
return vm_impl
|
||||||
|
|
||||||
@vm_impl_getters.register(P.Maximum)
|
@vm_impl_getters.register(P.Maximum)
|
||||||
|
@ -219,17 +219,17 @@ def vm_impl_minimum(self):
|
||||||
return vm_impl
|
return vm_impl
|
||||||
|
|
||||||
@vm_impl_getters.register(P.Less)
|
@vm_impl_getters.register(P.Less)
|
||||||
def vm_impl_greater(self):
|
def vm_impl_less(self):
|
||||||
"""Generate vm_impl function for Less"""
|
"""Generate vm_impl function for Less"""
|
||||||
def vm_impl(x, y):
|
def vm_impl(x, y):
|
||||||
x = x.asnumpy()
|
x = x.asnumpy()
|
||||||
y = y.asnumpy()
|
y = y.asnumpy()
|
||||||
out = vm.less(x, y)
|
out = vm.less(x, y)
|
||||||
return Tensor(out)
|
return Tensor(np.array(out))
|
||||||
return vm_impl
|
return vm_impl
|
||||||
|
|
||||||
@vm_impl_getters.register(P.ScalarCast)
|
@vm_impl_getters.register(P.ScalarCast)
|
||||||
def vm_impl_greater(self):
|
def vm_impl_scalar_cast(self):
|
||||||
"""Generate vm_impl function for ScalarCast"""
|
"""Generate vm_impl function for ScalarCast"""
|
||||||
def vm_impl(x, t):
|
def vm_impl(x, t):
|
||||||
np_type = dtype_to_nptype(t)
|
np_type = dtype_to_nptype(t)
|
||||||
|
|
Loading…
Reference in New Issue