forked from mindspore-Ecosystem/mindspore
Fix bug of mindspore dtype equal
This commit is contained in:
parent
c99cc0dfa1
commit
595767b4b5
|
@ -36,8 +36,12 @@ REGISTER_PYBIND_DEFINE(
|
|||
(void)m_sub.def("str_to_type", &StringToType, "string to typeptr");
|
||||
(void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type")
|
||||
.def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_)
|
||||
.def("__eq__",
|
||||
[](const TypePtr &t1, const TypePtr &t2) {
|
||||
.def("__eq__",
|
||||
[](const TypePtr &t1, const py::object &other) {
|
||||
if (!py::isinstance<Type>(other)) {
|
||||
return false;
|
||||
}
|
||||
auto t2 = py::cast<TypePtr>(other);
|
||||
if (t1 != nullptr && t2 != nullptr) {
|
||||
return *t1 == *t2;
|
||||
}
|
||||
|
|
|
@ -134,3 +134,11 @@ def test_dtype():
|
|||
with pytest.raises(NotImplementedError):
|
||||
x = 1.5
|
||||
dtype.get_py_obj_dtype(type(type(x)))
|
||||
|
||||
|
||||
def test_type_equal():
|
||||
t1 = (dtype.int32, dtype.int32)
|
||||
valid_types = [dtype.float16, dtype.float32]
|
||||
assert t1 not in valid_types
|
||||
assert dtype.int32 not in valid_types
|
||||
assert dtype.float32 in valid_types
|
||||
|
|
|
@ -971,7 +971,7 @@ raise_error_set = [
|
|||
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensorsTypeError', {
|
||||
'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': TypeError}),
|
||||
'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': IndexError}),
|
||||
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)],
|
||||
|
|
Loading…
Reference in New Issue