!27151 [ME][Fallback] [ME][Fallback] Support built-in type: complex and reversed.

Merge pull request !27151 from Margaret_wangrui/fallback
This commit is contained in:
i-robot 2021-12-07 01:04:14 +00:00 committed by Gitee
commit d723cc417d
3 changed files with 13 additions and 9 deletions

View File

@ -623,7 +623,7 @@ class Parser:
def is_unsupported_builtin_type(self, value_type):
"""To check if not supported builtin type"""
unsupported_builtin_type = (list, tuple, set, dict, slice, bool, int, float, str)
unsupported_builtin_type = (list, tuple, set, dict, slice, bool, int, float, str, complex, reversed)
is_unsupported = value_type in unsupported_builtin_type
logger.debug(f"value_type: {value_type}, unsupported builtin type: {is_unsupported}.")
return is_unsupported

View File

@ -144,6 +144,9 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s
type_proto->set_data_type(irpb::DT_UNDEFINED);
return;
}
if (type->isa<External>()) {
return;
}
CheckIfValidType(type);
if (type->isa<Number>()) {
type_proto->set_data_type(GetNumberDataType(type));

View File

@ -14,8 +14,8 @@
# ============================================================================
""" test graph fallback """
import math
import pytest
from mindspore import ms_function, context
import numpy as np
from mindspore import ms_function, context, Tensor
context.set_context(mode=context.GRAPH_MODE)
@ -99,7 +99,6 @@ def test_fallback_chr():
assert foo() == 'a'
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_fallback_complex():
"""
Feature: JIT Fallback
@ -109,8 +108,11 @@ def test_fallback_complex():
@ms_function
def foo():
x = complex(1, 2)
return x
assert foo() == (1 + 2j)
return Tensor(x)
res = foo()
expect_res = np.array(1 + 2j)
assert isinstance(res, Tensor)
assert np.all(res.asnumpy() == expect_res)
def test_fallback_dict():
@ -258,7 +260,6 @@ def test_fallback_ord():
assert foo() == 97
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_fallback_reversed():
"""
Feature: JIT Fallback
@ -268,8 +269,8 @@ def test_fallback_reversed():
@ms_function
def foo():
x = reversed([1, 2, 3])
return x
assert list(foo()) == [3, 2, 1]
return list(x)
assert foo() == (3, 2, 1)
def test_fallback_round():