forked from mindspore-Ecosystem/mindspore
Fix bug of mindspore dtype equal
This commit is contained in:
parent
c99cc0dfa1
commit
595767b4b5
|
@ -36,8 +36,12 @@ REGISTER_PYBIND_DEFINE(
|
||||||
(void)m_sub.def("str_to_type", &StringToType, "string to typeptr");
|
(void)m_sub.def("str_to_type", &StringToType, "string to typeptr");
|
||||||
(void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type")
|
(void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type")
|
||||||
.def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_)
|
.def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_)
|
||||||
.def("__eq__",
|
.def("__eq__",
|
||||||
[](const TypePtr &t1, const TypePtr &t2) {
|
[](const TypePtr &t1, const py::object &other) {
|
||||||
|
if (!py::isinstance<Type>(other)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto t2 = py::cast<TypePtr>(other);
|
||||||
if (t1 != nullptr && t2 != nullptr) {
|
if (t1 != nullptr && t2 != nullptr) {
|
||||||
return *t1 == *t2;
|
return *t1 == *t2;
|
||||||
}
|
}
|
||||||
|
|
|
@ -134,3 +134,11 @@ def test_dtype():
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
x = 1.5
|
x = 1.5
|
||||||
dtype.get_py_obj_dtype(type(type(x)))
|
dtype.get_py_obj_dtype(type(type(x)))
|
||||||
|
|
||||||
|
|
||||||
|
def test_type_equal():
|
||||||
|
t1 = (dtype.int32, dtype.int32)
|
||||||
|
valid_types = [dtype.float16, dtype.float32]
|
||||||
|
assert t1 not in valid_types
|
||||||
|
assert dtype.int32 not in valid_types
|
||||||
|
assert dtype.float32 in valid_types
|
||||||
|
|
|
@ -971,7 +971,7 @@ raise_error_set = [
|
||||||
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32)],
|
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32)],
|
||||||
}),
|
}),
|
||||||
('TensorGetItemByMixedTensorsTypeError', {
|
('TensorGetItemByMixedTensorsTypeError', {
|
||||||
'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': TypeError}),
|
'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': IndexError}),
|
||||||
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
|
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
|
||||||
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||||
Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)],
|
Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)],
|
||||||
|
|
Loading…
Reference in New Issue