!22986 Add supported namespace for fallback.

Merge pull request !22986 from 张清华/opt_fallback
This commit is contained in:
i-robot 2021-09-07 12:50:32 +00:00 committed by Gitee
commit 5cf19e8aa6
6 changed files with 134 additions and 19 deletions

View File

@ -529,6 +529,9 @@ class Parser:
# Used to resolve mindspore builtin ops namespace.
self.ms_common_ns = CellNamespace('mindspore.common')
self.ms_ops_ns = CellNamespace('mindspore.ops')
self.ms_ops_c = CellNamespace('mindspore.ops.composite')
self.ms_ops_c_multitype = CellNamespace('mindspore.ops.composite.multitype_ops')
self.ms_ops_p = CellNamespace('mindspore.ops.operations')
# Used to resolve the function's globals namespace.
self.global_namespace = CellNamespace(fn.__module__)
self.function_module = fn.__module__
@ -549,7 +552,7 @@ class Parser:
src = dedent(original_src)
self.col_offset = \
len(original_src.split('\n')[0]) - len(src.split('\n')[0])
logger.debug("get source = %s", src)
logger.debug("Get source = %s", src)
try:
ast_tokens = asttokens.ASTTokens(src, parse=True)
except IndentationError as idt_err:
@ -567,10 +570,10 @@ class Parser:
def get_namespace_symbol(self, var: str):
"""Get symbol type and namespace and symbol."""
if var in self.closure_namespace:
logger.debug(f"found {var} in closure_namespace {self.closure_namespace.__str__()}")
logger.debug(f"Found `{var}` in closure_namespace {self.closure_namespace.__str__()}")
return self.closure_namespace, var
if var in self.global_namespace:
logger.debug(f"found {var} in global_namespace {self.global_namespace.__str__()}")
logger.debug(f"Found `{var}` in global_namespace {self.global_namespace.__str__()}")
value = self.global_namespace[var]
if isinstance(value, type(abs)) and self.global_namespace[var] not in convert_object_map:
error_info = f"The builtin function '{var}' is not supported in graph mode."
@ -587,43 +590,63 @@ class Parser:
def is_unsupported_builtin_type(self, value_type):
"""To check if not supported builtin type"""
logger.debug(f'value_type: {value_type}, {type([])}, {type(())}')
logger.debug(f'value_type: {value_type}, {type([])}, {type(())}.')
return value_type == type([]) or value_type == type(())
def is_supported_namespace_module(self, value):
"""To check if the module is allowed to support."""
# Check `mindspore` namespace.
if not hasattr(value, '__name__'):
logger.debug(f'`{str(value)}` has no `__name__` attribute.')
return True
name = value.__name__
if name == 'mindspore':
logger.debug(f'...found {name} in mindspore root namespace')
logger.debug(f'Found `{name}` in mindspore root namespace.')
return True
# Check `builtins` namespace.
if hasattr(value, '__module__'): # Not types.ModuleType
mod = value.__module__
if mod == 'builtins':
logger.debug(f'Found `{name}` in `builtins` namespace.')
return True
# We suppose it's supported if not a Module.
if not isinstance(value, types.ModuleType):
return False
return True
# Check supported Module namespace.
rightmost_name = name.split('.')[-1]
# if rightmost_name in self.ms_common_ns:
# logger.error(f'...found {module_name} in common namespace: {self.ms_common_ns.__str__()}')
# return True
# By now, we don't check `self.ms_common_ns`.
if rightmost_name in self.ms_ops_ns:
logger.debug(f'...found {name}({rightmost_name}) in ops namespace: {self.ms_ops_ns.__str__()}')
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_ns.__str__()}.')
return True
if rightmost_name in self.ms_ops_c:
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_c.__str__()}.')
return True
if rightmost_name in self.ms_ops_c_multitype:
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_c_multitype.__str__()}.')
return True
if rightmost_name in self.ms_ops_p:
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_p.__str__()}.')
return True
if rightmost_name in trope_ns:
logger.debug(f'...found {name}({rightmost_name}) in trope namespace: {self.trope_ns.__str__()}')
logger.debug(f'Found `{name}`({rightmost_name}) in trope namespace: {self.trope_ns.__str__()}.')
return True
logger.error(f'Not found `{name}` in mindspore supported namespace.')
return False
def get_builtin_namespace_symbol(self, var: str):
"""Get mindspore builtin namespace and symbol."""
if var in self.closure_namespace:
logger.debug(f"found {var} in closure_namespace {self.closure_namespace.__str__()}")
logger.debug(f"Found `{var}` in closure_namespace {self.closure_namespace.__str__()}.")
return self.closure_namespace, var
if var in self.global_namespace:
logger.debug(f"found {var} in global_namespace {self.global_namespace.__str__()}")
logger.debug(f"Found `{var}` in global_namespace {self.global_namespace.__str__()}.")
value = self.global_namespace[var]
value_str = value.__name__ if hasattr(value, '__name__') else str(value)
logger.debug(f"value: {type(value)}, : {value_str}, hasattr(__name__): {hasattr(value, '__name__')}")
logger.debug(f"value: {type(value)}, `{value_str}`, hasattr(__name__): {hasattr(value, '__name__')}.")
# To check if allowed to support.
if self.is_unsupported_builtin_type(value):
return self.global_namespace, var, value

View File

@ -216,7 +216,7 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
// The fallback feature is enabled in default.
// Not support change the flag during the process is alive.
static const auto use_fallback = (parser_.support_fallback() == "0" ? false : true);
static const auto use_fallback = (parser_.support_fallback() != "1" ? false : true);
if (!use_fallback) {
py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
return HandleNamespaceInfo(namespace_info);

View File

@ -92,7 +92,7 @@ FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
Parser::Parser(const std::shared_ptr<ParseFunctionAst> &ast) : ast_(ast) {
max_for_loop_count_str_ = common::GetEnv("ENV_FOR_TO_WHILE_LOOP");
support_fallback_ = "0"; // We will open it later by call common::GetEnv("ENV_SUPPORT_FALLBACK")
support_fallback_ = common::GetEnv("ENV_SUPPORT_FALLBACK");
errcode_ = PARSE_SUCCESS;
BuildMethodMap();
}
@ -1792,7 +1792,7 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
const py::object &value_object) {
// The fallback feature is enabled in default.
// Not support change the flag during the process is alive.
static const auto use_fallback = (support_fallback_ == "0" ? false : true);
static const auto use_fallback = (support_fallback() != "1" ? false : true);
if (!use_fallback) {
return value_node;
}

View File

@ -56,7 +56,12 @@ static inline bool IsSupportedCreateInstanceType(const py::object &obj) {
abstract::AbstractBasePtr ClassType::ToAbstract() {
auto abs_scalar =
std::make_shared<abstract::AbstractScalar>(shared_from_base<ClassType>(), std::make_shared<TypeType>());
if (!IsSupportedCreateInstanceType(obj())) {
// The fallback feature is enabled in default.
// Not support change the flag during the process is alive.
static const auto support_fallback = common::GetEnv("ENV_SUPPORT_FALLBACK");
static const auto use_fallback = (support_fallback != "1" ? false : true);
if (use_fallback && !IsSupportedCreateInstanceType(obj())) {
return abs_scalar;
}
AbstractBasePtrList args_spec_list = {abs_scalar};

View File

@ -0,0 +1,21 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""setup for pytest"""
import mindspore.context as context
# pylint: disable=unused-argument
def setup_module(module):
context.set_context(mode=context.GRAPH_MODE)

View File

@ -0,0 +1,66 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test numpy ops """
import pytest
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, ms_function, context
from mindspore.ops import functional as F
context.set_context(mode=context.GRAPH_MODE)
# `add_func` is defined in current file.
def add_func(x, y):
return x + y
@ms_function
def do_increment(i):
add_1 = F.partial(add_func, 1)
return add_1(i)
def test_increment():
a = do_increment(9)
assert a == 10
@ms_function
def np_fallback_func():
array_x = [2, 3, 4, 5]
np_x = np.array(array_x).astype(np.float32)
me_x = Tensor(np_x)
me_x = me_x + me_x
return me_x
@pytest.mark.skip(reason='Graph fallback feature is not supported yet')
def test_np_fallback_func():
print(np_fallback_func())
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.x = Tensor([2, 3, 4])
def construct(self):
x_len = len(self.x)
for i in range(x_len):
print(i)
return x_len
def test_builtins_len():
net = Net()
net()