forked from mindspore-Ecosystem/mindspore
!30695 Support third-party modules
Merge pull request !30695 from huangbingjian/import_module
This commit is contained in:
commit
4bb2fbdb5f
|
@ -21,6 +21,7 @@
|
|||
#include <mutex>
|
||||
#include <set>
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "pipeline/jit/static_analysis/prim.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "utils/symbolic.h"
|
||||
|
@ -287,6 +288,20 @@ AbstractBasePtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode,
|
|||
return possible_func;
|
||||
}
|
||||
|
||||
void CheckInterpretedObject(const AbstractBasePtr &abs) {
|
||||
static const auto support_fallback = common::GetEnv("MS_DEV_ENABLE_FALLBACK");
|
||||
static const auto use_fallback = (support_fallback != "0");
|
||||
if (!use_fallback) {
|
||||
return;
|
||||
}
|
||||
auto value = abs->BuildValue();
|
||||
if (value->isa<parse::InterpretedObject>()) {
|
||||
MS_LOG(ERROR) << "Do not support " << value->ToString() << ". "
|
||||
<< "\nIf you are using third-party modules, you can try setting: "
|
||||
<< "'export MS_DEV_SUPPORT_MODULES=module1,module2,...'.";
|
||||
}
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
@ -298,6 +313,7 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
|
|||
|
||||
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(possible_func);
|
||||
if (func == nullptr) {
|
||||
CheckInterpretedObject(possible_func);
|
||||
MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << possible_func->ToString() << ".";
|
||||
MS_LOG(ERROR) << "It's called at: " << cnode->DebugString();
|
||||
MS_EXCEPTION(ValueError) << "This may be not defined, or it can't be a operator. Please check code.";
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Third-party modules."""
|
||||
|
||||
# List common third-party modules for JIT Fallback.
|
||||
# Refer to "https://github.com/eplt/deep-learning-coursera-complete/blob/master/awesome-python-machine-learning.md"
|
||||
jit_fallback_third_party_modules_whitelist = (
|
||||
# Python built-in modules.
|
||||
"datetime", "re", "difflib", "math", "cmath", "random",
|
||||
# Machine Learning.
|
||||
"ml_metrics", "nupic", "sklearn", "pyspark", "vowpal_porpoise", "xgboost",
|
||||
# Natual Language Processing.
|
||||
"gensim", "jieba", "langid", "nltk", "pattern", "polyglot", "snownlp", "spacy", "textblob", "quepy", "yalign",
|
||||
"spammy", "genius", "konlpy", "nut", "rosetta", "bllipparser", "pynlpl", "ucto", "frog", "zpar", "colibricore",
|
||||
"StanfordDependencies", "distance", "thefuzz", "jellyfish", "editdistance", "textacy", "pycorenlp", "cltk",
|
||||
"rasa_nlu", "drqa", "dedupe",
|
||||
# Text Processing.
|
||||
"chardet", "ftfy", "Levenshtein", "pangu", "pyfiglet", "pypinyin", "shortuuid", "unidecode", "uniout",
|
||||
"xpinyin", "slugify", "phonenumbers", "ply", "pygments", "pyparsing", "nameparser", "user_agents", "sqlparse",
|
||||
# Web Content Extracting.
|
||||
"haul", "html2text", "lassie", "micawber", "newspaper", "goose", "readability", "requests_html", "sanitize",
|
||||
"sumy", "textract", "toapi",
|
||||
# Web Crawling.
|
||||
"cola", "demiurge", "feedparser", "grab", "mechanicalsoup", "pyspider", "robobrowser", "scrapy",
|
||||
# Algorithms and Design Patterns.
|
||||
"algorithms", "pypattyrn", "sortedcontainers", "scoop",
|
||||
# Cryptography.
|
||||
"cryptography", "hashids", "paramiko", "passlib", "nacl",
|
||||
# Data Analysis (without plotting).
|
||||
"numpy", "scipy", "blaze", "pandas", "numba", "pymc", "zipline", "pydy", "sympy", "statsmodels", "astropy",
|
||||
"vincent", "pygal", "pycascading", "emcee", "windml", "vispy", "Dora", "ruffus", "sompy", "somoclu", "hdbscan",
|
||||
# Computer Vision.
|
||||
"PIL", "skimage", "SimpleCV", "PCV", "face_recognition",
|
||||
# General-Purpose Machine Learning.
|
||||
"cntk", "auto_ml", "xgboost", "featureforge", "scikits", "metric_learn", "simpleai", "bigml", "pylearn2", "keras",
|
||||
"lasagne", "hebel", "topik", "pybrain", "surprise", "recsys", "thinkbayes", "nilearn", "neuropredict", "pyevolve",
|
||||
"pyhsmm", "mrjob", "neurolab", "pebl", "yahmm", "timbl", "deap", "deeppy", "mlxtend", "neon", "optunity", "topt",
|
||||
"pgmpy", "milk", "rep", "rgf", "FukuML", "stacked_generalization", "modAL", "cogitare", "gym",
|
||||
)
|
|
@ -17,11 +17,12 @@
|
|||
"""The module of parser python object, called by c++."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import ast
|
||||
import hashlib
|
||||
import inspect
|
||||
import types
|
||||
import platform
|
||||
import importlib
|
||||
from dataclasses import is_dataclass
|
||||
from textwrap import dedent
|
||||
|
||||
|
@ -35,6 +36,7 @@ from mindspore.common.api import _MindsporeFunctionExecutor, _convert_data
|
|||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace, ClassAttrNamespace
|
||||
from .resources import parse_object_map, ops_symbol_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
|
||||
from .jit_fallback_modules import jit_fallback_third_party_modules_whitelist
|
||||
|
||||
# define return value
|
||||
RET_SUCCESS = 0
|
||||
|
@ -152,9 +154,33 @@ def get_bprop_method_of_class(obj, parse_method=None):
|
|||
method = getattr(obj, method_name)
|
||||
return method
|
||||
|
||||
|
||||
def get_env_support_modules():
|
||||
"""Get support modules from environment variable."""
|
||||
support_modules = os.getenv('MS_DEV_SUPPORT_MODULES')
|
||||
if support_modules is None:
|
||||
return []
|
||||
env_support_modules = []
|
||||
modules = support_modules.split(',')
|
||||
for module in modules:
|
||||
try:
|
||||
module_spec = importlib.util.find_spec(module)
|
||||
except (ModuleNotFoundError, ValueError):
|
||||
module = module[0:module.rfind('.')]
|
||||
module_spec = importlib.util.find_spec(module)
|
||||
if module_spec is None:
|
||||
raise ModuleNotFoundError(f"Cannot find module: {module}. " \
|
||||
f"Please check if {module} is installed, or if MS_DEV_SUPPORT_MODULES is set correctly.")
|
||||
# Add the outermost module.
|
||||
env_support_modules.append(module.split('.')[0])
|
||||
logger.debug(f"Get support modules from env: {env_support_modules}")
|
||||
return env_support_modules
|
||||
|
||||
|
||||
# The fallback feature is enabled in default.
|
||||
# Not support change the flag during the process is alive.
|
||||
support_fallback_ = os.getenv('MS_DEV_ENABLE_FALLBACK')
|
||||
support_modules_ = get_env_support_modules()
|
||||
|
||||
|
||||
def resolve_symbol(namespace, symbol):
|
||||
|
@ -571,6 +597,41 @@ def get_args(node):
|
|||
return args
|
||||
|
||||
|
||||
def _in_sys_path(file_path):
|
||||
for path in list(sys.path):
|
||||
if file_path.startswith(path):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_third_party_module(value):
|
||||
"""To check if value is a third-party module."""
|
||||
# Check if value is a module or package.
|
||||
if not inspect.ismodule(value) or not hasattr(value, '__file__'):
|
||||
return False
|
||||
# Check if module file is under the sys path.
|
||||
module_file = value.__file__
|
||||
if not _in_sys_path(module_file):
|
||||
return False
|
||||
|
||||
# Get module leftmost name.
|
||||
if not hasattr(value, '__name__'):
|
||||
return False
|
||||
module_name = value.__name__
|
||||
module_leftmost_name = module_name.split('.')[0]
|
||||
# Ignore mindspore package.
|
||||
if module_leftmost_name == "mindspore":
|
||||
return False
|
||||
# Check if module is in whitelist.
|
||||
if module_leftmost_name in support_modules_:
|
||||
logger.debug(f"Found support modules from env: {module_name}")
|
||||
return True
|
||||
if module_leftmost_name in jit_fallback_third_party_modules_whitelist:
|
||||
logger.debug(f"Found third-party module: {module_name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def eval_script(exp_str, params):
|
||||
"""Evaluate a python expression."""
|
||||
if not isinstance(params, tuple):
|
||||
|
@ -614,17 +675,6 @@ class Parser:
|
|||
self.line_offset = 0
|
||||
self.filename: str = inspect.getfile(inspect.unwrap(self.fn))
|
||||
|
||||
# Used to resolve mindspore builtin ops namespace.
|
||||
self.ms_common_ns = CellNamespace('mindspore.common')
|
||||
self.ms_nn_ns = CellNamespace('mindspore.nn')
|
||||
self.ms_ops_ns = CellNamespace('mindspore.ops')
|
||||
self.ms_ops_c_ns = CellNamespace('mindspore.ops.composite')
|
||||
self.ms_ops_c_multitype_ns = CellNamespace('mindspore.ops.composite.multitype_ops')
|
||||
self.ms_ops_p_ns = CellNamespace('mindspore.ops.operations')
|
||||
if platform.system().lower() != 'windows':
|
||||
self.ms_scipy_ns = CellNamespace('mindspore.scipy')
|
||||
else:
|
||||
self.ms_scipy_ns = {}
|
||||
# Used to resolve the function's globals namespace.
|
||||
self.global_namespace = CellNamespace(fn.__module__)
|
||||
self.function_module = fn.__module__
|
||||
|
@ -720,86 +770,6 @@ class Parser:
|
|||
error_info = f"The name '{var}' is not defined in function '{self.function_name}'."
|
||||
return None, error_info
|
||||
|
||||
def is_rightmost_name_in_namespace_module(self, name):
|
||||
"""Check supported Module namespace."""
|
||||
rightmost_name = name.split('.')[-1]
|
||||
if rightmost_name in self.ms_ops_ns:
|
||||
logger.debug(f"Found '{name}'({rightmost_name}) in ops namespace: {str(self.ms_ops_ns)}.")
|
||||
return True
|
||||
if rightmost_name in self.ms_ops_c_ns:
|
||||
logger.debug(f"Found '{name}'({rightmost_name}) in C namespace: {str(self.ms_ops_c_ns)}.")
|
||||
return True
|
||||
if rightmost_name in self.ms_ops_c_multitype_ns:
|
||||
logger.debug(
|
||||
f"Found '{name}'({rightmost_name}) in C.multitype namespace: {str(self.ms_ops_c_multitype_ns)}.")
|
||||
return True
|
||||
if rightmost_name in self.ms_ops_p_ns:
|
||||
logger.debug(f"Found '{name}'({rightmost_name}) in P namespace: {str(self.ms_ops_p_ns)}.")
|
||||
return True
|
||||
if rightmost_name in self.ms_common_ns:
|
||||
logger.debug(f"Found '{name}'({rightmost_name}) in common namespace: {str(self.ms_common_ns)}.")
|
||||
return True
|
||||
# Support nn.layer. To check if exclude other module.
|
||||
if rightmost_name in self.ms_nn_ns:
|
||||
logger.debug(f"Found '{name}'({rightmost_name}) in nn namespace: {str(self.ms_nn_ns)}.")
|
||||
return True
|
||||
if rightmost_name in self.ms_scipy_ns:
|
||||
logger.debug(f"Found '{name}'({rightmost_name}) in scipy namespace: {str(self.ms_scipy_ns)}.")
|
||||
return True
|
||||
if rightmost_name in trope_ns:
|
||||
logger.debug(f"Found '{name}'({rightmost_name}) in trope namespace: {str(trope_ns)}.")
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_supported_namespace_module(self, value):
|
||||
"""To check if the module is allowed to support."""
|
||||
# Check `mindspore` namespace.
|
||||
if not hasattr(value, '__name__'):
|
||||
logger.debug(f"'{str(value)}' has no '__name__' attribute, we suppose it's supported.")
|
||||
return True
|
||||
name = value.__name__
|
||||
if name == 'mindspore':
|
||||
logger.debug(f"Found 'mindspore' root namespace.")
|
||||
return True
|
||||
if name == 'mindspore.ops':
|
||||
logger.debug(f"Found 'mindspore.ops' namespace.")
|
||||
return True
|
||||
if name == 'mindspore.nn':
|
||||
logger.debug(f"Found 'mindspore.nn' namespace.")
|
||||
return True
|
||||
if name == 'mindspore.numpy':
|
||||
logger.debug(f"Found 'mindspore.numpy' namespace.")
|
||||
return True
|
||||
if platform.system().lower() != 'windows' and name == 'mindspore.scipy':
|
||||
logger.debug(f"Found 'mindspore.scipy' namespace.")
|
||||
return True
|
||||
if name == 'mindspore.context':
|
||||
logger.debug(f"Found 'mindspore.context' namespace.")
|
||||
return True
|
||||
|
||||
if name == 'functools':
|
||||
logger.debug(f"Found 'functools' namespace.")
|
||||
return True
|
||||
|
||||
# Check `builtins` namespace.
|
||||
if hasattr(value, '__module__'): # Not types.ModuleType
|
||||
mod = value.__module__
|
||||
if mod == 'builtins':
|
||||
logger.debug(f"Found '{name}' in 'builtins' namespace.")
|
||||
return True
|
||||
|
||||
# We suppose it's supported if not a Module.
|
||||
if not isinstance(value, types.ModuleType):
|
||||
logger.debug(f"Found '{name}', not a module.")
|
||||
return True
|
||||
|
||||
# Check supported Module namespace.
|
||||
if self.is_rightmost_name_in_namespace_module(name):
|
||||
return True
|
||||
|
||||
logger.info(f"Not found '{name}' in mindspore supported namespace.")
|
||||
return False
|
||||
|
||||
def get_builtin_namespace_symbol(self, var: str):
|
||||
"""Get mindspore builtin namespace and symbol."""
|
||||
if var in self.closure_namespace:
|
||||
|
@ -817,7 +787,7 @@ class Parser:
|
|||
support_info = self.global_namespace, var, value, SYNTAX_UNSUPPORTED_EXTERNAL_TYPE
|
||||
elif self.is_unsupported_special_type(value):
|
||||
support_info = self.global_namespace, var, value, SYNTAX_UNSUPPORTED_SPECIAL_TYPE
|
||||
elif self.is_unsupported_namespace(value) or not self.is_supported_namespace_module(value):
|
||||
elif self.is_unsupported_namespace(value) or is_third_party_module(value):
|
||||
support_info = self.global_namespace, var, value, SYNTAX_UNSUPPORTED_NAMESPACE
|
||||
else:
|
||||
support_info = self.global_namespace, var, value, SYNTAX_SUPPORTED
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback """
|
||||
import functools
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
|
@ -272,3 +273,22 @@ def test_probability_cauchy():
|
|||
|
||||
net = CauchyProb(loc, scale)
|
||||
net(Tensor(value), Tensor(loc_a), Tensor(scale_a))
|
||||
|
||||
|
||||
def test_third_party_module_functools():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: functools is a python built-in module and does not perform JIT Fallback.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class ModuleNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
func = functools.partial(add_func, x)
|
||||
out = func(y)
|
||||
return out
|
||||
|
||||
x = Tensor([1, 2, 3], mstype.int32)
|
||||
y = Tensor([4, 5, 6], mstype.int32)
|
||||
net = ModuleNet()
|
||||
out = net(x, y)
|
||||
print(out)
|
||||
|
|
|
@ -18,7 +18,8 @@ import numpy as np
|
|||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, context, ms_class
|
||||
from mindspore import Tensor, context, ms_class, ms_function
|
||||
from . import test_graph_fallback
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -137,6 +138,23 @@ def test_fallback_self_method_tensor():
|
|||
print(out)
|
||||
|
||||
|
||||
def test_fallback_import_modules():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: add_func is defined in test_graph_fallback.py
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def use_imported_module(x, y):
|
||||
out = test_graph_fallback.add_func(x, y)
|
||||
return out
|
||||
|
||||
x = Tensor(2, dtype=mstype.int32)
|
||||
y = Tensor(3, dtype=mstype.int32)
|
||||
out = use_imported_module(x, y)
|
||||
print(out)
|
||||
|
||||
|
||||
def test_fallback_class_attr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
""" test graph fallback """
|
||||
import pytest
|
||||
import numpy as np
|
||||
import numpy.random as rand
|
||||
from mindspore import ms_function, context, Tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
@ -547,3 +548,18 @@ def test_np_slice():
|
|||
return Tensor(b)
|
||||
res = np_slice()
|
||||
assert np.all(res.asnumpy() == np.array([1, 2, 3, 4]))
|
||||
|
||||
|
||||
def test_np_random():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test numpy.random module in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def np_random():
|
||||
a = rand.randint(100, size=(5))
|
||||
b = a[1:5]
|
||||
return Tensor(b)
|
||||
res = np_random()
|
||||
print(res)
|
||||
|
|
Loading…
Reference in New Issue