diff --git a/mindspore/core/ir/dtype_py.cc b/mindspore/core/ir/dtype_py.cc index 66bd8ba5f6f..b1e2151b6dd 100644 --- a/mindspore/core/ir/dtype_py.cc +++ b/mindspore/core/ir/dtype_py.cc @@ -36,8 +36,12 @@ REGISTER_PYBIND_DEFINE( (void)m_sub.def("str_to_type", &StringToType, "string to typeptr"); (void)py::class_>(m_sub, "Type") .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) - .def("__eq__", - [](const TypePtr &t1, const TypePtr &t2) { + .def("__eq__", + [](const TypePtr &t1, const py::object &other) { + if (!py::isinstance(other)) { + return false; + } + auto t2 = py::cast(other); if (t1 != nullptr && t2 != nullptr) { return *t1 == *t2; } diff --git a/tests/ut/python/ir/test_dtype.py b/tests/ut/python/ir/test_dtype.py index 1523a77ea39..49f834092e0 100644 --- a/tests/ut/python/ir/test_dtype.py +++ b/tests/ut/python/ir/test_dtype.py @@ -134,3 +134,11 @@ def test_dtype(): with pytest.raises(NotImplementedError): x = 1.5 dtype.get_py_obj_dtype(type(type(x))) + + +def test_type_equal(): + t1 = (dtype.int32, dtype.int32) + valid_types = [dtype.float16, dtype.float32] + assert t1 not in valid_types + assert dtype.int32 not in valid_types + assert dtype.float32 in valid_types diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index 66590945da7..ec8b9957a22 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -971,7 +971,7 @@ raise_error_set = [ Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32)], }), ('TensorGetItemByMixedTensorsTypeError', { - 'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': TypeError}), + 'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': IndexError}), 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32), Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)],