Fix bug of mindspore dtype equal

This commit is contained in:
fary86 2020-07-16 04:26:44 +08:00
parent c99cc0dfa1
commit 595767b4b5
3 changed files with 15 additions and 3 deletions

View File

@ -36,8 +36,12 @@ REGISTER_PYBIND_DEFINE(
(void)m_sub.def("str_to_type", &StringToType, "string to typeptr");
(void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type")
.def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_)
.def("__eq__",
[](const TypePtr &t1, const TypePtr &t2) {
.def("__eq__",
[](const TypePtr &t1, const py::object &other) {
if (!py::isinstance<Type>(other)) {
return false;
}
auto t2 = py::cast<TypePtr>(other);
if (t1 != nullptr && t2 != nullptr) {
return *t1 == *t2;
}

View File

@ -134,3 +134,11 @@ def test_dtype():
with pytest.raises(NotImplementedError):
x = 1.5
dtype.get_py_obj_dtype(type(type(x)))
def test_type_equal():
t1 = (dtype.int32, dtype.int32)
valid_types = [dtype.float16, dtype.float32]
assert t1 not in valid_types
assert dtype.int32 not in valid_types
assert dtype.float32 in valid_types

View File

@ -971,7 +971,7 @@ raise_error_set = [
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsTypeError', {
'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': TypeError}),
'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)],