!22843 Support fallback feature in Graph mode.

Merge pull request !22843 from 张清华/opt_fallback
This commit is contained in:
i-robot 2021-09-06 11:35:45 +00:00 committed by Gitee
commit 05a0898352
33 changed files with 792 additions and 288 deletions

View File

@ -13,9 +13,12 @@
"mindspore/mindspore/core/mindrt/src/actor/actorpolicy.h" "runtime/references"
"mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/" "readability/casting"
"mindspore/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc" "build/include_what_you_use"
"mindspore/mindspore/ccsrc/utils/convert_utils_py.cc" "whitespace/indent"
# Modelzoo
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references"
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/PostProcess/Yolov4TinyMindsporePost.h" "runtime/references"
# MindData
"mindspore/mindspore/ccsrc/minddata/mindrecord/include/shard_page.h" "runtime/string"
"mindspore/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h" "runtime/references"

View File

@ -27,6 +27,7 @@
"mindspore/mindspore/nn/cell.py" "assignment-from-none"
"mindspore/mindspore/_extends/parse/resources.py" "bad-whitespace"
"mindspore/mindspore/_extends/parse/parser.py" "broad-except"
"mindspore/mindspore/_extends/parse/parser.py" "eval-used"
"mindspore/mindspore/nn/cell.py" "protected-access"
"mindspore/mindspore/nn/optim/ftrl.py" "unused-import"
"mindspore/mindspore/train/amp.py" "protected-access"

View File

@ -16,20 +16,18 @@
Interfaces for parser module in c++.
"""
from .parser import (Parser, create_obj_instance, generate_scope,
get_bprop_method_of_class, get_class_instance_type,
get_class_member_namespace_symbol, create_slice_obj,
get_dataclass_attributes, get_dataclass_methods, get_obj_id,
get_module_namespace, get_obj_type, get_object_key,
get_ast_type, get_node_type, get_args, get_args_default_values,
get_ast_namespace_symbol, get_operation_namespace_symbol,
get_parse_method_of_class, get_scope_name, expand_expr_statement,
from .parser import (Parser, create_instance, is_supported_create_instance_type, generate_scope,
get_bprop_method_of_class, get_class_instance_type, get_class_member_namespace_symbol,
create_slice_obj, get_dataclass_attributes, get_dataclass_methods, get_obj_id,
get_module_namespace, get_obj_type, get_object_key, get_ast_type, get_node_type,
get_args, get_args_default_values, get_ast_namespace_symbol, get_operation_namespace_symbol,
get_parse_method_of_class, get_scope_name, eval_script, expand_expr_statement,
is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor, get_object_description)
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type',
'get_args_default_values', 'get_ast_namespace_symbol', 'get_operation_namespace_symbol',
'get_args', 'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
'get_dataclass_methods', 'get_dataclass_methods', 'get_scope_name',
'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement']
'get_args', 'get_obj_type', 'get_obj_id', 'create_instance', 'is_supported_create_instance_type',
'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser',
'get_dataclass_attributes', 'get_dataclass_methods', 'get_dataclass_methods', 'get_scope_name',
'eval_script', 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement']

View File

@ -16,6 +16,7 @@
# ============================================================================
"""The module of parser python object, called by c++."""
import os
import ast
import hashlib
import inspect
@ -25,7 +26,7 @@ from textwrap import dedent
import asttokens
from mindspore import Tensor as MsTensor
from mindspore import Tensor
from mindspore import context
from mindspore import log as logger
from mindspore import nn
@ -132,6 +133,9 @@ def get_bprop_method_of_class(obj, parse_method=None):
method = getattr(obj, method_name)
return method
# The fallback feature is enabled in default.
# Not support change the flag during the process is alive.
support_fallback_ = os.getenv('ENV_SUPPORT_FALLBACK')
def resolve_symbol(namespace, symbol):
"""
@ -159,10 +163,13 @@ def resolve_symbol(namespace, symbol):
if getattr(resolve_, "__hash__") is None:
return resolve_
# Raise NotImplementedError when parsing the numpy methods, but not the numpy constant.
if namespace.name == "numpy" and isinstance(resolve_, (types.FunctionType, types.MethodType, types.ModuleType)):
raise NotImplementedError(
f"MindSpore does not support to use the numpy methods in the function construct with the graph mode.")
# Raise a proper error if not using Fallback feature.
if support_fallback_ != '1':
# Raise NotImplementedError when parsing the numpy methods, but not the numpy constant.
if namespace.name == "numpy" and \
isinstance(resolve_, (types.FunctionType, types.MethodType, types.ModuleType)):
raise NotImplementedError("Mindspore does not support to use the numpy methods " \
"within the construct() or @ms_function decorated function in graph mode.")
# If need trope the obj
if resolve_ in convert_object_map:
@ -177,9 +184,11 @@ def resolve_symbol(namespace, symbol):
logger.debug("resolve exception occurred, value = %r", e)
logger.debug("resolve type is invalid, namespace = %s, symbol = %s",
namespace.__str__(), symbol)
if isinstance(resolve_, _MindsporeFunctionExecutor):
logger.debug("resolve class _MindsporeFunctionExecutor, resolve fn instead.")
resolve_ = resolve_.fn
logger.debug(f'found: {symbol} in {namespace.__str__()}, resolve: {resolve_} / {type(resolve_)}')
return resolve_
@ -292,22 +301,24 @@ def _convert_tuple_to_args_kwargs(params):
args += (param,)
return (args, kwargs)
def is_supported_create_instance_type(cls_type):
return issubclass(cls_type, (nn.Cell, ops.Primitive))
def create_obj_instance(cls_type, params=None):
def create_instance(cls_type, params=None):
"""Create python instance."""
if not isinstance(cls_type, type):
logger.warning(f"create_obj_instance(), cls_type is not a type, cls_type: {cls_type}")
logger.warning(f"create_instance(), cls_type is not a type, cls_type: {cls_type}")
return None
# Check the type, now only support nn.Cell and Primitive.
obj = None
if issubclass(cls_type, (nn.Cell, ops.Primitive)):
if is_supported_create_instance_type(cls_type):
# Check arguments, only support *args or **kwargs.
if params is None:
obj = cls_type()
elif isinstance(params, tuple):
args, kwargs = _convert_tuple_to_args_kwargs(params)
logger.debug(f"create_obj_instance(), args: {args}, kwargs: {kwargs}")
logger.debug(f"create_instance(), args: {args}, kwargs: {kwargs}")
if args and kwargs:
obj = cls_type(*args, **kwargs)
elif args:
@ -358,7 +369,7 @@ def get_dataclass_methods(cls):
def convert_to_ms_tensor(data):
"""Convert C++ tensor to mindspore tensor."""
return MsTensor(data)
return Tensor(data)
def get_object_description(obj, fname, fline):
@ -415,7 +426,6 @@ def get_operation_namespace_symbol(var: str):
logger.debug("get operation ops info = %r", ops_info)
return ops_info
def get_ast_type(node):
"""Get the ast type."""
ast_type = AST_SUB_TYPE_UNKNOWN
@ -483,6 +493,21 @@ def get_args(node):
args.append(node.args.kwarg)
return args
def eval_script(exp_str, params):
"""Evaluate a python expression."""
if not isinstance(params, tuple):
raise ValueError(f"eval_script(), params is not a tuple, params: {params}")
if len(params) != 2:
raise ValueError(f"eval_script(), params tuple length is wrong, params: {params}")
logger.debug(f'exp_str: {exp_str}, params: {params}')
global_params = params[0]
local_params = params[1]
obj = eval(exp_str, global_params, local_params)
if obj is None:
raise ValueError(f"When call 'eval', the result is none. exp_str: '{exp_str}'")
return obj
class Parser:
"""
@ -501,7 +526,10 @@ class Parser:
self.line_offset = 0
self.filename: str = inspect.getfile(self.fn)
# Used to resolve the function's globals Namespace.
# Used to resolve mindspore builtin ops namespace.
self.ms_common_ns = CellNamespace('mindspore.common')
self.ms_ops_ns = CellNamespace('mindspore.ops')
# Used to resolve the function's globals namespace.
self.global_namespace = CellNamespace(fn.__module__)
self.function_module = fn.__module__
# Used to resolve the function's nonlocals.
@ -512,38 +540,37 @@ class Parser:
def parse(self):
"""Parse the function or method."""
logger.debug("fn = %r", self.fn)
tree = None
if isinstance(self.fn, (types.FunctionType, types.MethodType)):
lines, self.line_offset = inspect.getsourcelines(self.fn)
original_src = ''.join(lines)
hexstr = hashlib.sha256(original_src.encode()).hexdigest()
tree = Parser.ast_cache.get(hexstr)
if not tree:
ast_tokens = Parser.ast_cache.get(hexstr)
if not ast_tokens:
src = dedent(original_src)
self.col_offset = \
len(original_src.split('\n')[0]) - len(src.split('\n')[0])
logger.debug("get source = %s", src)
try:
tree = asttokens.ASTTokens(src, parse=True).tree
ast_tokens = asttokens.ASTTokens(src, parse=True)
except IndentationError as idt_err:
idt_err.filename = self.filename
idt_err.lineno = self.line_offset
idt_err.msg = f"There are incorrect indentations in definition or comment of function: " \
f"'{self.fn.__qualname__}'."
raise idt_err
Parser.ast_cache[hexstr] = tree
else:
logger.error("Fn type is invalid")
return tree
Parser.ast_cache[hexstr] = ast_tokens
return ast_tokens, ast_tokens.tree
logger.error("Fn type is invalid")
return None, None
def get_namespace_symbol(self, var: str):
"""Get symbol type and namespace and symbol."""
if var in self.closure_namespace:
logger.debug("in closure_namespace")
logger.debug(f"found {var} in closure_namespace {self.closure_namespace.__str__()}")
return self.closure_namespace, var
if var in self.global_namespace:
logger.debug("in global_namespace")
logger.debug(f"found {var} in global_namespace {self.global_namespace.__str__()}")
value = self.global_namespace[var]
if isinstance(value, type(abs)) and self.global_namespace[var] not in convert_object_map:
error_info = f"The builtin function '{var}' is not supported in graph mode."
@ -552,6 +579,56 @@ class Parser:
error_info = f"The name '{var}' is not defined."
return None, var, error_info
def is_unsupported_builtin_type(self, value_type):
"""To check if not supported builtin type"""
logger.debug(f'value_type: {value_type}, {type([])}, {type(())}')
return value_type == type([]) or value_type == type(())
def is_supported_namespace_module(self, value):
"""To check if the module is allowed to support."""
if not hasattr(value, '__name__'):
return True
name = value.__name__
if name == 'mindspore':
logger.debug(f'...found {name} in mindspore root namespace')
return True
if not isinstance(value, types.ModuleType):
return False
rightmost_name = name.split('.')[-1]
# if rightmost_name in self.ms_common_ns:
# logger.error(f'...found {module_name} in common namespace: {self.ms_common_ns.__str__()}')
# return True
if rightmost_name in self.ms_ops_ns:
logger.debug(f'...found {name}({rightmost_name}) in ops namespace: {self.ms_ops_ns.__str__()}')
return True
if rightmost_name in trope_ns:
logger.debug(f'...found {name}({rightmost_name}) in trope namespace: {self.trope_ns.__str__()}')
return True
return False
def get_builtin_namespace_symbol(self, var: str):
"""Get mindspore builtin namespace and symbol."""
if var in self.closure_namespace:
logger.debug(f"found {var} in closure_namespace {self.closure_namespace.__str__()}")
return self.closure_namespace, var
if var in self.global_namespace:
logger.debug(f"found {var} in global_namespace {self.global_namespace.__str__()}")
value = self.global_namespace[var]
value_str = value.__name__ if hasattr(value, '__name__') else str(value)
logger.debug(f"value: {type(value)}, : {value_str}, hasattr(__name__): {hasattr(value, '__name__')}")
# To check if allowed to support.
if self.is_unsupported_builtin_type(value):
return self.global_namespace, var, value
if not self.is_supported_namespace_module(value): # Check if support including instance of types.ModuleType
return self.global_namespace, var, value
return self.global_namespace, var
error_info = f"The symbol '{var}' is not supported in graph mode."
logger.debug(error_info)
return None, var, error_info
def analyze_super(self, class_type_node, subclass_instance):
"""Analyze super and return a class instance."""
sub_class = type(subclass_instance)

View File

@ -444,7 +444,7 @@ AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitiveP
CheckArgsSize(op_name, args_spec_list, 1);
AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
py::tuple data_tuple = ValuePtrToPyData(input->BuildValue());
py::tuple data_tuple = ValueToPyData(input->BuildValue());
py::array data = py::array(data_tuple);
auto tensor = tensor::TensorPy::MakeTensor(data);
auto ret = tensor->ToAbstract();

View File

@ -268,7 +268,7 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const pipeline::ResourceB
}
}
if (!fn || py::isinstance<py::none>(fn)) {
MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << ".";
MS_LOG(ERROR) << "Fail to find bprop function for " << prim->name() << ". fn: " << py::str(fn);
return nullptr;
}
func_graph = parse::ParsePythonCode(fn);

View File

@ -59,7 +59,7 @@ ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name
}
std::vector<py::object> arg_list;
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
[](const Attr &attr) { return ValuePtrToPyData(attr.second); });
[](const Attr &attr) { return ValueToPyData(attr.second); });
py::object obj =
parse::python_adapter::CallPyFn(GET_OP_FUNCTION_PATH, GET_OP_FUNCTION, op_name, op_path, instance_name, arg_list);
ValuePtr op_instance = nullptr;

View File

@ -215,11 +215,13 @@ ValuePtr ConvertDict(const py::object &obj, bool use_signature) {
return std::make_shared<ValueDictionary>(key_values);
}
ValuePtr ConvertNameSpace(const py::object &obj) {
ValuePtr ConvertModuleNameSpace(const py::object &obj) {
MS_LOG(DEBUG) << "Converting python module";
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj);
auto converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace));
auto converted =
std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace), obj);
MS_LOG(DEBUG) << "name_space: " << converted->ToString();
return converted;
}
@ -355,7 +357,9 @@ ValuePtr ConvertOtherObj(const py::object &obj) {
// When the obj is Cell, default parse the 'construct'
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
return std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
auto res = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
MS_LOG(DEBUG) << "name_space: " << res->ToString();
return res;
}
MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));
return nullptr;
@ -456,7 +460,7 @@ std::vector<DataConverterPtr> GetDataConverters() {
std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>),
std::make_shared<ByTypeDataConverter<py::str>>(PyCast<StringImm, string>),
std::make_shared<ByTypeDataConverter<py::ellipsis>>(kEllipsis),
std::make_shared<ByTypeDataConverter<py::module>>(ConvertNameSpace),
std::make_shared<ByTypeDataConverter<py::module>>(ConvertModuleNameSpace),
std::make_shared<ByAttrDataConverter>(PYTHON_DATACLASS_FIELDS, ConvertDataClass),
std::make_shared<ByTypeDataConverter<Type>>(ObjCast<TypePtr>),
std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
@ -466,8 +470,10 @@ std::vector<DataConverterPtr> GetDataConverters() {
std::make_shared<ByTypeDataConverter<EnvInstance>>(ObjCast<std::shared_ptr<EnvInstance>>),
std::make_shared<ByAttrDataConverter>(PYTHON_CLASS_MEMBER_NAMESPACE,
[](const py::object &obj) -> ValuePtr {
return std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER,
obj);
auto res =
std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
MS_LOG(DEBUG) << "name_space: " << res->ToString();
return res;
}),
std::make_shared<ByTypeDataConverter<py::int_>>(ConvertIntegerWithType),
std::make_shared<ByTypeDataConverter<py::float_>>(ConvertFloatWithType),
@ -598,8 +604,16 @@ bool IsCellInstance(const py::object &obj) {
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
// `args_kwargs` maybe a tuple(*args), tuple(**kwargs), or tuple(*args, **kwargs).
return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type)
: python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, args_kwargs);
return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_INSTANCE, type)
: python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_INSTANCE, type, args_kwargs);
}
// Call the python script string.
py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
// `args_kwargs` is a tuple(dict(global), dict(local)).
return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_EVAL_PY_SCRIPT, script)
: python_adapter::CallPyModFn(mod, PYTHON_MOD_EVAL_PY_SCRIPT, script, args_kwargs);
}
// Generate an appropriate name and set to graph debuginfo,

View File

@ -45,6 +45,7 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object &obj);
bool IsCellInstance(const py::object &obj);
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs);
py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs);
void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name);
ValuePtr PyDataToValue(const py::object &obj);
void ClearObjectCache();

View File

@ -25,6 +25,7 @@
#include "pybind11/pybind11.h"
#include "pipeline/jit/parse/resolve.h"
#include "pipeline/jit/parse/parse.h"
#include "pipeline/jit/parse/data_converter.h"
#include "frontend/operator/ops.h"
#include "utils/info.h"
#include "debug/trace.h"
@ -71,7 +72,7 @@ void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var " << var_name << " with node "
<< node->DebugString();
auto [iter, is_new_name] = vars_.emplace(var_name, std::make_pair(node, false));
auto [iter, is_new_name] = assigned_vars_.emplace(var_name, std::make_pair(node, false));
if (!is_new_name) {
// If a cnode variable with same name already existed but not used,
// add it as an isolate node. for example:
@ -97,8 +98,8 @@ void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr
// Read variable from predecessors
AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
// Get var node if it is found
auto found = vars_.find(var);
if (found != vars_.end()) {
auto found = assigned_vars_.find(var);
if (found != assigned_vars_.end()) {
auto &node = found->second.first;
MS_EXCEPTION_IF_NULL(node);
// Mark the variable as used.
@ -109,7 +110,7 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
}
return node;
}
// Get var from predecessor block ,if can't get then make a resolve node to it
// Get var from predecessor block, if can't get then make a resolve node to it
if (matured_) {
// If only one predecessor block, read the definition of var from it.
if (prev_blocks_.size() == 1) {
@ -122,6 +123,7 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
if (it != var_to_resolve_.end()) {
return it->second;
}
MS_LOG(DEBUG) << "var: " << var;
auto tmp_node = MakeResolveSymbol(var);
var_to_resolve_[var] = tmp_node;
return tmp_node;
@ -154,6 +156,7 @@ AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) {
}
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_AST, namespace_var[0]);
SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[1].cast<std::string>());
MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
return MakeResolve(name_space, symbol);
}
@ -164,9 +167,39 @@ AnfNodePtr FunctionBlock::MakeResolveClassMember(const std::string &attr) {
py::object namespace_var = ast->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, ast->obj());
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
SymbolPtr symbol = std::make_shared<Symbol>(attr);
MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
return MakeResolve(name_space, symbol);
}
AnfNodePtr FunctionBlock::HandleNamespaceInfo(const py::tuple &namespace_info) {
const size_t namespace_info_size = 2;
const size_t namespace_more_info_size = 3;
if (namespace_info.size() != namespace_info_size && namespace_info.size() != namespace_more_info_size) {
MS_EXCEPTION(NameError) << "namespace info size should be 2 or 3, but got " << namespace_info.size();
}
bool unsupported = false;
py::object py_obj;
if (namespace_info.size() == namespace_more_info_size) {
if (namespace_info[0].is_none()) { // If namespace is None, the symbol is an undefined name.
MS_EXCEPTION(NameError) << namespace_info[namespace_more_info_size - 1].cast<std::string>();
} else { // Or, the symbol is an unsupported builtin symbol in Graph mode.
unsupported = true;
py_obj = namespace_info[namespace_more_info_size - 1];
}
}
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_info[0]);
SymbolPtr symbol = std::make_shared<Symbol>(namespace_info[1].cast<std::string>());
MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString()
<< ", unsupported: " << unsupported;
auto resolved_node = MakeResolve(name_space, symbol);
if (unsupported) {
resolved_node->set_interpret(true);
AddGlobalPyParam(symbol->name(), py_obj);
MS_LOG(INFO) << "Added global python symblol: {" << symbol->name() << " : " << py::str(py_obj) << "}";
}
return resolved_node;
}
// Make a resolve node for symbol string
AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
if (value.compare(0, strlen("self"), "self") == 0) {
@ -180,23 +213,17 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
}
auto ast = parser_.ast();
MS_EXCEPTION_IF_NULL(ast);
py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
const size_t namespace_info_size = 2;
if (namespace_info.size() < namespace_info_size) {
MS_EXCEPTION(NameError) << "namespace_info is less than 2";
// The fallback feature is enabled in default.
// Not support change the flag during the process is alive.
static const auto use_fallback = (parser_.support_fallback() == "0" ? false : true);
if (!use_fallback) {
py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
return HandleNamespaceInfo(namespace_info);
} else {
py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_BUILTIN_NAMESPACE_SYMBOL, value);
return HandleNamespaceInfo(namespace_info);
}
// If namespace is None, the symbol is an undefined name or an unsupported builtin function.
if (namespace_info[0].is_none()) {
// If the size of namespace_var is greater than or equal to 3, the error information is stored in namespace_var[2].
if (namespace_info.size() > namespace_info_size) {
MS_EXCEPTION(NameError) << namespace_info[namespace_info_size].cast<std::string>();
}
// If the size of namespace_var is less than 3, the default error information is used.
MS_EXCEPTION(NameError) << "The name \'" << value << "\' is not defined.";
}
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_info[0]);
SymbolPtr symbol = std::make_shared<Symbol>(namespace_info[1].cast<std::string>());
return MakeResolve(name_space, symbol);
}
AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
@ -209,6 +236,7 @@ AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
}
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_COMMON_OPS, namespace_var[0]);
SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[1].cast<std::string>());
MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
return MakeResolve(name_space, symbol);
}
@ -221,6 +249,17 @@ AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const Symb
return node;
}
AnfNodePtr FunctionBlock::MakeInterpret(const std::string &script_text, const AnfNodePtr &global_dict_node,
const AnfNodePtr &local_dict_node, const AnfNodePtr &orig_node) {
MS_LOG(DEBUG) << "MakeInterpret for " << script_text;
ScriptPtr script = std::make_shared<Script>(script_text);
auto script_node = NewValueNode(script);
auto node = func_graph_->NewCNodeInOrder(
{NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
node->set_interpreted_node(orig_node);
return node;
}
// Add input for the block's phi parameter
void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
MS_EXCEPTION_IF_NULL(phi);
@ -418,7 +457,7 @@ void FunctionBlock::FindIsolatedNodes() {
//
std::set<AnfNodePtr> used;
// Find used variables.
for (const auto &var : vars_) {
for (const auto &var : assigned_vars_) {
auto &node = var.second.first;
if (node == nullptr) {
continue;
@ -429,7 +468,7 @@ void FunctionBlock::FindIsolatedNodes() {
}
}
// Add isolated nodes which is unused var but not found in used set.
for (const auto &var : vars_) {
for (const auto &var : assigned_vars_) {
auto &node = var.second.first;
bool is_used = var.second.second;
if (node == nullptr || is_used) {

View File

@ -26,16 +26,17 @@
#include <unordered_map>
#include <memory>
#include <utility>
#include <tuple>
#include "pipeline/jit/parse/parse_base.h"
#include "utils/log_adapter.h"
#include "utils/ordered_set.h"
namespace mindspore {
namespace parse {
class Parser;
class NameSpace;
class Symbol;
class Script;
class FunctionBlock;
using FunctionBlockPtr = std::shared_ptr<FunctionBlock>;
@ -70,11 +71,26 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
AnfNodePtr MakeResolveSymbol(const std::string &value);
AnfNodePtr MakeResolveOperation(const std::string &value);
AnfNodePtr MakeResolve(const std::shared_ptr<NameSpace> &name_space, const std::shared_ptr<Symbol> &resolve_symbol);
AnfNodePtr HandleNamespaceInfo(const py::tuple &namespace_info);
AnfNodePtr MakeInterpret(const std::string &script_text, const AnfNodePtr &global_dict_node,
const AnfNodePtr &local_dict_node, const AnfNodePtr &orig_node);
const std::unordered_map<ParameterPtr, AnfNodePtr> &removable_phis() const { return removable_phis_; }
void FindIsolatedNodes();
void AddIsolatedNode(const AnfNodePtr &target);
void AttachIsolatedNodesBeforeReturn();
py::dict &global_py_params() { return global_py_params_; }
void set_global_py_params(const py::dict &symbols) { global_py_params_ = symbols; }
void AddGlobalPyParam(const std::string &name, const py::object &obj) { global_py_params_[py::str(name)] = obj; }
std::tuple<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> local_py_params() {
return {local_py_params_keys_, local_py_params_values_};
}
void AddLocalPyParam(const std::string &name, const AnfNodePtr &node) {
local_py_params_keys_.emplace_back(NewValueNode(name));
local_py_params_values_.emplace_back(node);
}
private:
// Block graph
FuncGraphPtr func_graph_;
@ -90,7 +106,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
std::vector<FunctionBlock *> prev_blocks_;
// Store args and variable's node, use a bool flag to indicate if the variable is used.
std::map<std::string, std::pair<AnfNodePtr, bool>> vars_;
std::map<std::string, std::pair<AnfNodePtr, bool>> assigned_vars_;
// Map the parameter node to variable, it can be resolved if the block's predecessors are processed
std::map<ParameterPtr, std::string> phi_nodes_;
@ -114,10 +130,25 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
// Keep new made resolve symbol for the variable not found in vars_.
std::unordered_map<std::string, AnfNodePtr> var_to_resolve_;
// Collect all python symbols in the block.
// We treat both global symbols and local symbols declared previously as global symbols.
py::dict global_py_params_;
std::vector<AnfNodePtr> local_py_params_keys_;
std::vector<AnfNodePtr> local_py_params_values_;
// Isolated nodes.
OrderedSet<AnfNodePtr> isolated_nodes_;
};
class ScriptInfo {
public:
explicit ScriptInfo(const py::object &obj) : py_obj_(obj) {}
// Key for user data.
constexpr static char key[] = "ScriptInfo";
py::object py_obj_;
};
} // namespace parse
} // namespace mindspore

View File

@ -24,9 +24,10 @@
#include <unordered_map>
#include <sstream>
#include <algorithm>
#include "pybind_api/pybind_patch.h"
#include "pipeline/jit/parse/resolve.h"
#include "frontend/operator/ops.h"
#include "pipeline/jit/parse/data_converter.h"
#include "frontend/operator/ops.h"
#include "frontend/operator/composite/composite.h"
#include "utils/ms_context.h"
#include "debug/trace.h"
@ -42,7 +43,7 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
return nullptr;
}
auto ast = std::make_shared<ParseAst>(obj);
auto ast = std::make_shared<ParseFunctionAst>(obj);
bool success = ast->InitParseAstInfo(python_mod_get_parse_method);
if (!success) {
MS_LOG(ERROR) << "Parse code to ast tree failed.";
@ -89,8 +90,9 @@ AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNo
FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
Parser::Parser(const std::shared_ptr<ParseAst> &ast) : ast_(ast) {
Parser::Parser(const std::shared_ptr<ParseFunctionAst> &ast) : ast_(ast) {
max_for_loop_count_str_ = common::GetEnv("ENV_FOR_TO_WHILE_LOOP");
support_fallback_ = "0"; // We will open it later by call common::GetEnv("ENV_SUPPORT_FALLBACK")
errcode_ = PARSE_SUCCESS;
BuildMethodMap();
}
@ -147,7 +149,7 @@ void Parser::CleanParserResource() {
ScopeManager::GetInstance().ClearScope();
}
void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &ast) {
void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseFunctionAst> &ast) {
// Check whether the functions referred by this function and itself are missing 'return' statement
auto mng = Manage(fn, false);
MS_EXCEPTION_IF_NULL(ast);
@ -501,6 +503,7 @@ AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &no
MS_LOG(DEBUG) << "The Name id is " << name_id;
MS_EXCEPTION_IF_NULL(block);
if (block->IsGlobalVar(name_id)) {
MS_LOG(DEBUG) << "name_id: " << name_id;
return block->MakeResolveSymbol(name_id);
}
return block->ReadVariable(name_id);
@ -612,6 +615,7 @@ AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &arg
py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance);
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
SymbolPtr symbol = std::make_shared<Symbol>("namespace");
MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
return block->MakeResolve(name_space, symbol);
}
@ -631,7 +635,7 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
}
}
AnfNodePtr call_function_anf_node = ParseExprNode(block, function_ast_node);
AnfNodePtr call_function_node = ParseExprNode(block, function_ast_node);
// Function call arguments should be passed in as groups and unpacked later using unpack call
std::vector<AnfNodePtr> packed_arguments;
std::vector<AnfNodePtr> group_arguments;
@ -641,33 +645,37 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
// If there is stared or keyword argument, unpack may be needed
bool need_unpack = need_unpack_args || need_unpack_keywords;
return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack);
auto call_cnode = GenerateAnfNodeForCall(block, call_function_node, packed_arguments, group_arguments, need_unpack);
if (call_function_node->interpret()) {
call_cnode->set_interpret(true);
}
return call_cnode;
}
CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_anf_node,
CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_node,
const std::vector<AnfNodePtr> &packed_arguments) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> unpack_call_nodes;
auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL));
unpack_call_nodes.push_back(unpack_call_op);
unpack_call_nodes.push_back(call_function_anf_node);
unpack_call_nodes.push_back(call_function_node);
(void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes),
[](AnfNodePtr node) -> AnfNodePtr { return node; });
CNodePtr unpack_call = func_graph->NewCNodeInOrder(unpack_call_nodes);
return unpack_call;
}
AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node,
AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_node,
const std::vector<AnfNodePtr> &packed_arguments,
const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const {
// If there is keyword arguments or starred, using an unpack_call op to unpack the argument
MS_EXCEPTION_IF_NULL(block);
if (need_unpack) {
return MakeUnpackCall(block->func_graph(), call_function_anf_node, packed_arguments);
return MakeUnpackCall(block->func_graph(), call_function_node, packed_arguments);
}
// else there is no keyword arguments and starred, parsed as normal arguments without unpack
std::vector<AnfNodePtr> func_call_nodes;
func_call_nodes.push_back(call_function_anf_node);
func_call_nodes.push_back(call_function_node);
(void)std::transform(group_arguments.begin(), group_arguments.end(), std::back_inserter(func_call_nodes),
[](AnfNodePtr node) -> AnfNodePtr { return node; });
CNodePtr call_anf_node = block->func_graph()->NewCNodeInOrder(func_call_nodes);
@ -689,7 +697,9 @@ bool Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args
group_arguments->clear();
need_unpack = true;
} else {
group_arguments->push_back(ParseExprNode(block, args[i]));
auto node = ParseExprNode(block, args[i]);
node = HandleInterpret(block, node, args[i]);
group_arguments->push_back(node);
}
}
if (!group_arguments->empty()) {
@ -734,7 +744,7 @@ bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object
// Process call attributes of class type define, eg: x.y()
AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast Attribute";
// Process class value,eg: self.xx
// Process class value, eg: self.xx
if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
if (ast_->IsClassMember(node)) {
std::string var_name = "self.";
@ -746,6 +756,7 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
(py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance<py::int_>(attr_obj) ||
py::isinstance<py::float_>(attr_obj) || py::isinstance<py::bool_>(attr_obj) ||
py::isinstance<py::str>(attr_obj) || data_converter::IsCellInstance(attr_obj))) {
MS_LOG(DEBUG) << "var_name: " << var_name;
return block->MakeResolveSymbol(var_name);
} else {
return block->ReadVariable(var_name);
@ -775,7 +786,11 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
}
// Create the apply node
return block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
if (value_node->interpret()) {
attr_cnode->set_interpret(true);
}
return attr_cnode;
}
// Process comparison expression : a == b. a > b etc.
@ -1021,6 +1036,15 @@ AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object
}
// Process a dict ast node expression
AnfNodePtr Parser::ParseDictByKeysAndValues(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &key_nodes,
const std::vector<AnfNodePtr> &value_nodes) {
auto keys_tuple = GenerateMakeTuple(block, key_nodes);
auto values_tuple = GenerateMakeTuple(block, value_nodes);
MS_EXCEPTION_IF_NULL(block);
auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
return block->func_graph()->NewCNodeInOrder({make_dict_op, keys_tuple, values_tuple});
}
AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast Dict";
py::list keys = node.attr("keys");
@ -1031,14 +1055,10 @@ AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &no
key_nodes.push_back(ParseExprNode(block, keys[i]));
value_nodes.push_back(ParseExprNode(block, values[i]));
}
auto keys_tuple = GenerateMakeTuple(block, key_nodes);
auto values_tuple = GenerateMakeTuple(block, value_nodes);
MS_EXCEPTION_IF_NULL(block);
auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
return block->func_graph()->NewCNodeInOrder({make_dict_op, keys_tuple, values_tuple});
return ParseDictByKeysAndValues(block, key_nodes, value_nodes);
}
// Process a augment assign such as a += b or mat[stride_slice] += b.
// Process a augment assign such as a += b or mat[stride_slice] += b.
FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast AugAssign";
MS_EXCEPTION_IF_NULL(block);
@ -1631,17 +1651,18 @@ AnfNodePtr Parser::ParseListComp(const FunctionBlockPtr &block, const py::object
auto top_block = ParseListCompIter(block, node, generator_node);
// Call the top graph and return the list.
auto call_function_anf_node = NewValueNode(top_block->func_graph());
auto call_function_node = NewValueNode(top_block->func_graph());
std::vector<AnfNodePtr> func_call_nodes;
func_call_nodes.push_back(call_function_anf_node);
func_call_nodes.push_back(call_function_node);
AnfNodePtr output = block->func_graph()->NewCNodeInOrder(func_call_nodes);
return output;
}
void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) {
void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &target_obj,
const AnfNodePtr &assigned_node) {
MS_EXCEPTION_IF_NULL(block);
MS_EXCEPTION_IF_NULL(assigned_node);
py::str name = python_adapter::GetPyObjAttr(targ, "id");
py::str name = python_adapter::GetPyObjAttr(target_obj, "id");
std::string name_id = name;
assigned_node->debug_info()->set_name(name_id);
// Set the debug name of the constant graph
@ -1652,13 +1673,16 @@ void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &t
fg->debug_info()->set_name(name_id);
}
}
MS_LOG(DEBUG) << "Assign name: " << name_id << " to node: " << assigned_node;
block->AddLocalPyParam(name_id, assigned_node);
block->WriteVariable(name_id, assigned_node);
}
void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) {
void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &target_obj,
const AnfNodePtr &assigned_node) {
MS_EXCEPTION_IF_NULL(block);
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
py::list items = python_adapter::GetPyObjAttr(targ, "elts");
py::list items = python_adapter::GetPyObjAttr(target_obj, "elts");
for (size_t i = 0; i < items.size(); i++) {
// Use the Primitive replace the operation resolve node (getitem),
// because the getitem will eventually be converted to Primitive node
@ -1670,13 +1694,13 @@ void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &
}
}
void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ,
void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &target_obj,
const AnfNodePtr &assigned_node) {
// Now only support the self.xx = xxxxx, can't support x.y = xxxx
AnfNodePtr target_node = ParseExprNode(block, targ);
AnfNodePtr target_node = ParseExprNode(block, target_obj);
MS_EXCEPTION_IF_NULL(target_node);
auto attr_name = targ.attr("attr").cast<std::string>();
auto attr_name = target_obj.attr("attr").cast<std::string>();
std::string var_name = "self." + attr_name;
// Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
@ -1699,12 +1723,12 @@ void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::ob
block->SetStateAssign(target_node, assigned_node);
}
void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ,
void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &target_obj,
const AnfNodePtr &assigned_node) {
MS_EXCEPTION_IF_NULL(block);
AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM);
py::object value_obj = python_adapter::GetPyObjAttr(targ, "value");
py::object slice_obj = python_adapter::GetPyObjAttr(targ, "slice");
py::object value_obj = python_adapter::GetPyObjAttr(target_obj, "value");
py::object slice_obj = python_adapter::GetPyObjAttr(target_obj, "slice");
AnfNodePtr value_node = ParseExprNode(block, value_obj);
AnfNodePtr slice_node = ParseExprNode(block, slice_obj);
CNodePtr setitem_app = block->func_graph()->NewCNodeInOrder({op_setitem, value_node, slice_node, assigned_node});
@ -1742,18 +1766,19 @@ void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::obje
block->WriteVariable(var_name, setitem_app);
}
void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node) {
void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &target_obj,
const AnfNodePtr &value_node) {
MS_EXCEPTION_IF_NULL(value_node);
MS_LOG(DEBUG) << "Process WriteAssignVars";
auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, targ)));
auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_obj)));
if (ast_type == AST_SUB_TYPE_NAME) {
HandleAssignName(block, targ, value_node);
HandleAssignName(block, target_obj, value_node);
} else if (ast_type == AST_SUB_TYPE_TUPLE) {
HandleAssignTuple(block, targ, value_node);
HandleAssignTuple(block, target_obj, value_node);
} else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
HandleAssignSubscript(block, targ, value_node);
} else if (ast_->IsClassMember(targ)) {
HandleAssignClassMember(block, targ, value_node);
HandleAssignSubscript(block, target_obj, value_node);
} else if (ast_->IsClassMember(target_obj)) {
HandleAssignClassMember(block, target_obj, value_node);
} else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) {
MS_LOG(EXCEPTION) << "The subnet attributes cannot be changed in the network. \n\n"
<< trace::GetDebugInfo(value_node->debug_info());
@ -1763,11 +1788,47 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta
}
}
// Process a assign statement, such as a =b, a,b = tup
AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
const py::object &value_object) {
// The fallback feature is enabled in default.
// Not support change the flag during the process is alive.
static const auto use_fallback = (support_fallback_ == "0" ? false : true);
if (!use_fallback) {
return value_node;
}
AnfNodePtr interpreted_node = value_node;
if (value_node->interpret()) {
const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(value_object));
MS_LOG(INFO) << "script_text: " << script_text << ", value_node: " << value_node->DebugString(2);
// Prepare global parameters.
py::dict global_dict = block->global_py_params();
ValuePtr globals_converted_value = nullptr;
if (!ConvertData(global_dict, &globals_converted_value)) {
MS_LOG(EXCEPTION) << "Convert data failed";
}
auto global_dict_node = NewValueNode(globals_converted_value);
// Prepare local parameters.
auto [keys, values] = block->local_py_params();
auto local_dict_node = ParseDictByKeysAndValues(block, keys, values);
// Update the valued node if it need interpreting.
interpreted_node = block->MakeInterpret(script_text, global_dict_node, local_dict_node, value_node);
// Print a hint for user.
MS_LOG(ERROR) << "Found unsupported syntax in Graph mode, those codes would be fell back to Python interpreter:"
<< "\n\n"
<< trace::GetDebugInfo(value_node->debug_info());
}
return interpreted_node;
}
// Process a assign statement, such as a = b, a, b = tuple(xx, xx)
FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast assign";
py::object value_object = python_adapter::GetPyObjAttr(node, "value");
AnfNodePtr value_node = ParseExprNode(block, value_object);
value_node = HandleInterpret(block, value_node, value_object);
py::object targets_object = python_adapter::GetPyObjAttr(node, "targets");
py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__");
size_t count = LongToSize(pcount);
@ -1870,8 +1931,8 @@ void Parser::RemoveUnnecessaryPhis() {
}
}
// ParseAst class code
bool ParseAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) {
// ParseFunctionAst class code
bool ParseFunctionAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) {
// Init the type
target_type_ = PARSE_TARGET_UNKNOW;
@ -1916,7 +1977,13 @@ bool ParseAst::InitParseAstInfo(const std::string &python_mod_get_parse_method)
// Call python parse get ast tree
parser_ = python_adapter::CallPyModFn(module_, PYTHON_MOD_PARSE_OBJECT_FUNCTION, function_, parse_method);
ast_tree_ = python_adapter::CallPyObjMethod(parser_, "parse");
py::tuple ast_info = python_adapter::CallPyObjMethod(parser_, "parse");
const size_t ast_info_size = 2;
if (ast_info.size() != ast_info_size) {
MS_EXCEPTION(NameError) << "ast info size is not equal to 2.";
}
ast_tokens_ = ast_info[0];
ast_tree_ = ast_info[1];
// Get fn name and module
function_module_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_module"));
@ -1928,23 +1995,28 @@ bool ParseAst::InitParseAstInfo(const std::string &python_mod_get_parse_method)
}
// Get ast tree node : is the tree bode list[0]
py::object ParseAst::GetAstNode() {
py::object ParseFunctionAst::GetAstNode() {
py::list tree_body = python_adapter::GetPyObjAttr(ast_tree_, "body");
py::object ast_node = tree_body[0];
return ast_node;
}
py::list ParseAst::GetArgs(const py::object &func_node) {
// Get ast tokens node text.
py::str ParseFunctionAst::GetAstNodeText(const py::object &node_obj) {
return python_adapter::CallPyObjMethod(ast_tokens_, "get_text", node_obj);
}
py::list ParseFunctionAst::GetArgs(const py::object &func_node) {
py::list ret = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_ARGS, func_node);
return ret;
}
py::list ParseAst::GetArgsDefaultValues(const py::object &func_node) {
py::list ParseFunctionAst::GetArgsDefaultValues(const py::object &func_node) {
py::list ret = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES, func_node);
return ret;
}
AstNodeTypePtr ParseAst::GetNodeType(const py::object &node) {
AstNodeTypePtr ParseFunctionAst::GetNodeType(const py::object &node) {
py::list list_value = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_NODE_TYPE, node);
const size_t list_value_size = 2;
if (list_value.size() < list_value_size) {
@ -1955,12 +2027,12 @@ AstNodeTypePtr ParseAst::GetNodeType(const py::object &node) {
return std::make_shared<AstNodeType>(node, node_name, type);
}
AstSubType ParseAst::GetOpType(const py::object &node) {
AstSubType ParseFunctionAst::GetOpType(const py::object &node) {
auto op_type = AstSubType(python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_AST_TYPE, node).cast<int32_t>());
return op_type;
}
bool ParseAst::IsClassMember(const py::object &node) {
bool ParseFunctionAst::IsClassMember(const py::object &node) {
py::object ret = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER, node);
if (!py::isinstance<py::bool_>(ret)) {
MS_LOG(ERROR) << "The result of mod function parse, should be bool type.";

View File

@ -34,7 +34,6 @@
namespace mindspore {
namespace parse {
// Parse status define
enum ParseStatusCode : int64_t {
PARSE_SUCCESS = 0,
@ -58,7 +57,7 @@ enum ParseStatusCode : int64_t {
const int64_t MAX_FOR_LOOP_COUNT = std::numeric_limits<int64_t>::max();
class AstNodeType;
class ParseAst;
class ParseFunctionAst;
// Save loop info for 'continue' and 'break' statements.
struct Loop {
@ -90,13 +89,14 @@ class LoopContext {
// Parser to parse python function
class Parser {
public:
explicit Parser(const std::shared_ptr<ParseAst> &ast);
explicit Parser(const std::shared_ptr<ParseFunctionAst> &ast);
~Parser() {}
FuncGraphPtr ParseFuncGraph();
FuncGraphPtr func_graph() const { return func_graph_; }
ParseStatusCode errcode() const { return errcode_; }
std::shared_ptr<ParseAst> ast() const { return ast_; }
std::shared_ptr<ParseFunctionAst> ast() const { return ast_; }
const std::string &support_fallback() const { return support_fallback_; }
// Get location info from the ast node
LocationPtr GetLocation(const py::object &node) const;
static void InitParserEnvironment(const py::object &obj);
@ -177,6 +177,8 @@ class Parser {
// Process a unaryop
AnfNodePtr ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node);
// Process a dict ast node expression
AnfNodePtr ParseDictByKeysAndValues(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &key_nodes,
const std::vector<AnfNodePtr> &value_nodes);
AnfNodePtr ParseDict(const FunctionBlockPtr &block, const py::object &node);
// Process ListComp expression
AnfNodePtr ParseListComp(const FunctionBlockPtr &block, const py::object &node);
@ -185,6 +187,10 @@ class Parser {
AnfNodePtr ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param,
const py::object &node, const py::object &generator_node);
// Check if the node need interpreting.
AnfNodePtr HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
const py::object &value_object);
// Generate argument nodes for ast function node
void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node);
// Generate argument default value for ast function node
@ -260,7 +266,7 @@ class Parser {
// The shared_ptr will be hold by GraphManager, so just hold a weak ref here.
static FuncGraphWeakPtr top_func_graph_;
// Python function id, used to indicate whether two CNodes come from the same Python function
const std::shared_ptr<ParseAst> &ast_;
const std::shared_ptr<ParseFunctionAst> &ast_;
FuncGraphPtr func_graph_;
// Error code setwhen parsing ast tree
ParseStatusCode errcode_;
@ -278,6 +284,7 @@ class Parser {
// Save current loops to support 'continue', 'break' statement.
std::stack<Loop> loops_;
string max_for_loop_count_str_;
string support_fallback_;
};
// AST node type define code to ast
@ -303,16 +310,19 @@ class AstNodeType {
using AstNodeTypePtr = std::shared_ptr<AstNodeType>;
// A helper class to parse python function
class ParseAst {
class ParseFunctionAst {
public:
explicit ParseAst(const py::object &obj) : obj_(obj), target_type_(PARSE_TARGET_UNKNOW), function_line_offset_(-1) {}
explicit ParseFunctionAst(const py::object &obj)
: obj_(obj), target_type_(PARSE_TARGET_UNKNOW), function_line_offset_(-1) {}
~ParseAst() = default;
~ParseFunctionAst() = default;
bool InitParseAstInfo(const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD);
py::object GetAstNode();
py::str GetAstNodeText(const py::object &node);
py::list GetArgs(const py::object &func_node);
py::list GetArgsDefaultValues(const py::object &func_node);
@ -360,6 +370,7 @@ class ParseAst {
// Function or class method.
py::function function_;
py::object ast_tokens_;
py::object ast_tree_;
py::object parser_;
py::module module_;

View File

@ -63,7 +63,8 @@ const char PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER[] = "is_class_member";
const char PYTHON_MOD_RESOLVE_GET_OBJ_TYPE[] = "get_obj_type";
const char PYTHON_MOD_GET_OBJ_ID[] = "get_obj_id";
const char PYTHON_MOD_GET_CLASS_INSTANCE_TYPE[] = "get_class_instance_type";
const char PYTHON_MOD_CREATE_OBJ_INSTANCE[] = "create_obj_instance";
const char PYTHON_MOD_CREATE_INSTANCE[] = "create_instance";
const char PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE[] = "is_supported_create_instance_type";
const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes";
const char PYTHON_MOD_GET_DATACLASS_METHODS[] = "get_dataclass_methods";
const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
@ -72,6 +73,7 @@ const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script";
const char PYTHON_PARSE_GET_ARGS[] = "get_args";
const char PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES[] = "get_args_default_values";
@ -80,6 +82,7 @@ const char PYTHON_PARSE_GET_AST_TYPE[] = "get_ast_type";
const char PYTHON_PARSE_GET_NAMESPACE_SYMBOL[] = "get_namespace_symbol";
const char PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL[] = "get_ast_namespace_symbol";
const char PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL[] = "get_operation_namespace_symbol";
const char PYTHON_PARSE_GET_BUILTIN_NAMESPACE_SYMBOL[] = "get_builtin_namespace_symbol";
const char PYTHON_PARSE_GET_LOCATION[] = "get_location";
const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement";
const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope";

View File

@ -33,7 +33,7 @@ static const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRI
parse::NAMED_PRIMITIVE_NAMECONSTANT,
parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR};
std::string DynamicParser::ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
std::string DynamicParser::ParseNodeName(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node,
parse::AstMainType type) {
MS_EXCEPTION_IF_NULL(ast);
if (py::isinstance<py::none>(node)) {
@ -53,7 +53,7 @@ std::string DynamicParser::ParseNodeName(const std::shared_ptr<parse::ParseAst>
return node_name;
}
void DynamicParser::ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node) {
void DynamicParser::ParseInputArgs(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &fn_node) {
MS_EXCEPTION_IF_NULL(ast);
py::list args = ast->GetArgs(fn_node);
for (size_t i = 1; i < args.size(); i++) {
@ -63,7 +63,7 @@ void DynamicParser::ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast,
}
}
bool DynamicParser::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
bool DynamicParser::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node) {
MS_LOG(DEBUG) << "Parse if/while expr";
py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST);
const auto &node_name = ParseNodeName(ast, test_node, parse::AST_MAIN_TYPE_EXPR);
@ -112,7 +112,7 @@ bool DynamicParser::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst>
return false;
}
bool DynamicParser::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
bool DynamicParser::ParseAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node) {
MS_LOG(DEBUG) << "Parse assign expr";
py::object value_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_VALUE);
const auto &node_name = ParseNodeName(ast, value_node, parse::AST_MAIN_TYPE_EXPR);
@ -140,7 +140,7 @@ bool DynamicParser::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &
return false;
}
bool DynamicParser::ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
bool DynamicParser::ParseAugAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node,
const std::vector<std::string> &compare_prim) {
MS_LOG(DEBUG) << "Parse augassign expr";
bool ret = false;
@ -168,7 +168,7 @@ bool DynamicParser::ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst
return ret;
}
bool DynamicParser::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
bool DynamicParser::ParseForExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node) {
MS_LOG(DEBUG) << "Parse for expr";
py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY);
if (py::isinstance<py::none>(body_node)) {
@ -188,7 +188,7 @@ bool DynamicParser::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast
return false;
}
bool DynamicParser::ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node,
bool DynamicParser::ParseBodyContext(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &fn_node,
const std::vector<std::string> &compare_prim) {
MS_EXCEPTION_IF_NULL(ast);
py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY);
@ -236,7 +236,7 @@ bool DynamicParser::IsDynamicCell(const py::object &cell) {
return false;
}
// Using ast parse to check whether the construct of cell will be changed
auto ast = std::make_shared<parse::ParseAst>(cell);
auto ast = std::make_shared<parse::ParseFunctionAst>(cell);
bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD);
if (!success) {
MS_LOG(ERROR) << "Parse code to ast tree failed";

View File

@ -36,15 +36,15 @@ class DynamicParser {
private:
static std::string GetCellInfo(const py::object &cell);
static void ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node);
static bool ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node,
static void ParseInputArgs(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &fn_node);
static bool ParseBodyContext(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &fn_node,
const std::vector<std::string> &compare_prim = {});
static bool ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
static bool ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
static bool ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
static bool ParseIfWhileExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node);
static bool ParseAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node);
static bool ParseAugAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node,
const std::vector<std::string> &compare_prim = {});
static bool ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
static std::string ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
static bool ParseForExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node);
static std::string ParseNodeName(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node,
parse::AstMainType type);
};
} // namespace mindspore::parse

View File

@ -43,9 +43,22 @@ abstract::AbstractBasePtr ClassObject::ToAbstract() {
return std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list);
}
static inline bool IsSupportedCreateInstanceType(const py::object &obj) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
auto res = python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE, obj);
if (!py::isinstance<py::bool_>(res)) {
MS_LOG(ERROR) << "Expect a bool type, but got " << py::str(res);
return false;
}
return res.cast<bool>();
}
abstract::AbstractBasePtr ClassType::ToAbstract() {
auto abs_scalar =
std::make_shared<abstract::AbstractScalar>(shared_from_base<ClassType>(), std::make_shared<TypeType>());
if (!IsSupportedCreateInstanceType(obj())) {
return abs_scalar;
}
AbstractBasePtrList args_spec_list = {abs_scalar};
auto func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimCreateInstance);
@ -333,6 +346,7 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa
auto new_namespace = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj);
std::string attr_as_string = GetValueNode<StringImmPtr>(attr)->value();
auto new_symbol = std::make_shared<Symbol>(attr_as_string);
MS_LOG(DEBUG) << "name_space: " << new_namespace->ToString() << ", symbol: " << new_symbol->ToString();
AnfNodePtrList inputs = {NewValueNode(prim::kPrimResolve), NewValueNode(new_namespace), NewValueNode(new_symbol)};
AnfNodePtr resolved_node = node->func_graph()->NewCNode(inputs);

View File

@ -36,15 +36,16 @@ using ResourceBasePtr = std::shared_ptr<ResourceBase>;
namespace mindspore {
namespace parse {
// NameSpace class for resolving python code.
class NameSpace : public Named {
public:
NameSpace(const std::string &module, const py::object &obj) : Named(module), module_(module), obj_(obj) {}
NameSpace(const std::string &module, const py::object &obj, const py::object &module_obj = py::object())
: Named(module), module_(module), obj_(obj), module_obj_(module_obj) {}
~NameSpace() override = default;
MS_DECLARE_PARENT(NameSpace, Named);
py::object obj() { return obj_; }
py::object module_obj() { return module_obj_; }
std::string module() { return module_; }
abstract::AbstractBasePtr ToAbstract() override {
return std::make_shared<abstract::AbstractScalar>(shared_from_base<NameSpace>(), std::make_shared<External>());
@ -55,6 +56,8 @@ class NameSpace : public Named {
std::string module_;
// namespace object
py::object obj_;
// module object
py::object module_obj_;
};
using NameSpacePtr = std::shared_ptr<NameSpace>;
@ -62,7 +65,7 @@ using NameSpacePtr = std::shared_ptr<NameSpace>;
class Symbol : public Named {
public:
explicit Symbol(const std::string &symbol) : Named(symbol), symbol_(symbol) {}
explicit Symbol(const std::string &symbol, const std::string &name) : Named(name), symbol_(symbol) {}
Symbol(const std::string &symbol, const std::string &name) : Named(name), symbol_(symbol) {}
~Symbol() override = default;
MS_DECLARE_PARENT(Symbol, Named);
@ -77,6 +80,25 @@ class Symbol : public Named {
};
using SymbolPtr = std::shared_ptr<Symbol>;
class Script : public Named {
public:
explicit Script(const std::string &script) : Named(script), script_(script) {}
Script(const std::string &script, const std::string &name) : Named(name), script_(script) {}
~Script() override = default;
MS_DECLARE_PARENT(Script, Named);
std::string script() { return script_; }
abstract::AbstractBasePtr ToAbstract() override {
return std::make_shared<abstract::AbstractScript>(shared_from_base<Script>());
}
std::string ToString() const override { return "`" + name() + "`"; }
private:
std::string script_;
};
using ScriptPtr = std::shared_ptr<Script>;
// PyObjectWrapper class wrappers resolved python object for further processing.
class PyObjectWrapper : public Named {
public:

View File

@ -359,7 +359,7 @@ py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data,
MS_EXCEPTION_IF_NULL(output_node);
if (output_node->isa<ValueNode>()) {
return ValuePtrToPyData(GetValueNode(output_node));
return ValueToPyData(GetValueNode(output_node));
}
if (output_node->isa<Parameter>()) {

View File

@ -402,15 +402,15 @@ EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) {
if (args_conf_list.empty()) {
MS_LOG(EXCEPTION) << "Size should greater than 0";
}
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
return conf->ObtainEvalResult()->abstract();
});
if (args_conf_list.empty()) {
MS_LOG(EXCEPTION) << "Size should greater than 0";
}
EvalResultPtr res = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf);
// No need to cache.
return res;

View File

@ -298,7 +298,7 @@ py::object BuildValue(const ValuePtr &value_ptr) {
if (value_ptr == nullptr) {
return py::none();
} else {
return ValuePtrToPyData(value_ptr);
return ValueToPyData(value_ptr);
}
}
@ -786,23 +786,6 @@ EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveIm
return uniform_primitive_evaluator;
}
const int64_t kResolveCaseUserDefineClass = 1;
const int64_t kResolveCaseBuiltInType = 2;
const int64_t kResolveCaseFunction = 3;
int64_t GetResolveCase(const TypePtr &data_type) {
MS_EXCEPTION_IF_NULL(data_type);
if (data_type->type_id() == kObjectTypeClass) {
return kResolveCaseUserDefineClass;
}
// try method map, if not in method map, the data_type should be External type.
if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) {
return kResolveCaseBuiltInType;
}
return kResolveCaseFunction;
}
FuncGraphPtr PyObjToGraph(const AnalysisEnginePtr &engine, const ValuePtr &method) {
MS_EXCEPTION_IF_NULL(engine);
MS_EXCEPTION_IF_NULL(method);
@ -883,18 +866,18 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, con
MS_LOG(EXCEPTION) << "Data is not NameSpace : " << data_v->ToString();
}
auto item_v = args_spec_list[1]->BuildValue();
MS_EXCEPTION_IF_NULL(item_v);
if (item_v->isa<StringImm>()) {
item_v = std::make_shared<parse::Symbol>(item_v->cast<StringImmPtr>()->value());
auto item_value = args_spec_list[1]->BuildValue();
MS_EXCEPTION_IF_NULL(item_value);
if (item_value->isa<StringImm>()) {
item_value = std::make_shared<parse::Symbol>(item_value->cast<StringImmPtr>()->value());
}
if (!item_v->isa<parse::Symbol>()) {
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_v->ToString();
if (!item_value->isa<parse::Symbol>()) {
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
}
// item_name to func addr from obj_map
parse::SymbolPtr symbol = item_v->cast<parse::SymbolPtr>();
parse::SymbolPtr symbol = item_value->cast<parse::SymbolPtr>();
parse::NameSpacePtr name_space = data_v->cast<parse::NameSpacePtr>();
MS_EXCEPTION_IF_NULL(out_conf);
auto out_node = out_conf->node();
@ -915,19 +898,20 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, con
}
EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine,
const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v,
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
const AbstractBasePtrList &args_spec_list,
const ValuePtr &item_value, const ConfigPtr &data_conf,
const AnfNodeConfigPtr &out_conf) {
if (args_spec_list.empty()) {
MS_LOG(EXCEPTION) << "args_spec_list is empty";
}
AbstractClassPtr cls = CheckArg<AbstractClass>("__FUNC__", args_spec_list, 0);
// If item_v is an attribute, get abstract value from AbstractClass
MS_EXCEPTION_IF_NULL(item_v);
if (!item_v->isa<StringImm>()) {
// If item_value is an attribute, get abstract value from AbstractClass
MS_EXCEPTION_IF_NULL(item_value);
if (!item_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "Attribute type error";
}
std::string item_name = item_v->cast<StringImmPtr>()->value();
std::string item_name = item_value->cast<StringImmPtr>()->value();
MS_LOG(DEBUG) << "Resolve name: " << cls->tag().name();
MS_LOG(DEBUG) << "Resolve item: " << item_name;
MS_EXCEPTION_IF_NULL(cls);
@ -941,25 +925,25 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
MS_EXCEPTION_IF_NULL(args_spec_list[0]->BuildType());
MS_EXCEPTION(AttributeError) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString()
<< ", item value: " << item_v->ToString();
<< ", item value: " << item_value->ToString();
}
// Infer class method
ValuePtr converted_v = PyObjToGraph(engine, method);
return StaticGetterInferred(converted_v, data_conf, out_conf);
ValuePtr converted_value = PyObjToGraph(engine, method);
return StaticGetterInferred(converted_value, data_conf, out_conf);
}
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v,
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
const TypePtr &data_type, const ConfigPtr &data_conf,
const AnfNodeConfigPtr &out_conf) {
MS_EXCEPTION_IF_NULL(item_v);
MS_EXCEPTION_IF_NULL(item_value);
MS_EXCEPTION_IF_NULL(data_type);
// The method maybe a Primitive or Composite
if (!item_v->isa<StringImm>()) {
if (!item_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "Error item is not string";
}
std::string item_name = item_v->cast<StringImmPtr>()->value();
std::string item_name = item_value->cast<StringImmPtr>()->value();
REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD;
Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name);
if (require.empty()) {
@ -971,20 +955,38 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePt
require_type = REQUIRE_TYPE::ATTR;
}
ValuePtr converted_v = nullptr;
ValuePtr converted_value = nullptr;
if (require.is<std::string>()) {
// composite registered in standard_method_map go to this branch
converted_v = prim::GetPythonOps(require.cast<std::string>());
MS_EXCEPTION_IF_NULL(converted_v);
if (!converted_v->isa<Primitive>()) {
AddToManager(engine, converted_v->cast<FuncGraphPtr>());
converted_value = prim::GetPythonOps(require.cast<std::string>());
MS_EXCEPTION_IF_NULL(converted_value);
if (!converted_value->isa<Primitive>()) {
AddToManager(engine, converted_value->cast<FuncGraphPtr>());
}
} else if (require.is<PrimitivePtr>()) {
converted_v = require.cast<PrimitivePtr>();
converted_value = require.cast<PrimitivePtr>();
} else {
MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from attr or method map, but got " << require.ToString();
}
return StaticGetterInferred(converted_v, data_conf, out_conf, require_type);
return StaticGetterInferred(converted_value, data_conf, out_conf, require_type);
}
enum ResolveType : int64_t {
kResolveTypeUserDefineClass = 1,
kResolveTypeBuiltInType,
kResolveTypeFunction,
};
int64_t GetResolveType(const TypePtr &data_type) {
MS_EXCEPTION_IF_NULL(data_type);
if (data_type->type_id() == kObjectTypeClass) {
return kResolveTypeUserDefineClass;
}
// Try to search method map, if not found, the data_type should be External type.
if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) {
return kResolveTypeBuiltInType;
}
return kResolveTypeFunction;
}
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
@ -1006,10 +1008,10 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
}
int64_t case_v = GetResolveCase(data_type);
if (case_v == kResolveCaseUserDefineClass) {
int64_t resolve_type = GetResolveType(data_type);
if (resolve_type == kResolveTypeUserDefineClass) {
return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf);
} else if (case_v == kResolveCaseBuiltInType) {
} else if (resolve_type == kResolveTypeBuiltInType) {
return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, item_value, data_type, data_conf, out_conf);
} else {
return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf);
@ -1234,12 +1236,12 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
return infer_result;
}
pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const {
py::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const {
// Exclude class type by minus 1;
std::size_t params_size = args_spec_list.size() - 1;
auto params = py::tuple(params_size);
if (params_size > params.size()) {
MS_LOG(EXCEPTION) << "Unexpected params_size:" << params_size << ",params.size():" << params.size();
MS_LOG(EXCEPTION) << "Unexpected params_size: " << params_size << ", params.size():" << params.size();
}
if (params_size > 0) {
for (size_t i = 0; i < params_size; i++) {
@ -1248,7 +1250,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
MS_EXCEPTION_IF_NULL(arg);
// Because the Tensor's AbstractTensor can't get value from GetValueTrack.
ValuePtr param_value = arg->BuildValue();
py::object param = ValuePtrToPyData(param_value);
py::object param = ValueToPyData(param_value);
params[i] = param;
}
}
@ -1256,6 +1258,80 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
}
};
class PyInterpretEvaluator : public TransitionPrimEvaluator {
public:
PyInterpretEvaluator() : TransitionPrimEvaluator("PyInterpretEvaluator") {}
~PyInterpretEvaluator() override = default;
MS_DECLARE_PARENT(PyInterpretEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override {
if (args_spec_list.empty()) {
MS_LOG(ERROR) << "'args_spec_list' should not be empty";
}
// Get the type parameter.
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
ValuePtr value_track = args_spec_list[0]->GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track);
std::shared_ptr<parse::Script> script_obj = dyn_cast<parse::Script>(value_track);
if (script_obj == nullptr) {
MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
}
// Make global and local parameters.
py::tuple params = MakeParameters(args_spec_list);
// Call python script string.
MS_LOG(DEBUG) << "Call script: " << script_obj->script() << ", params: " << py::str(params);
auto obj = parse::data_converter::CallPythonScript(py::str(script_obj->script()), params);
if (py::isinstance<py::none>(obj)) {
MS_LOG(EXCEPTION) << "Failed to call python script: `" << script_obj->script() << "`";
}
ValuePtr converted_val = nullptr;
bool converted = parse::ConvertData(obj, &converted_val, true);
if (!converted) {
MS_LOG(EXCEPTION) << "Convert the python object failed";
}
MS_EXCEPTION_IF_NULL(converted_val);
AbstractBasePtr res = ToAbstract(converted_val, AnalysisContext::DummyContext(), out_conf);
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
return infer_result;
}
py::tuple MakeParameters(const AbstractBasePtrList &args_spec_list) const {
constexpr int params_size = 3;
if (params_size != args_spec_list.size()) {
MS_LOG(EXCEPTION) << "Unexpected params_size: " << params_size
<< ", not equal to arguments.size:" << args_spec_list.size();
}
// The first argument is script string, ignore it.
auto params = py::tuple(params_size - 1);
// Make the global parameters.
auto global_dict = dyn_cast<AbstractDictionary>(args_spec_list[1]); // Global parameters dict.
MS_EXCEPTION_IF_NULL(global_dict);
MS_LOG(DEBUG) << "arg_1, global_dict: " << global_dict->ToString() << ", [" << global_dict->type_name() << "]";
ValuePtr global_dict_value = global_dict->BuildValue();
py::object global_params_dict = ValueToPyData(global_dict_value);
MS_LOG(DEBUG) << "arg_1, python global_params_dict: " << py::str(global_params_dict);
params[0] = global_params_dict;
// Make the local parameters.
auto local_dict = dyn_cast<AbstractDictionary>(args_spec_list[2]); // Local parameters dict.
MS_EXCEPTION_IF_NULL(local_dict);
MS_LOG(DEBUG) << "arg_2, local_dict: " << local_dict->ToString() << ", [" << local_dict->type_name() << "]";
ValuePtr local_dict_value = local_dict->BuildValue();
py::object local_params_dict = ValueToPyData(local_dict_value);
MS_LOG(DEBUG) << "arg_2, python local_params_dict: " << py::str(local_params_dict);
params[1] = local_params_dict;
return params;
}
};
class PartialEvaluator : public Evaluator {
public:
PartialEvaluator() : Evaluator("PartialEvaluator") {}
@ -1399,6 +1475,7 @@ void InitPrimEvaluatorConstructors() {
constructor[prim::kPrimResolve] = std::make_shared<ResolveEvaluator>();
constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>();
constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>();
constructor[prim::kPrimPyInterpret] = std::make_shared<PyInterpretEvaluator>();
}
} // namespace

View File

@ -286,7 +286,8 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(possible_func);
if (func == nullptr) {
MS_LOG(ERROR) << "Can not cast to a AbstractFunction: " << possible_func->ToString() << ".";
MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << possible_func->ToString() << ".";
MS_LOG(ERROR) << "It's called at: " << cnode->DebugString();
MS_EXCEPTION(ValueError) << "This may be not defined, and it can't be a operator. Please check code.";
}

View File

@ -291,7 +291,7 @@ py::function PrimitivePy::GetComputeFunction() const {
py::dict PrimitivePy::GetAttrDict() {
py::dict attr_dict;
for (auto &attr : attrs_) {
attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second);
attr_dict[py::str(attr.first)] = ValueToPyData(attr.second);
}
return attr_dict;
}
@ -430,7 +430,7 @@ py::dict PrimitivePyAdapter::GetAttrDict() {
py::dict attr_dict;
for (auto &attr : attrs_) {
attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second);
attr_dict[py::str(attr.first)] = ValueToPyData(attr.second);
}
return attr_dict;
}

View File

@ -28,6 +28,7 @@
#include "abstract/utils.h"
#include "pipeline/jit/parse/parse.h"
#include "pipeline/jit/parse/parse_base.h"
#include "pipeline/jit/parse/resolve.h"
#include "ir/value.h"
#include "ir/tensor.h"
#include "ir/param_info.h"
@ -106,80 +107,125 @@ py::object ScalarPtrToPyData(const ScalarPtr &value) {
}
}
py::object ValuePtrToPyData(const ValuePtr &value) {
if (value == nullptr) {
MS_LOG(EXCEPTION) << "value is null";
}
py::object ret;
if (value->isa<Scalar>()) {
ret = ScalarPtrToPyData(value->cast<ScalarPtr>());
} else if (value->isa<StringImm>()) {
MS_LOG(DEBUG) << "String";
py::str v = value->cast<StringImmPtr>()->value();
ret = v;
} else if (value->isa<tensor::Tensor>()) {
MS_LOG(DEBUG) << "tensor";
auto tensor_ptr = value->cast<tensor::TensorPtr>();
ret = TensorToPyData(tensor_ptr);
} else if (value->isa<tensor::MetaTensor>()) {
MS_LOG(DEBUG) << "MetaTensor";
py::tuple v(1);
v[0] = value->cast<tensor::MetaTensorPtr>();
ret = v[0];
} else if (value->isa<RefKey>()) {
MS_LOG(DEBUG) << "RefKey";
py::tuple v(1);
v[0] = value->cast<RefKeyPtr>();
ret = v[0];
} else if (value->isa<ValueSequeue>()) {
MS_LOG(DEBUG) << "tuple or list";
auto value_sequeue = value->cast<ValueSequeuePtr>()->value();
py::tuple ret_sequeue(value_sequeue.size());
using ConverterFunction = std::function<py::object(const ValuePtr &value)>;
using ValueNameToConverterVector = std::vector<std::pair<const char *, ConverterFunction>>;
for (size_t i = 0; i < value_sequeue.size(); i++) {
ret_sequeue[i] = ValuePtrToPyData(value_sequeue[i]);
}
if (value->isa<ValueTuple>()) {
ret = ret_sequeue;
} else {
ret = ret_sequeue.cast<py::list>();
}
} else if (value->isa<ValueDictionary>()) {
MS_LOG(DEBUG) << "dict";
auto value_list = value->cast<ValueDictionaryPtr>()->value();
py::dict ret_dict;
for (const auto &v : value_list) {
ret_dict[py::str(v.first)] = ValuePtrToPyData(v.second);
}
ret = ret_dict;
} else if (value->isa<Ellipsis>()) {
ret = py::ellipsis();
} else if (value->isa<ValueSlice>()) {
auto slice = value->cast<ValueSlicePtr>();
auto start = ValuePtrToPyData(slice->start());
auto end = ValuePtrToPyData(slice->stop());
auto step = ValuePtrToPyData(slice->step());
ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end,
step);
} else if (value->isa<Type>()) {
py::tuple v(1);
v[0] = value->cast<TypePtr>();
ret = v[0];
} else if (value->isa<AnyValue>() || value->isa<None>() || value->isa<Monad>() || value->isa<FuncGraph>()) {
// FuncGraph is not used in the backend, return None
ret = py::none();
} else if (value->isa<KeywordArg>()) {
auto abs_keyword_arg = value->ToAbstract()->cast<abstract::AbstractKeywordArgPtr>();
auto key = abs_keyword_arg->get_key();
auto val = abs_keyword_arg->get_arg()->BuildValue();
auto py_value = ValuePtrToPyData(val);
auto kwargs = py::kwargs();
kwargs[key.c_str()] = py_value;
ret = kwargs;
} else {
MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
// (Value Type Name) -> (Converter Function)
// The converter function is used to convert Value object to Python data object.
static ValueNameToConverterVector value_name_to_converter = {
// Scalar
{typeid(Scalar).name(),
[](const ValuePtr &value) -> py::object { return ScalarPtrToPyData(value->cast<ScalarPtr>()); }},
// Tensor
{typeid(tensor::Tensor).name(),
[](const ValuePtr &value) -> py::object {
auto tensor_ptr = value->cast<tensor::TensorPtr>();
return TensorToPyData(tensor_ptr);
}},
// MetaTenser
{typeid(tensor::MetaTensor).name(),
[](const ValuePtr &value) -> py::object {
py::tuple tuple_container(1);
tuple_container[0] = value->cast<tensor::MetaTensorPtr>();
return tuple_container[0];
}},
// RefKey
{typeid(RefKey).name(),
[](const ValuePtr &value) -> py::object {
py::tuple tuple_container(1);
tuple_container[0] = value->cast<RefKeyPtr>();
return tuple_container[0];
}},
// Type
{typeid(Type).name(),
[](const ValuePtr &value) -> py::object {
py::tuple tuple_container(1);
tuple_container[0] = value->cast<TypePtr>();
return tuple_container[0];
}},
// StringImm
{typeid(StringImm).name(),
[](const ValuePtr &value) -> py::object {
py::str res = value->cast<StringImmPtr>()->value();
return res;
}},
// ValueSequeue
{typeid(ValueSequeue).name(),
[](const ValuePtr &value) -> py::object {
auto value_sequeue = value->cast<ValueSequeuePtr>()->value();
py::tuple res_sequeue(value_sequeue.size());
for (size_t i = 0; i < value_sequeue.size(); i++) {
res_sequeue[i] = ValueToPyData(value_sequeue[i]);
}
if (value->isa<ValueTuple>()) {
return res_sequeue;
}
return res_sequeue.cast<py::list>();
}},
// ValueDictionary
{typeid(ValueDictionary).name(),
[](const ValuePtr &value) -> py::object {
auto value_list = value->cast<ValueDictionaryPtr>()->value();
py::dict res_dict;
for (const auto &value : value_list) {
res_dict[py::str(value.first)] = ValueToPyData(value.second);
}
return res_dict;
}},
// ValueSlice
{typeid(ValueSlice).name(),
[](const ValuePtr &value) -> py::object {
auto slice = value->cast<ValueSlicePtr>();
auto start = ValueToPyData(slice->start());
auto end = ValueToPyData(slice->stop());
auto step = ValueToPyData(slice->step());
return parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end,
step);
}},
// KeywordArg
{typeid(KeywordArg).name(),
[](const ValuePtr &value) -> py::object {
auto abs_keyword_arg = value->ToAbstract()->cast<abstract::AbstractKeywordArgPtr>();
auto key = abs_keyword_arg->get_key();
auto val = abs_keyword_arg->get_arg()->BuildValue();
auto py_value = ValueToPyData(val);
auto kwargs = py::kwargs();
kwargs[key.c_str()] = py_value;
return kwargs;
}},
// parse::NameSpace
{typeid(parse::NameSpace).name(),
[](const ValuePtr &value) -> py::object {
auto ns = value->cast<parse::NameSpacePtr>();
return ns->module_obj();
}},
// parse::ClassType
{typeid(parse::ClassType).name(),
[](const ValuePtr &value) -> py::object {
auto class_type = value->cast<parse::ClassTypePtr>();
return class_type->obj();
}},
// None
{typeid(None).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
// AnyValue
{typeid(AnyValue).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
// FuncGraph
{typeid(FuncGraph).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
// Monad
{typeid(Monad).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
// Ellipsis
{typeid(Ellipsis).name(), [](const ValuePtr &value) -> py::object { return py::ellipsis(); }}};
py::object ValueToPyData(const ValuePtr &value) {
if (value == nullptr) {
MS_LOG(EXCEPTION) << "The `value` should not be null";
}
return ret;
for (auto &iter : value_name_to_converter) {
if (value->IsFromTypeId(Base::GetTypeId(iter.first))) {
return iter.second(value);
}
}
MS_LOG(EXCEPTION) << "Unsupported to convert " << value->ToString() << "[" << value->type_name() << "] to a PyData";
}
py::object AnyToPyData(const Any &value) {
@ -190,7 +236,7 @@ py::object AnyToPyData(const Any &value) {
} else if (value.is<ValuePtr>()) {
MS_LOG(DEBUG) << "ValuePtr";
ValuePtr v = value.cast<ValuePtr>();
ret = ValuePtrToPyData(v);
ret = ValueToPyData(v);
} else if (value.is<tensor::TensorPtr>()) {
MS_LOG(DEBUG) << "tensor";
auto tensor_ptr = value.cast<tensor::TensorPtr>();
@ -233,7 +279,7 @@ py::object BaseRefToPyData(const BaseRef &value) {
} else if (utils::isa<ValuePtr>(value)) {
MS_LOG(DEBUG) << "ValuePtr";
ValuePtr v = utils::cast<ValuePtr>(value);
ret = ValuePtrToPyData(v);
ret = ValueToPyData(v);
} else if (utils::isa<tensor::TensorPtr>(value)) {
MS_LOG(DEBUG) << "tensor";
auto tensor_ptr = utils::cast<tensor::TensorPtr>(value);
@ -459,7 +505,7 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
if (output->isa<ValueNode>()) {
MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
ValuePtr value = GetValueNode(output);
*ret_val = ValuePtrToPyData(value);
*ret_val = ValueToPyData(value);
return true;
}

View File

@ -31,7 +31,7 @@ namespace py = pybind11;
namespace mindspore {
py::object AnyToPyData(const Any &value);
py::object BaseRefToPyData(const BaseRef &value);
py::object ValuePtrToPyData(const ValuePtr &value);
py::object ValueToPyData(const ValuePtr &value);
bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
const std::shared_ptr<py::object> &ret_val);

View File

@ -320,6 +320,38 @@ std::string TypedPrimitiveAbstractClosure::ToString() const {
return buffer.str();
}
bool PyInterpretAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<PyInterpretAbstractClosure>()) {
return false;
}
auto other_partial = static_cast<const PyInterpretAbstractClosure *>(&other);
if (fn_ != other_partial->fn_) {
return false;
}
if (args_spec_list_.size() != other_partial->args_spec_list_.size()) {
return false;
}
return args_spec_list_ == other_partial->args_spec_list_;
}
std::size_t PyInterpretAbstractClosure::hash() const {
MS_EXCEPTION_IF_NULL(fn_);
auto hash_value = hash_combine(tid(), fn_->hash());
hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_));
return hash_value;
}
std::string PyInterpretAbstractClosure::ToString() const {
std::ostringstream buffer;
buffer << "PyInterpretAbstractClosure(" << fn_->ToString() << "(";
for (const auto &arg : args_spec_list_) {
MS_EXCEPTION_IF_NULL(arg);
buffer << arg->ToString() << ", ";
}
buffer << "))";
return buffer.str();
}
bool DummyAbstractClosure::operator==(const AbstractFunction &other) const {
return !other.isa<DummyAbstractClosure>();
}

View File

@ -192,7 +192,7 @@ class MS_CORE_API PartialAbstractClosure : public AbstractFuncAtom {
MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom)
AbstractFunctionPtr fn() { return fn_; }
AbstractBasePtrList args() { return args_spec_list_; }
AbstractBasePtrList &args() { return args_spec_list_; }
ValuePtr RealBuildValue() const override { return fn_->BuildValue(); }
AnfNodePtr node() { return node_.lock(); }
void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); }
@ -287,6 +287,34 @@ class MS_CORE_API TypedPrimitiveAbstractClosure : public AbstractFuncAtom {
AbstractBasePtr output_;
};
class PyInterpretAbstractClosure : public AbstractFuncAtom {
public:
PyInterpretAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list,
const AnfNodePtr &node = nullptr)
: fn_(fn), args_spec_list_(args_spec_list), node_(AnfNodePtr(node)) {}
~PyInterpretAbstractClosure() override = default;
MS_DECLARE_PARENT(PyInterpretAbstractClosure, AbstractFuncAtom)
AbstractFunctionPtr fn() { return fn_; }
AbstractBasePtrList args() { return args_spec_list_; }
ValuePtr RealBuildValue() const override { return fn_->BuildValue(); }
AnfNodePtr node() { return node_.lock(); }
void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); }
AbstractFunctionPtr Copy() const override {
return std::make_shared<PyInterpretAbstractClosure>(fn_, args_spec_list_, node_.lock());
}
bool operator==(const AbstractFunction &other) const override;
std::size_t hash() const override;
std::string ToString() const override;
private:
AbstractFuncAtomPtr fn_;
AbstractBasePtrList args_spec_list_;
AnfNodeWeakPtr node_;
};
using PyInterpretAbstractClosurePtr = std::shared_ptr<PyInterpretAbstractClosure>;
// Represents a function that can't be called.
class MS_CORE_API DummyAbstractClosure : public AbstractFuncAtom {
public:

View File

@ -104,7 +104,7 @@ class MS_CORE_API AbstractBase : public Base {
class MS_CORE_API AbstractScalar : public AbstractBase {
public:
AbstractScalar() : AbstractBase(kAnyValue, kAnyType) {}
explicit AbstractScalar(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}
AbstractScalar(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}
explicit AbstractScalar(const ValuePtr &value) : AbstractBase(value, value->type()) {}
explicit AbstractScalar(int value) : AbstractBase(MakeValue(value), kInt32) {}
explicit AbstractScalar(int64_t value) : AbstractBase(MakeValue(value), kInt64) {}
@ -148,7 +148,7 @@ using AbstractTypePtr = std::shared_ptr<AbstractType>;
class MS_CORE_API AbstractError : public AbstractBase {
public:
explicit AbstractError(const StringImmPtr &err, const AnfNodePtr &node) : AbstractBase(err), node_(node) {
AbstractError(const StringImmPtr &err, const AnfNodePtr &node) : AbstractBase(err), node_(node) {
if (err == nullptr || node == nullptr) {
MS_LOG(EXCEPTION) << "err or node is nullptr";
}
@ -170,6 +170,25 @@ class MS_CORE_API AbstractError : public AbstractBase {
const AnfNodePtr node_;
};
class MS_CORE_API AbstractScript : public AbstractBase {
public:
AbstractScript() : AbstractBase(kAnyValue, kAnyType) {}
AbstractScript(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}
explicit AbstractScript(const ValuePtr &value) : AbstractBase(value, kString) {}
// explicit AbstractScript(const std::string &value) : AbstractBase(MakeValue(value), kString) {}
~AbstractScript() override = default;
MS_DECLARE_PARENT(AbstractScript, AbstractBase)
std::size_t hash() const override { return hash_combine({tid(), GetValueTrack()->hash(), GetTypeTrack()->hash()}); }
TypePtr BuildType() const override { return GetTypeTrack(); }
AbstractBasePtr Clone() const override {
return std::make_shared<AbstractScript>(GetValueTrack(), GetTypeTrack()->Clone());
}
AbstractBasePtr Broaden() const override { return Clone(); }
};
using AbstractScriptPtr = std::shared_ptr<AbstractScript>;
class Evaluator;
using EvaluatorPtr = std::shared_ptr<Evaluator>;
class AnalysisEngine;

View File

@ -604,6 +604,9 @@ inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
// Python interpreter runner
inline const PrimitivePtr kPrimPyInterpret = std::make_shared<Primitive>("PyInterpret");
// Other primitive not used by backend but used in core;
inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J", kSideEffectPropagate);

View File

@ -107,7 +107,9 @@ class MS_CORE_API AnfNode : public Base {
hash_(std::hash<const AnfNode *>()),
kernel_info_(nullptr),
stage_(-1),
need_grad_(false) {
need_grad_(false),
interpret_(false),
interpreted_node_(nullptr) {
scope_ = ScopeManager::GetInstance().GetCurrentScope();
}
@ -204,6 +206,12 @@ class MS_CORE_API AnfNode : public Base {
bool grad() { return need_grad_; }
void set_grad(const bool &need_grad) { need_grad_ = need_grad; }
bool interpret() { return interpret_; }
void set_interpret(const bool interpret) { interpret_ = interpret; }
AnfNodePtr interpreted_node() { return interpreted_node_; }
void set_interpreted_node(const AnfNodePtr &node) { interpreted_node_ = node; }
protected:
// Hold a weak ref to Graph as Graph also hold ref to AnfNode.
// Otherwise, func_graph_ and AnfNode will make a reference cycle.
@ -220,6 +228,8 @@ class MS_CORE_API AnfNode : public Base {
UserData user_data_;
int64_t stage_;
bool need_grad_;
bool interpret_;
AnfNodePtr interpreted_node_;
};
// CNode represents the complex node with a set of arguments.

View File

@ -291,7 +291,7 @@ TEST_F(TestStepParallel, CreatOpInstance) {
std::vector<py::object> arglist;
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arglist),
[](Attr attr) { return ValuePtrToPyData(attr.second); });
[](Attr attr) { return ValueToPyData(attr.second); });
py::object allreduce_pyobj = parse::python_adapter::CallPyFn(
"mindspore.parallel._utils", "_get_python_op", "AllReduce", "mindspore.ops.operations", "test", arglist);
py::dict opAttr = py::getattr(allreduce_pyobj, "attrs");

View File

@ -65,7 +65,7 @@ TEST_F(TestParser, TestParseApi) {
TEST_F(TestParser, TestParseAst) {
GetPythonFunction("test_f");
ParseAst ast = ParseAst(fn);
ParseFunctionAst ast = ParseFunctionAst(fn);
bool succ = ast.InitParseAstInfo();
ASSERT_TRUE(succ = true);

View File

@ -47,7 +47,8 @@ def test_use_numpy_method():
net = Net()
with pytest.raises(NotImplementedError) as err:
net()
assert "MindSpore does not support to use the numpy methods in the function construct with the graph mode." \
assert "Mindspore does not support to use the numpy methods " \
"within the construct() or @ms_function decorated function in graph mode." \
in str(err.value)
@ -63,5 +64,6 @@ def test_use_numpy_module():
net = Net()
with pytest.raises(NotImplementedError) as err:
net()
assert "MindSpore does not support to use the numpy methods in the function construct with the graph mode." \
assert "Mindspore does not support to use the numpy methods " \
"within the construct() or @ms_function decorated function in graph mode." \
in str(err.value)