!27151 [ME][Fallback] [ME][Fallback] Support built-in type: complex and reversed.
Merge pull request !27151 from Margaret_wangrui/fallback
This commit is contained in:
commit
d723cc417d
|
@ -623,7 +623,7 @@ class Parser:
|
|||
|
||||
def is_unsupported_builtin_type(self, value_type):
|
||||
"""To check if not supported builtin type"""
|
||||
unsupported_builtin_type = (list, tuple, set, dict, slice, bool, int, float, str)
|
||||
unsupported_builtin_type = (list, tuple, set, dict, slice, bool, int, float, str, complex, reversed)
|
||||
is_unsupported = value_type in unsupported_builtin_type
|
||||
logger.debug(f"value_type: {value_type}, unsupported builtin type: {is_unsupported}.")
|
||||
return is_unsupported
|
||||
|
|
|
@ -144,6 +144,9 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s
|
|||
type_proto->set_data_type(irpb::DT_UNDEFINED);
|
||||
return;
|
||||
}
|
||||
if (type->isa<External>()) {
|
||||
return;
|
||||
}
|
||||
CheckIfValidType(type);
|
||||
if (type->isa<Number>()) {
|
||||
type_proto->set_data_type(GetNumberDataType(type));
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
# ============================================================================
|
||||
""" test graph fallback """
|
||||
import math
|
||||
import pytest
|
||||
from mindspore import ms_function, context
|
||||
import numpy as np
|
||||
from mindspore import ms_function, context, Tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -99,7 +99,6 @@ def test_fallback_chr():
|
|||
assert foo() == 'a'
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_fallback_complex():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -109,8 +108,11 @@ def test_fallback_complex():
|
|||
@ms_function
|
||||
def foo():
|
||||
x = complex(1, 2)
|
||||
return x
|
||||
assert foo() == (1 + 2j)
|
||||
return Tensor(x)
|
||||
res = foo()
|
||||
expect_res = np.array(1 + 2j)
|
||||
assert isinstance(res, Tensor)
|
||||
assert np.all(res.asnumpy() == expect_res)
|
||||
|
||||
|
||||
def test_fallback_dict():
|
||||
|
@ -258,7 +260,6 @@ def test_fallback_ord():
|
|||
assert foo() == 97
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_fallback_reversed():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -268,8 +269,8 @@ def test_fallback_reversed():
|
|||
@ms_function
|
||||
def foo():
|
||||
x = reversed([1, 2, 3])
|
||||
return x
|
||||
assert list(foo()) == [3, 2, 1]
|
||||
return list(x)
|
||||
assert foo() == (3, 2, 1)
|
||||
|
||||
|
||||
def test_fallback_round():
|
||||
|
|
Loading…
Reference in New Issue