support bool tensor and bool do equal

This commit is contained in:
buxue 2020-08-12 18:12:33 +08:00
parent 363fbb7a79
commit 9c7cdcf1e9
1 changed files with 3 additions and 1 deletions

View File

@ -85,7 +85,9 @@ class Tensor(Tensor_):
return False
# bool type is not supported for `Equal` operator in backend.
if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
return Tensor(np.array(self.asnumpy() == other.asnumpy()))
if isinstance(other, Tensor):
return Tensor(np.array(self.asnumpy() == other.asnumpy()))
return Tensor(np.array(self.asnumpy() == other))
return tensor_operator_registry.get('__eq__')(self, other)
def __ne__(self, other):