!30695 Support third-party modules

Merge pull request !30695 from huangbingjian/import_module
This commit is contained in:
i-robot 2022-03-21 08:36:39 +00:00 committed by Gitee
commit 4bb2fbdb5f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 185 additions and 94 deletions

View File

@ -21,6 +21,7 @@
#include <mutex>
#include <set>
#include "abstract/abstract_value.h"
#include "pipeline/jit/parse/resolve.h"
#include "pipeline/jit/static_analysis/prim.h"
#include "frontend/operator/ops.h"
#include "utils/symbolic.h"
@ -287,6 +288,20 @@ AbstractBasePtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode,
return possible_func;
}
void CheckInterpretedObject(const AbstractBasePtr &abs) {
static const auto support_fallback = common::GetEnv("MS_DEV_ENABLE_FALLBACK");
static const auto use_fallback = (support_fallback != "0");
if (!use_fallback) {
return;
}
auto value = abs->BuildValue();
if (value->isa<parse::InterpretedObject>()) {
MS_LOG(ERROR) << "Do not support " << value->ToString() << ". "
<< "\nIf you are using third-party modules, you can try setting: "
<< "'export MS_DEV_SUPPORT_MODULES=module1,module2,...'.";
}
}
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(cnode);
@ -298,6 +313,7 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(possible_func);
if (func == nullptr) {
CheckInterpretedObject(possible_func);
MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << possible_func->ToString() << ".";
MS_LOG(ERROR) << "It's called at: " << cnode->DebugString();
MS_EXCEPTION(ValueError) << "This may be not defined, or it can't be a operator. Please check code.";

View File

@ -0,0 +1,51 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Third-party modules."""
# List common third-party modules for JIT Fallback.
# Refer to "https://github.com/eplt/deep-learning-coursera-complete/blob/master/awesome-python-machine-learning.md"
jit_fallback_third_party_modules_whitelist = (
# Python built-in modules.
"datetime", "re", "difflib", "math", "cmath", "random",
# Machine Learning.
"ml_metrics", "nupic", "sklearn", "pyspark", "vowpal_porpoise", "xgboost",
# Natual Language Processing.
"gensim", "jieba", "langid", "nltk", "pattern", "polyglot", "snownlp", "spacy", "textblob", "quepy", "yalign",
"spammy", "genius", "konlpy", "nut", "rosetta", "bllipparser", "pynlpl", "ucto", "frog", "zpar", "colibricore",
"StanfordDependencies", "distance", "thefuzz", "jellyfish", "editdistance", "textacy", "pycorenlp", "cltk",
"rasa_nlu", "drqa", "dedupe",
# Text Processing.
"chardet", "ftfy", "Levenshtein", "pangu", "pyfiglet", "pypinyin", "shortuuid", "unidecode", "uniout",
"xpinyin", "slugify", "phonenumbers", "ply", "pygments", "pyparsing", "nameparser", "user_agents", "sqlparse",
# Web Content Extracting.
"haul", "html2text", "lassie", "micawber", "newspaper", "goose", "readability", "requests_html", "sanitize",
"sumy", "textract", "toapi",
# Web Crawling.
"cola", "demiurge", "feedparser", "grab", "mechanicalsoup", "pyspider", "robobrowser", "scrapy",
# Algorithms and Design Patterns.
"algorithms", "pypattyrn", "sortedcontainers", "scoop",
# Cryptography.
"cryptography", "hashids", "paramiko", "passlib", "nacl",
# Data Analysis (without plotting).
"numpy", "scipy", "blaze", "pandas", "numba", "pymc", "zipline", "pydy", "sympy", "statsmodels", "astropy",
"vincent", "pygal", "pycascading", "emcee", "windml", "vispy", "Dora", "ruffus", "sompy", "somoclu", "hdbscan",
# Computer Vision.
"PIL", "skimage", "SimpleCV", "PCV", "face_recognition",
# General-Purpose Machine Learning.
"cntk", "auto_ml", "xgboost", "featureforge", "scikits", "metric_learn", "simpleai", "bigml", "pylearn2", "keras",
"lasagne", "hebel", "topik", "pybrain", "surprise", "recsys", "thinkbayes", "nilearn", "neuropredict", "pyevolve",
"pyhsmm", "mrjob", "neurolab", "pebl", "yahmm", "timbl", "deap", "deeppy", "mlxtend", "neon", "optunity", "topt",
"pgmpy", "milk", "rep", "rgf", "FukuML", "stacked_generalization", "modAL", "cogitare", "gym",
)

View File

@ -17,11 +17,12 @@
"""The module of parser python object, called by c++."""
import os
import sys
import ast
import hashlib
import inspect
import types
import platform
import importlib
from dataclasses import is_dataclass
from textwrap import dedent
@ -35,6 +36,7 @@ from mindspore.common.api import _MindsporeFunctionExecutor, _convert_data
from mindspore.common.dtype import pytype_to_dtype
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace, ClassAttrNamespace
from .resources import parse_object_map, ops_symbol_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
from .jit_fallback_modules import jit_fallback_third_party_modules_whitelist
# define return value
RET_SUCCESS = 0
@ -152,9 +154,33 @@ def get_bprop_method_of_class(obj, parse_method=None):
method = getattr(obj, method_name)
return method
def get_env_support_modules():
"""Get support modules from environment variable."""
support_modules = os.getenv('MS_DEV_SUPPORT_MODULES')
if support_modules is None:
return []
env_support_modules = []
modules = support_modules.split(',')
for module in modules:
try:
module_spec = importlib.util.find_spec(module)
except (ModuleNotFoundError, ValueError):
module = module[0:module.rfind('.')]
module_spec = importlib.util.find_spec(module)
if module_spec is None:
raise ModuleNotFoundError(f"Cannot find module: {module}. " \
f"Please check if {module} is installed, or if MS_DEV_SUPPORT_MODULES is set correctly.")
# Add the outermost module.
env_support_modules.append(module.split('.')[0])
logger.debug(f"Get support modules from env: {env_support_modules}")
return env_support_modules
# The fallback feature is enabled in default.
# Not support change the flag during the process is alive.
support_fallback_ = os.getenv('MS_DEV_ENABLE_FALLBACK')
support_modules_ = get_env_support_modules()
def resolve_symbol(namespace, symbol):
@ -571,6 +597,41 @@ def get_args(node):
return args
def _in_sys_path(file_path):
for path in list(sys.path):
if file_path.startswith(path):
return True
return False
def is_third_party_module(value):
"""To check if value is a third-party module."""
# Check if value is a module or package.
if not inspect.ismodule(value) or not hasattr(value, '__file__'):
return False
# Check if module file is under the sys path.
module_file = value.__file__
if not _in_sys_path(module_file):
return False
# Get module leftmost name.
if not hasattr(value, '__name__'):
return False
module_name = value.__name__
module_leftmost_name = module_name.split('.')[0]
# Ignore mindspore package.
if module_leftmost_name == "mindspore":
return False
# Check if module is in whitelist.
if module_leftmost_name in support_modules_:
logger.debug(f"Found support modules from env: {module_name}")
return True
if module_leftmost_name in jit_fallback_third_party_modules_whitelist:
logger.debug(f"Found third-party module: {module_name}")
return True
return False
def eval_script(exp_str, params):
"""Evaluate a python expression."""
if not isinstance(params, tuple):
@ -614,17 +675,6 @@ class Parser:
self.line_offset = 0
self.filename: str = inspect.getfile(inspect.unwrap(self.fn))
# Used to resolve mindspore builtin ops namespace.
self.ms_common_ns = CellNamespace('mindspore.common')
self.ms_nn_ns = CellNamespace('mindspore.nn')
self.ms_ops_ns = CellNamespace('mindspore.ops')
self.ms_ops_c_ns = CellNamespace('mindspore.ops.composite')
self.ms_ops_c_multitype_ns = CellNamespace('mindspore.ops.composite.multitype_ops')
self.ms_ops_p_ns = CellNamespace('mindspore.ops.operations')
if platform.system().lower() != 'windows':
self.ms_scipy_ns = CellNamespace('mindspore.scipy')
else:
self.ms_scipy_ns = {}
# Used to resolve the function's globals namespace.
self.global_namespace = CellNamespace(fn.__module__)
self.function_module = fn.__module__
@ -720,86 +770,6 @@ class Parser:
error_info = f"The name '{var}' is not defined in function '{self.function_name}'."
return None, error_info
def is_rightmost_name_in_namespace_module(self, name):
"""Check supported Module namespace."""
rightmost_name = name.split('.')[-1]
if rightmost_name in self.ms_ops_ns:
logger.debug(f"Found '{name}'({rightmost_name}) in ops namespace: {str(self.ms_ops_ns)}.")
return True
if rightmost_name in self.ms_ops_c_ns:
logger.debug(f"Found '{name}'({rightmost_name}) in C namespace: {str(self.ms_ops_c_ns)}.")
return True
if rightmost_name in self.ms_ops_c_multitype_ns:
logger.debug(
f"Found '{name}'({rightmost_name}) in C.multitype namespace: {str(self.ms_ops_c_multitype_ns)}.")
return True
if rightmost_name in self.ms_ops_p_ns:
logger.debug(f"Found '{name}'({rightmost_name}) in P namespace: {str(self.ms_ops_p_ns)}.")
return True
if rightmost_name in self.ms_common_ns:
logger.debug(f"Found '{name}'({rightmost_name}) in common namespace: {str(self.ms_common_ns)}.")
return True
# Support nn.layer. To check if exclude other module.
if rightmost_name in self.ms_nn_ns:
logger.debug(f"Found '{name}'({rightmost_name}) in nn namespace: {str(self.ms_nn_ns)}.")
return True
if rightmost_name in self.ms_scipy_ns:
logger.debug(f"Found '{name}'({rightmost_name}) in scipy namespace: {str(self.ms_scipy_ns)}.")
return True
if rightmost_name in trope_ns:
logger.debug(f"Found '{name}'({rightmost_name}) in trope namespace: {str(trope_ns)}.")
return True
return False
def is_supported_namespace_module(self, value):
"""To check if the module is allowed to support."""
# Check `mindspore` namespace.
if not hasattr(value, '__name__'):
logger.debug(f"'{str(value)}' has no '__name__' attribute, we suppose it's supported.")
return True
name = value.__name__
if name == 'mindspore':
logger.debug(f"Found 'mindspore' root namespace.")
return True
if name == 'mindspore.ops':
logger.debug(f"Found 'mindspore.ops' namespace.")
return True
if name == 'mindspore.nn':
logger.debug(f"Found 'mindspore.nn' namespace.")
return True
if name == 'mindspore.numpy':
logger.debug(f"Found 'mindspore.numpy' namespace.")
return True
if platform.system().lower() != 'windows' and name == 'mindspore.scipy':
logger.debug(f"Found 'mindspore.scipy' namespace.")
return True
if name == 'mindspore.context':
logger.debug(f"Found 'mindspore.context' namespace.")
return True
if name == 'functools':
logger.debug(f"Found 'functools' namespace.")
return True
# Check `builtins` namespace.
if hasattr(value, '__module__'): # Not types.ModuleType
mod = value.__module__
if mod == 'builtins':
logger.debug(f"Found '{name}' in 'builtins' namespace.")
return True
# We suppose it's supported if not a Module.
if not isinstance(value, types.ModuleType):
logger.debug(f"Found '{name}', not a module.")
return True
# Check supported Module namespace.
if self.is_rightmost_name_in_namespace_module(name):
return True
logger.info(f"Not found '{name}' in mindspore supported namespace.")
return False
def get_builtin_namespace_symbol(self, var: str):
"""Get mindspore builtin namespace and symbol."""
if var in self.closure_namespace:
@ -817,7 +787,7 @@ class Parser:
support_info = self.global_namespace, var, value, SYNTAX_UNSUPPORTED_EXTERNAL_TYPE
elif self.is_unsupported_special_type(value):
support_info = self.global_namespace, var, value, SYNTAX_UNSUPPORTED_SPECIAL_TYPE
elif self.is_unsupported_namespace(value) or not self.is_supported_namespace_module(value):
elif self.is_unsupported_namespace(value) or is_third_party_module(value):
support_info = self.global_namespace, var, value, SYNTAX_UNSUPPORTED_NAMESPACE
else:
support_info = self.global_namespace, var, value, SYNTAX_SUPPORTED

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
""" test graph fallback """
import functools
import pytest
import numpy as np
@ -272,3 +273,22 @@ def test_probability_cauchy():
net = CauchyProb(loc, scale)
net(Tensor(value), Tensor(loc_a), Tensor(scale_a))
def test_third_party_module_functools():
"""
Feature: JIT Fallback
Description: functools is a python built-in module and does not perform JIT Fallback.
Expectation: No exception.
"""
class ModuleNet(nn.Cell):
def construct(self, x, y):
func = functools.partial(add_func, x)
out = func(y)
return out
x = Tensor([1, 2, 3], mstype.int32)
y = Tensor([4, 5, 6], mstype.int32)
net = ModuleNet()
out = net(x, y)
print(out)

View File

@ -18,7 +18,8 @@ import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, context, ms_class
from mindspore import Tensor, context, ms_class, ms_function
from . import test_graph_fallback
context.set_context(mode=context.GRAPH_MODE)
@ -137,6 +138,23 @@ def test_fallback_self_method_tensor():
print(out)
def test_fallback_import_modules():
"""
Feature: JIT Fallback
Description: add_func is defined in test_graph_fallback.py
Expectation: No exception.
"""
@ms_function
def use_imported_module(x, y):
out = test_graph_fallback.add_func(x, y)
return out
x = Tensor(2, dtype=mstype.int32)
y = Tensor(3, dtype=mstype.int32)
out = use_imported_module(x, y)
print(out)
def test_fallback_class_attr():
"""
Feature: JIT Fallback

View File

@ -15,6 +15,7 @@
""" test graph fallback """
import pytest
import numpy as np
import numpy.random as rand
from mindspore import ms_function, context, Tensor
context.set_context(mode=context.GRAPH_MODE)
@ -547,3 +548,18 @@ def test_np_slice():
return Tensor(b)
res = np_slice()
assert np.all(res.asnumpy() == np.array([1, 2, 3, 4]))
def test_np_random():
"""
Feature: JIT Fallback
Description: Test numpy.random module in graph mode.
Expectation: No exception.
"""
@ms_function
def np_random():
a = rand.randint(100, size=(5))
b = a[1:5]
return Tensor(b)
res = np_random()
print(res)