forked from mindspore-Ecosystem/mindspore
!22986 Add supported namespace for fallback.
Merge pull request !22986 from 张清华/opt_fallback
This commit is contained in:
commit
5cf19e8aa6
|
@ -529,6 +529,9 @@ class Parser:
|
|||
# Used to resolve mindspore builtin ops namespace.
|
||||
self.ms_common_ns = CellNamespace('mindspore.common')
|
||||
self.ms_ops_ns = CellNamespace('mindspore.ops')
|
||||
self.ms_ops_c = CellNamespace('mindspore.ops.composite')
|
||||
self.ms_ops_c_multitype = CellNamespace('mindspore.ops.composite.multitype_ops')
|
||||
self.ms_ops_p = CellNamespace('mindspore.ops.operations')
|
||||
# Used to resolve the function's globals namespace.
|
||||
self.global_namespace = CellNamespace(fn.__module__)
|
||||
self.function_module = fn.__module__
|
||||
|
@ -549,7 +552,7 @@ class Parser:
|
|||
src = dedent(original_src)
|
||||
self.col_offset = \
|
||||
len(original_src.split('\n')[0]) - len(src.split('\n')[0])
|
||||
logger.debug("get source = %s", src)
|
||||
logger.debug("Get source = %s", src)
|
||||
try:
|
||||
ast_tokens = asttokens.ASTTokens(src, parse=True)
|
||||
except IndentationError as idt_err:
|
||||
|
@ -567,10 +570,10 @@ class Parser:
|
|||
def get_namespace_symbol(self, var: str):
|
||||
"""Get symbol type and namespace and symbol."""
|
||||
if var in self.closure_namespace:
|
||||
logger.debug(f"found {var} in closure_namespace {self.closure_namespace.__str__()}")
|
||||
logger.debug(f"Found `{var}` in closure_namespace {self.closure_namespace.__str__()}")
|
||||
return self.closure_namespace, var
|
||||
if var in self.global_namespace:
|
||||
logger.debug(f"found {var} in global_namespace {self.global_namespace.__str__()}")
|
||||
logger.debug(f"Found `{var}` in global_namespace {self.global_namespace.__str__()}")
|
||||
value = self.global_namespace[var]
|
||||
if isinstance(value, type(abs)) and self.global_namespace[var] not in convert_object_map:
|
||||
error_info = f"The builtin function '{var}' is not supported in graph mode."
|
||||
|
@ -587,43 +590,63 @@ class Parser:
|
|||
|
||||
def is_unsupported_builtin_type(self, value_type):
|
||||
"""To check if not supported builtin type"""
|
||||
logger.debug(f'value_type: {value_type}, {type([])}, {type(())}')
|
||||
logger.debug(f'value_type: {value_type}, {type([])}, {type(())}.')
|
||||
return value_type == type([]) or value_type == type(())
|
||||
|
||||
def is_supported_namespace_module(self, value):
|
||||
"""To check if the module is allowed to support."""
|
||||
# Check `mindspore` namespace.
|
||||
if not hasattr(value, '__name__'):
|
||||
logger.debug(f'`{str(value)}` has no `__name__` attribute.')
|
||||
return True
|
||||
|
||||
name = value.__name__
|
||||
if name == 'mindspore':
|
||||
logger.debug(f'...found {name} in mindspore root namespace')
|
||||
logger.debug(f'Found `{name}` in mindspore root namespace.')
|
||||
return True
|
||||
|
||||
# Check `builtins` namespace.
|
||||
if hasattr(value, '__module__'): # Not types.ModuleType
|
||||
mod = value.__module__
|
||||
if mod == 'builtins':
|
||||
logger.debug(f'Found `{name}` in `builtins` namespace.')
|
||||
return True
|
||||
|
||||
# We suppose it's supported if not a Module.
|
||||
if not isinstance(value, types.ModuleType):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Check supported Module namespace.
|
||||
rightmost_name = name.split('.')[-1]
|
||||
# if rightmost_name in self.ms_common_ns:
|
||||
# logger.error(f'...found {module_name} in common namespace: {self.ms_common_ns.__str__()}')
|
||||
# return True
|
||||
# By now, we don't check `self.ms_common_ns`.
|
||||
if rightmost_name in self.ms_ops_ns:
|
||||
logger.debug(f'...found {name}({rightmost_name}) in ops namespace: {self.ms_ops_ns.__str__()}')
|
||||
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_ns.__str__()}.')
|
||||
return True
|
||||
if rightmost_name in self.ms_ops_c:
|
||||
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_c.__str__()}.')
|
||||
return True
|
||||
if rightmost_name in self.ms_ops_c_multitype:
|
||||
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_c_multitype.__str__()}.')
|
||||
return True
|
||||
if rightmost_name in self.ms_ops_p:
|
||||
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_p.__str__()}.')
|
||||
return True
|
||||
if rightmost_name in trope_ns:
|
||||
logger.debug(f'...found {name}({rightmost_name}) in trope namespace: {self.trope_ns.__str__()}')
|
||||
logger.debug(f'Found `{name}`({rightmost_name}) in trope namespace: {self.trope_ns.__str__()}.')
|
||||
return True
|
||||
|
||||
logger.error(f'Not found `{name}` in mindspore supported namespace.')
|
||||
return False
|
||||
|
||||
def get_builtin_namespace_symbol(self, var: str):
|
||||
"""Get mindspore builtin namespace and symbol."""
|
||||
if var in self.closure_namespace:
|
||||
logger.debug(f"found {var} in closure_namespace {self.closure_namespace.__str__()}")
|
||||
logger.debug(f"Found `{var}` in closure_namespace {self.closure_namespace.__str__()}.")
|
||||
return self.closure_namespace, var
|
||||
if var in self.global_namespace:
|
||||
logger.debug(f"found {var} in global_namespace {self.global_namespace.__str__()}")
|
||||
logger.debug(f"Found `{var}` in global_namespace {self.global_namespace.__str__()}.")
|
||||
value = self.global_namespace[var]
|
||||
value_str = value.__name__ if hasattr(value, '__name__') else str(value)
|
||||
logger.debug(f"value: {type(value)}, : {value_str}, hasattr(__name__): {hasattr(value, '__name__')}")
|
||||
logger.debug(f"value: {type(value)}, `{value_str}`, hasattr(__name__): {hasattr(value, '__name__')}.")
|
||||
# To check if allowed to support.
|
||||
if self.is_unsupported_builtin_type(value):
|
||||
return self.global_namespace, var, value
|
||||
|
|
|
@ -216,7 +216,7 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
|
|||
|
||||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto use_fallback = (parser_.support_fallback() == "0" ? false : true);
|
||||
static const auto use_fallback = (parser_.support_fallback() != "1" ? false : true);
|
||||
if (!use_fallback) {
|
||||
py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
|
||||
return HandleNamespaceInfo(namespace_info);
|
||||
|
|
|
@ -92,7 +92,7 @@ FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
|
|||
|
||||
Parser::Parser(const std::shared_ptr<ParseFunctionAst> &ast) : ast_(ast) {
|
||||
max_for_loop_count_str_ = common::GetEnv("ENV_FOR_TO_WHILE_LOOP");
|
||||
support_fallback_ = "0"; // We will open it later by call common::GetEnv("ENV_SUPPORT_FALLBACK")
|
||||
support_fallback_ = common::GetEnv("ENV_SUPPORT_FALLBACK");
|
||||
errcode_ = PARSE_SUCCESS;
|
||||
BuildMethodMap();
|
||||
}
|
||||
|
@ -1792,7 +1792,7 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
|
|||
const py::object &value_object) {
|
||||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto use_fallback = (support_fallback_ == "0" ? false : true);
|
||||
static const auto use_fallback = (support_fallback() != "1" ? false : true);
|
||||
if (!use_fallback) {
|
||||
return value_node;
|
||||
}
|
||||
|
|
|
@ -56,7 +56,12 @@ static inline bool IsSupportedCreateInstanceType(const py::object &obj) {
|
|||
abstract::AbstractBasePtr ClassType::ToAbstract() {
|
||||
auto abs_scalar =
|
||||
std::make_shared<abstract::AbstractScalar>(shared_from_base<ClassType>(), std::make_shared<TypeType>());
|
||||
if (!IsSupportedCreateInstanceType(obj())) {
|
||||
|
||||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto support_fallback = common::GetEnv("ENV_SUPPORT_FALLBACK");
|
||||
static const auto use_fallback = (support_fallback != "1" ? false : true);
|
||||
if (use_fallback && !IsSupportedCreateInstanceType(obj())) {
|
||||
return abs_scalar;
|
||||
}
|
||||
AbstractBasePtrList args_spec_list = {abs_scalar};
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""setup for pytest"""
|
||||
import mindspore.context as context
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def setup_module(module):
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test numpy ops """
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
# `add_func` is defined in current file.
|
||||
def add_func(x, y):
|
||||
return x + y
|
||||
|
||||
@ms_function
|
||||
def do_increment(i):
|
||||
add_1 = F.partial(add_func, 1)
|
||||
return add_1(i)
|
||||
|
||||
def test_increment():
|
||||
a = do_increment(9)
|
||||
assert a == 10
|
||||
|
||||
|
||||
@ms_function
|
||||
def np_fallback_func():
|
||||
array_x = [2, 3, 4, 5]
|
||||
np_x = np.array(array_x).astype(np.float32)
|
||||
me_x = Tensor(np_x)
|
||||
me_x = me_x + me_x
|
||||
return me_x
|
||||
|
||||
@pytest.mark.skip(reason='Graph fallback feature is not supported yet')
|
||||
def test_np_fallback_func():
|
||||
print(np_fallback_func())
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.x = Tensor([2, 3, 4])
|
||||
|
||||
def construct(self):
|
||||
x_len = len(self.x)
|
||||
for i in range(x_len):
|
||||
print(i)
|
||||
return x_len
|
||||
|
||||
def test_builtins_len():
|
||||
net = Net()
|
||||
net()
|
Loading…
Reference in New Issue