From 6edea0ed3c578554674c9ef7b70f919bd3c5b829 Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Mon, 6 Sep 2021 19:41:56 +0800 Subject: [PATCH] Add supported namespace for fallback. --- mindspore/_extends/parse/parser.py | 53 ++++++++++----- .../pipeline/jit/parse/function_block.cc | 2 +- mindspore/ccsrc/pipeline/jit/parse/parse.cc | 4 +- mindspore/ccsrc/pipeline/jit/parse/resolve.cc | 7 +- tests/ut/python/fallback/__init__.py | 21 ++++++ .../ut/python/fallback/test_graph_fallback.py | 66 +++++++++++++++++++ 6 files changed, 134 insertions(+), 19 deletions(-) create mode 100644 tests/ut/python/fallback/__init__.py create mode 100644 tests/ut/python/fallback/test_graph_fallback.py diff --git a/mindspore/_extends/parse/parser.py b/mindspore/_extends/parse/parser.py index 7ee822c5348..19726286e75 100644 --- a/mindspore/_extends/parse/parser.py +++ b/mindspore/_extends/parse/parser.py @@ -529,6 +529,9 @@ class Parser: # Used to resolve mindspore builtin ops namespace. self.ms_common_ns = CellNamespace('mindspore.common') self.ms_ops_ns = CellNamespace('mindspore.ops') + self.ms_ops_c = CellNamespace('mindspore.ops.composite') + self.ms_ops_c_multitype = CellNamespace('mindspore.ops.composite.multitype_ops') + self.ms_ops_p = CellNamespace('mindspore.ops.operations') # Used to resolve the function's globals namespace. self.global_namespace = CellNamespace(fn.__module__) self.function_module = fn.__module__ @@ -549,7 +552,7 @@ class Parser: src = dedent(original_src) self.col_offset = \ len(original_src.split('\n')[0]) - len(src.split('\n')[0]) - logger.debug("get source = %s", src) + logger.debug("Get source = %s", src) try: ast_tokens = asttokens.ASTTokens(src, parse=True) except IndentationError as idt_err: @@ -567,10 +570,10 @@ class Parser: def get_namespace_symbol(self, var: str): """Get symbol type and namespace and symbol.""" if var in self.closure_namespace: - logger.debug(f"found {var} in closure_namespace {self.closure_namespace.__str__()}") + logger.debug(f"Found `{var}` in closure_namespace {self.closure_namespace.__str__()}") return self.closure_namespace, var if var in self.global_namespace: - logger.debug(f"found {var} in global_namespace {self.global_namespace.__str__()}") + logger.debug(f"Found `{var}` in global_namespace {self.global_namespace.__str__()}") value = self.global_namespace[var] if isinstance(value, type(abs)) and self.global_namespace[var] not in convert_object_map: error_info = f"The builtin function '{var}' is not supported in graph mode." @@ -581,43 +584,63 @@ class Parser: def is_unsupported_builtin_type(self, value_type): """To check if not supported builtin type""" - logger.debug(f'value_type: {value_type}, {type([])}, {type(())}') + logger.debug(f'value_type: {value_type}, {type([])}, {type(())}.') return value_type == type([]) or value_type == type(()) def is_supported_namespace_module(self, value): """To check if the module is allowed to support.""" + # Check `mindspore` namespace. if not hasattr(value, '__name__'): + logger.debug(f'`{str(value)}` has no `__name__` attribute.') return True - name = value.__name__ if name == 'mindspore': - logger.debug(f'...found {name} in mindspore root namespace') + logger.debug(f'Found `{name}` in mindspore root namespace.') return True + # Check `builtins` namespace. + if hasattr(value, '__module__'): # Not types.ModuleType + mod = value.__module__ + if mod == 'builtins': + logger.debug(f'Found `{name}` in `builtins` namespace.') + return True + + # We suppose it's supported if not a Module. if not isinstance(value, types.ModuleType): - return False + return True + + # Check supported Module namespace. rightmost_name = name.split('.')[-1] - # if rightmost_name in self.ms_common_ns: - # logger.error(f'...found {module_name} in common namespace: {self.ms_common_ns.__str__()}') - # return True + # By now, we don't check `self.ms_common_ns`. if rightmost_name in self.ms_ops_ns: - logger.debug(f'...found {name}({rightmost_name}) in ops namespace: {self.ms_ops_ns.__str__()}') + logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_ns.__str__()}.') + return True + if rightmost_name in self.ms_ops_c: + logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_c.__str__()}.') + return True + if rightmost_name in self.ms_ops_c_multitype: + logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_c_multitype.__str__()}.') + return True + if rightmost_name in self.ms_ops_p: + logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_p.__str__()}.') return True if rightmost_name in trope_ns: - logger.debug(f'...found {name}({rightmost_name}) in trope namespace: {self.trope_ns.__str__()}') + logger.debug(f'Found `{name}`({rightmost_name}) in trope namespace: {self.trope_ns.__str__()}.') return True + + logger.error(f'Not found `{name}` in mindspore supported namespace.') return False def get_builtin_namespace_symbol(self, var: str): """Get mindspore builtin namespace and symbol.""" if var in self.closure_namespace: - logger.debug(f"found {var} in closure_namespace {self.closure_namespace.__str__()}") + logger.debug(f"Found `{var}` in closure_namespace {self.closure_namespace.__str__()}.") return self.closure_namespace, var if var in self.global_namespace: - logger.debug(f"found {var} in global_namespace {self.global_namespace.__str__()}") + logger.debug(f"Found `{var}` in global_namespace {self.global_namespace.__str__()}.") value = self.global_namespace[var] value_str = value.__name__ if hasattr(value, '__name__') else str(value) - logger.debug(f"value: {type(value)}, : {value_str}, hasattr(__name__): {hasattr(value, '__name__')}") + logger.debug(f"value: {type(value)}, `{value_str}`, hasattr(__name__): {hasattr(value, '__name__')}.") # To check if allowed to support. if self.is_unsupported_builtin_type(value): return self.global_namespace, var, value diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc index 63ad8dca7c4..cc4251ca207 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc @@ -216,7 +216,7 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) { // The fallback feature is enabled in default. // Not support change the flag during the process is alive. - static const auto use_fallback = (parser_.support_fallback() == "0" ? false : true); + static const auto use_fallback = (parser_.support_fallback() != "1" ? false : true); if (!use_fallback) { py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value); return HandleNamespaceInfo(namespace_info); diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index 4a2d82478ea..f7da6b6c455 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -92,7 +92,7 @@ FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr(); Parser::Parser(const std::shared_ptr &ast) : ast_(ast) { max_for_loop_count_str_ = common::GetEnv("ENV_FOR_TO_WHILE_LOOP"); - support_fallback_ = "0"; // We will open it later by call common::GetEnv("ENV_SUPPORT_FALLBACK") + support_fallback_ = common::GetEnv("ENV_SUPPORT_FALLBACK"); errcode_ = PARSE_SUCCESS; BuildMethodMap(); } @@ -1792,7 +1792,7 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP const py::object &value_object) { // The fallback feature is enabled in default. // Not support change the flag during the process is alive. - static const auto use_fallback = (support_fallback_ == "0" ? false : true); + static const auto use_fallback = (support_fallback() != "1" ? false : true); if (!use_fallback) { return value_node; } diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index 0e7e8342913..c3df379b053 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -56,7 +56,12 @@ static inline bool IsSupportedCreateInstanceType(const py::object &obj) { abstract::AbstractBasePtr ClassType::ToAbstract() { auto abs_scalar = std::make_shared(shared_from_base(), std::make_shared()); - if (!IsSupportedCreateInstanceType(obj())) { + + // The fallback feature is enabled in default. + // Not support change the flag during the process is alive. + static const auto support_fallback = common::GetEnv("ENV_SUPPORT_FALLBACK"); + static const auto use_fallback = (support_fallback != "1" ? false : true); + if (use_fallback && !IsSupportedCreateInstanceType(obj())) { return abs_scalar; } AbstractBasePtrList args_spec_list = {abs_scalar}; diff --git a/tests/ut/python/fallback/__init__.py b/tests/ut/python/fallback/__init__.py new file mode 100644 index 00000000000..4dc35b714d1 --- /dev/null +++ b/tests/ut/python/fallback/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""setup for pytest""" +import mindspore.context as context + + +# pylint: disable=unused-argument +def setup_module(module): + context.set_context(mode=context.GRAPH_MODE) diff --git a/tests/ut/python/fallback/test_graph_fallback.py b/tests/ut/python/fallback/test_graph_fallback.py new file mode 100644 index 00000000000..79301933da9 --- /dev/null +++ b/tests/ut/python/fallback/test_graph_fallback.py @@ -0,0 +1,66 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test numpy ops """ +import pytest +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor, ms_function, context +from mindspore.ops import functional as F + + +context.set_context(mode=context.GRAPH_MODE) + +# `add_func` is defined in current file. +def add_func(x, y): + return x + y + +@ms_function +def do_increment(i): + add_1 = F.partial(add_func, 1) + return add_1(i) + +def test_increment(): + a = do_increment(9) + assert a == 10 + + +@ms_function +def np_fallback_func(): + array_x = [2, 3, 4, 5] + np_x = np.array(array_x).astype(np.float32) + me_x = Tensor(np_x) + me_x = me_x + me_x + return me_x + +@pytest.mark.skip(reason='Graph fallback feature is not supported yet') +def test_np_fallback_func(): + print(np_fallback_func()) + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.x = Tensor([2, 3, 4]) + + def construct(self): + x_len = len(self.x) + for i in range(x_len): + print(i) + return x_len + +def test_builtins_len(): + net = Net() + net()