forked from mindspore-Ecosystem/mindspore
!22843 Support fallback feature in Graph mode.
Merge pull request !22843 from 张清华/opt_fallback
This commit is contained in:
commit
05a0898352
|
@ -13,9 +13,12 @@
|
|||
"mindspore/mindspore/core/mindrt/src/actor/actorpolicy.h" "runtime/references"
|
||||
"mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/" "readability/casting"
|
||||
"mindspore/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc" "build/include_what_you_use"
|
||||
"mindspore/mindspore/ccsrc/utils/convert_utils_py.cc" "whitespace/indent"
|
||||
|
||||
# Modelzoo
|
||||
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references"
|
||||
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/PostProcess/Yolov4TinyMindsporePost.h" "runtime/references"
|
||||
|
||||
# MindData
|
||||
"mindspore/mindspore/ccsrc/minddata/mindrecord/include/shard_page.h" "runtime/string"
|
||||
"mindspore/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h" "runtime/references"
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
"mindspore/mindspore/nn/cell.py" "assignment-from-none"
|
||||
"mindspore/mindspore/_extends/parse/resources.py" "bad-whitespace"
|
||||
"mindspore/mindspore/_extends/parse/parser.py" "broad-except"
|
||||
"mindspore/mindspore/_extends/parse/parser.py" "eval-used"
|
||||
"mindspore/mindspore/nn/cell.py" "protected-access"
|
||||
"mindspore/mindspore/nn/optim/ftrl.py" "unused-import"
|
||||
"mindspore/mindspore/train/amp.py" "protected-access"
|
||||
|
|
|
@ -16,20 +16,18 @@
|
|||
Interfaces for parser module in c++.
|
||||
"""
|
||||
|
||||
from .parser import (Parser, create_obj_instance, generate_scope,
|
||||
get_bprop_method_of_class, get_class_instance_type,
|
||||
get_class_member_namespace_symbol, create_slice_obj,
|
||||
get_dataclass_attributes, get_dataclass_methods, get_obj_id,
|
||||
get_module_namespace, get_obj_type, get_object_key,
|
||||
get_ast_type, get_node_type, get_args, get_args_default_values,
|
||||
get_ast_namespace_symbol, get_operation_namespace_symbol,
|
||||
get_parse_method_of_class, get_scope_name, expand_expr_statement,
|
||||
from .parser import (Parser, create_instance, is_supported_create_instance_type, generate_scope,
|
||||
get_bprop_method_of_class, get_class_instance_type, get_class_member_namespace_symbol,
|
||||
create_slice_obj, get_dataclass_attributes, get_dataclass_methods, get_obj_id,
|
||||
get_module_namespace, get_obj_type, get_object_key, get_ast_type, get_node_type,
|
||||
get_args, get_args_default_values, get_ast_namespace_symbol, get_operation_namespace_symbol,
|
||||
get_parse_method_of_class, get_scope_name, eval_script, expand_expr_statement,
|
||||
is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor, get_object_description)
|
||||
|
||||
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
|
||||
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type',
|
||||
'get_args_default_values', 'get_ast_namespace_symbol', 'get_operation_namespace_symbol',
|
||||
'get_args', 'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
|
||||
'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
|
||||
'get_dataclass_methods', 'get_dataclass_methods', 'get_scope_name',
|
||||
'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement']
|
||||
'get_args', 'get_obj_type', 'get_obj_id', 'create_instance', 'is_supported_create_instance_type',
|
||||
'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser',
|
||||
'get_dataclass_attributes', 'get_dataclass_methods', 'get_dataclass_methods', 'get_scope_name',
|
||||
'eval_script', 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement']
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
# ============================================================================
|
||||
"""The module of parser python object, called by c++."""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import hashlib
|
||||
import inspect
|
||||
|
@ -25,7 +26,7 @@ from textwrap import dedent
|
|||
|
||||
import asttokens
|
||||
|
||||
from mindspore import Tensor as MsTensor
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore import nn
|
||||
|
@ -132,6 +133,9 @@ def get_bprop_method_of_class(obj, parse_method=None):
|
|||
method = getattr(obj, method_name)
|
||||
return method
|
||||
|
||||
# The fallback feature is enabled in default.
|
||||
# Not support change the flag during the process is alive.
|
||||
support_fallback_ = os.getenv('ENV_SUPPORT_FALLBACK')
|
||||
|
||||
def resolve_symbol(namespace, symbol):
|
||||
"""
|
||||
|
@ -159,10 +163,13 @@ def resolve_symbol(namespace, symbol):
|
|||
if getattr(resolve_, "__hash__") is None:
|
||||
return resolve_
|
||||
|
||||
# Raise a proper error if not using Fallback feature.
|
||||
if support_fallback_ != '1':
|
||||
# Raise NotImplementedError when parsing the numpy methods, but not the numpy constant.
|
||||
if namespace.name == "numpy" and isinstance(resolve_, (types.FunctionType, types.MethodType, types.ModuleType)):
|
||||
raise NotImplementedError(
|
||||
f"MindSpore does not support to use the numpy methods in the function construct with the graph mode.")
|
||||
if namespace.name == "numpy" and \
|
||||
isinstance(resolve_, (types.FunctionType, types.MethodType, types.ModuleType)):
|
||||
raise NotImplementedError("Mindspore does not support to use the numpy methods " \
|
||||
"within the construct() or @ms_function decorated function in graph mode.")
|
||||
|
||||
# If need trope the obj
|
||||
if resolve_ in convert_object_map:
|
||||
|
@ -177,9 +184,11 @@ def resolve_symbol(namespace, symbol):
|
|||
logger.debug("resolve exception occurred, value = %r", e)
|
||||
logger.debug("resolve type is invalid, namespace = %s, symbol = %s",
|
||||
namespace.__str__(), symbol)
|
||||
|
||||
if isinstance(resolve_, _MindsporeFunctionExecutor):
|
||||
logger.debug("resolve class _MindsporeFunctionExecutor, resolve fn instead.")
|
||||
resolve_ = resolve_.fn
|
||||
logger.debug(f'found: {symbol} in {namespace.__str__()}, resolve: {resolve_} / {type(resolve_)}')
|
||||
return resolve_
|
||||
|
||||
|
||||
|
@ -292,22 +301,24 @@ def _convert_tuple_to_args_kwargs(params):
|
|||
args += (param,)
|
||||
return (args, kwargs)
|
||||
|
||||
def is_supported_create_instance_type(cls_type):
|
||||
return issubclass(cls_type, (nn.Cell, ops.Primitive))
|
||||
|
||||
def create_obj_instance(cls_type, params=None):
|
||||
def create_instance(cls_type, params=None):
|
||||
"""Create python instance."""
|
||||
if not isinstance(cls_type, type):
|
||||
logger.warning(f"create_obj_instance(), cls_type is not a type, cls_type: {cls_type}")
|
||||
logger.warning(f"create_instance(), cls_type is not a type, cls_type: {cls_type}")
|
||||
return None
|
||||
|
||||
# Check the type, now only support nn.Cell and Primitive.
|
||||
obj = None
|
||||
if issubclass(cls_type, (nn.Cell, ops.Primitive)):
|
||||
if is_supported_create_instance_type(cls_type):
|
||||
# Check arguments, only support *args or **kwargs.
|
||||
if params is None:
|
||||
obj = cls_type()
|
||||
elif isinstance(params, tuple):
|
||||
args, kwargs = _convert_tuple_to_args_kwargs(params)
|
||||
logger.debug(f"create_obj_instance(), args: {args}, kwargs: {kwargs}")
|
||||
logger.debug(f"create_instance(), args: {args}, kwargs: {kwargs}")
|
||||
if args and kwargs:
|
||||
obj = cls_type(*args, **kwargs)
|
||||
elif args:
|
||||
|
@ -358,7 +369,7 @@ def get_dataclass_methods(cls):
|
|||
|
||||
def convert_to_ms_tensor(data):
|
||||
"""Convert C++ tensor to mindspore tensor."""
|
||||
return MsTensor(data)
|
||||
return Tensor(data)
|
||||
|
||||
|
||||
def get_object_description(obj, fname, fline):
|
||||
|
@ -415,7 +426,6 @@ def get_operation_namespace_symbol(var: str):
|
|||
logger.debug("get operation ops info = %r", ops_info)
|
||||
return ops_info
|
||||
|
||||
|
||||
def get_ast_type(node):
|
||||
"""Get the ast type."""
|
||||
ast_type = AST_SUB_TYPE_UNKNOWN
|
||||
|
@ -483,6 +493,21 @@ def get_args(node):
|
|||
args.append(node.args.kwarg)
|
||||
return args
|
||||
|
||||
def eval_script(exp_str, params):
|
||||
"""Evaluate a python expression."""
|
||||
if not isinstance(params, tuple):
|
||||
raise ValueError(f"eval_script(), params is not a tuple, params: {params}")
|
||||
if len(params) != 2:
|
||||
raise ValueError(f"eval_script(), params tuple length is wrong, params: {params}")
|
||||
|
||||
logger.debug(f'exp_str: {exp_str}, params: {params}')
|
||||
global_params = params[0]
|
||||
local_params = params[1]
|
||||
obj = eval(exp_str, global_params, local_params)
|
||||
if obj is None:
|
||||
raise ValueError(f"When call 'eval', the result is none. exp_str: '{exp_str}'")
|
||||
return obj
|
||||
|
||||
|
||||
class Parser:
|
||||
"""
|
||||
|
@ -501,7 +526,10 @@ class Parser:
|
|||
self.line_offset = 0
|
||||
self.filename: str = inspect.getfile(self.fn)
|
||||
|
||||
# Used to resolve the function's globals Namespace.
|
||||
# Used to resolve mindspore builtin ops namespace.
|
||||
self.ms_common_ns = CellNamespace('mindspore.common')
|
||||
self.ms_ops_ns = CellNamespace('mindspore.ops')
|
||||
# Used to resolve the function's globals namespace.
|
||||
self.global_namespace = CellNamespace(fn.__module__)
|
||||
self.function_module = fn.__module__
|
||||
# Used to resolve the function's nonlocals.
|
||||
|
@ -512,38 +540,37 @@ class Parser:
|
|||
def parse(self):
|
||||
"""Parse the function or method."""
|
||||
logger.debug("fn = %r", self.fn)
|
||||
tree = None
|
||||
if isinstance(self.fn, (types.FunctionType, types.MethodType)):
|
||||
lines, self.line_offset = inspect.getsourcelines(self.fn)
|
||||
original_src = ''.join(lines)
|
||||
hexstr = hashlib.sha256(original_src.encode()).hexdigest()
|
||||
tree = Parser.ast_cache.get(hexstr)
|
||||
if not tree:
|
||||
ast_tokens = Parser.ast_cache.get(hexstr)
|
||||
if not ast_tokens:
|
||||
src = dedent(original_src)
|
||||
self.col_offset = \
|
||||
len(original_src.split('\n')[0]) - len(src.split('\n')[0])
|
||||
logger.debug("get source = %s", src)
|
||||
try:
|
||||
tree = asttokens.ASTTokens(src, parse=True).tree
|
||||
ast_tokens = asttokens.ASTTokens(src, parse=True)
|
||||
except IndentationError as idt_err:
|
||||
idt_err.filename = self.filename
|
||||
idt_err.lineno = self.line_offset
|
||||
idt_err.msg = f"There are incorrect indentations in definition or comment of function: " \
|
||||
f"'{self.fn.__qualname__}'."
|
||||
raise idt_err
|
||||
Parser.ast_cache[hexstr] = tree
|
||||
else:
|
||||
Parser.ast_cache[hexstr] = ast_tokens
|
||||
return ast_tokens, ast_tokens.tree
|
||||
|
||||
logger.error("Fn type is invalid")
|
||||
return tree
|
||||
return None, None
|
||||
|
||||
def get_namespace_symbol(self, var: str):
|
||||
|
||||
"""Get symbol type and namespace and symbol."""
|
||||
if var in self.closure_namespace:
|
||||
logger.debug("in closure_namespace")
|
||||
logger.debug(f"found {var} in closure_namespace {self.closure_namespace.__str__()}")
|
||||
return self.closure_namespace, var
|
||||
if var in self.global_namespace:
|
||||
logger.debug("in global_namespace")
|
||||
logger.debug(f"found {var} in global_namespace {self.global_namespace.__str__()}")
|
||||
value = self.global_namespace[var]
|
||||
if isinstance(value, type(abs)) and self.global_namespace[var] not in convert_object_map:
|
||||
error_info = f"The builtin function '{var}' is not supported in graph mode."
|
||||
|
@ -552,6 +579,56 @@ class Parser:
|
|||
error_info = f"The name '{var}' is not defined."
|
||||
return None, var, error_info
|
||||
|
||||
def is_unsupported_builtin_type(self, value_type):
|
||||
"""To check if not supported builtin type"""
|
||||
logger.debug(f'value_type: {value_type}, {type([])}, {type(())}')
|
||||
return value_type == type([]) or value_type == type(())
|
||||
|
||||
def is_supported_namespace_module(self, value):
|
||||
"""To check if the module is allowed to support."""
|
||||
if not hasattr(value, '__name__'):
|
||||
return True
|
||||
|
||||
name = value.__name__
|
||||
if name == 'mindspore':
|
||||
logger.debug(f'...found {name} in mindspore root namespace')
|
||||
return True
|
||||
|
||||
if not isinstance(value, types.ModuleType):
|
||||
return False
|
||||
rightmost_name = name.split('.')[-1]
|
||||
# if rightmost_name in self.ms_common_ns:
|
||||
# logger.error(f'...found {module_name} in common namespace: {self.ms_common_ns.__str__()}')
|
||||
# return True
|
||||
if rightmost_name in self.ms_ops_ns:
|
||||
logger.debug(f'...found {name}({rightmost_name}) in ops namespace: {self.ms_ops_ns.__str__()}')
|
||||
return True
|
||||
if rightmost_name in trope_ns:
|
||||
logger.debug(f'...found {name}({rightmost_name}) in trope namespace: {self.trope_ns.__str__()}')
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_builtin_namespace_symbol(self, var: str):
|
||||
"""Get mindspore builtin namespace and symbol."""
|
||||
if var in self.closure_namespace:
|
||||
logger.debug(f"found {var} in closure_namespace {self.closure_namespace.__str__()}")
|
||||
return self.closure_namespace, var
|
||||
if var in self.global_namespace:
|
||||
logger.debug(f"found {var} in global_namespace {self.global_namespace.__str__()}")
|
||||
value = self.global_namespace[var]
|
||||
value_str = value.__name__ if hasattr(value, '__name__') else str(value)
|
||||
logger.debug(f"value: {type(value)}, : {value_str}, hasattr(__name__): {hasattr(value, '__name__')}")
|
||||
# To check if allowed to support.
|
||||
if self.is_unsupported_builtin_type(value):
|
||||
return self.global_namespace, var, value
|
||||
if not self.is_supported_namespace_module(value): # Check if support including instance of types.ModuleType
|
||||
return self.global_namespace, var, value
|
||||
return self.global_namespace, var
|
||||
|
||||
error_info = f"The symbol '{var}' is not supported in graph mode."
|
||||
logger.debug(error_info)
|
||||
return None, var, error_info
|
||||
|
||||
def analyze_super(self, class_type_node, subclass_instance):
|
||||
"""Analyze super and return a class instance."""
|
||||
sub_class = type(subclass_instance)
|
||||
|
|
|
@ -444,7 +444,7 @@ AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitiveP
|
|||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
|
||||
py::tuple data_tuple = ValuePtrToPyData(input->BuildValue());
|
||||
py::tuple data_tuple = ValueToPyData(input->BuildValue());
|
||||
py::array data = py::array(data_tuple);
|
||||
auto tensor = tensor::TensorPy::MakeTensor(data);
|
||||
auto ret = tensor->ToAbstract();
|
||||
|
|
|
@ -268,7 +268,7 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const pipeline::ResourceB
|
|||
}
|
||||
}
|
||||
if (!fn || py::isinstance<py::none>(fn)) {
|
||||
MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << ".";
|
||||
MS_LOG(ERROR) << "Fail to find bprop function for " << prim->name() << ". fn: " << py::str(fn);
|
||||
return nullptr;
|
||||
}
|
||||
func_graph = parse::ParsePythonCode(fn);
|
||||
|
|
|
@ -59,7 +59,7 @@ ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name
|
|||
}
|
||||
std::vector<py::object> arg_list;
|
||||
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
|
||||
[](const Attr &attr) { return ValuePtrToPyData(attr.second); });
|
||||
[](const Attr &attr) { return ValueToPyData(attr.second); });
|
||||
py::object obj =
|
||||
parse::python_adapter::CallPyFn(GET_OP_FUNCTION_PATH, GET_OP_FUNCTION, op_name, op_path, instance_name, arg_list);
|
||||
ValuePtr op_instance = nullptr;
|
||||
|
|
|
@ -215,11 +215,13 @@ ValuePtr ConvertDict(const py::object &obj, bool use_signature) {
|
|||
return std::make_shared<ValueDictionary>(key_values);
|
||||
}
|
||||
|
||||
ValuePtr ConvertNameSpace(const py::object &obj) {
|
||||
ValuePtr ConvertModuleNameSpace(const py::object &obj) {
|
||||
MS_LOG(DEBUG) << "Converting python module";
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj);
|
||||
auto converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace));
|
||||
auto converted =
|
||||
std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace), obj);
|
||||
MS_LOG(DEBUG) << "name_space: " << converted->ToString();
|
||||
return converted;
|
||||
}
|
||||
|
||||
|
@ -355,7 +357,9 @@ ValuePtr ConvertOtherObj(const py::object &obj) {
|
|||
// When the obj is Cell, default parse the 'construct'
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
|
||||
return std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
|
||||
auto res = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
|
||||
MS_LOG(DEBUG) << "name_space: " << res->ToString();
|
||||
return res;
|
||||
}
|
||||
MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));
|
||||
return nullptr;
|
||||
|
@ -456,7 +460,7 @@ std::vector<DataConverterPtr> GetDataConverters() {
|
|||
std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>),
|
||||
std::make_shared<ByTypeDataConverter<py::str>>(PyCast<StringImm, string>),
|
||||
std::make_shared<ByTypeDataConverter<py::ellipsis>>(kEllipsis),
|
||||
std::make_shared<ByTypeDataConverter<py::module>>(ConvertNameSpace),
|
||||
std::make_shared<ByTypeDataConverter<py::module>>(ConvertModuleNameSpace),
|
||||
std::make_shared<ByAttrDataConverter>(PYTHON_DATACLASS_FIELDS, ConvertDataClass),
|
||||
std::make_shared<ByTypeDataConverter<Type>>(ObjCast<TypePtr>),
|
||||
std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
|
||||
|
@ -466,8 +470,10 @@ std::vector<DataConverterPtr> GetDataConverters() {
|
|||
std::make_shared<ByTypeDataConverter<EnvInstance>>(ObjCast<std::shared_ptr<EnvInstance>>),
|
||||
std::make_shared<ByAttrDataConverter>(PYTHON_CLASS_MEMBER_NAMESPACE,
|
||||
[](const py::object &obj) -> ValuePtr {
|
||||
return std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER,
|
||||
obj);
|
||||
auto res =
|
||||
std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
|
||||
MS_LOG(DEBUG) << "name_space: " << res->ToString();
|
||||
return res;
|
||||
}),
|
||||
std::make_shared<ByTypeDataConverter<py::int_>>(ConvertIntegerWithType),
|
||||
std::make_shared<ByTypeDataConverter<py::float_>>(ConvertFloatWithType),
|
||||
|
@ -598,8 +604,16 @@ bool IsCellInstance(const py::object &obj) {
|
|||
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) {
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
// `args_kwargs` maybe a tuple(*args), tuple(**kwargs), or tuple(*args, **kwargs).
|
||||
return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type)
|
||||
: python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, args_kwargs);
|
||||
return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_INSTANCE, type)
|
||||
: python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_INSTANCE, type, args_kwargs);
|
||||
}
|
||||
|
||||
// Call the python script string.
|
||||
py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs) {
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
// `args_kwargs` is a tuple(dict(global), dict(local)).
|
||||
return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_EVAL_PY_SCRIPT, script)
|
||||
: python_adapter::CallPyModFn(mod, PYTHON_MOD_EVAL_PY_SCRIPT, script, args_kwargs);
|
||||
}
|
||||
|
||||
// Generate an appropriate name and set to graph debuginfo,
|
||||
|
|
|
@ -45,6 +45,7 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object &obj);
|
|||
|
||||
bool IsCellInstance(const py::object &obj);
|
||||
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs);
|
||||
py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs);
|
||||
void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name);
|
||||
ValuePtr PyDataToValue(const py::object &obj);
|
||||
void ClearObjectCache();
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "pybind11/pybind11.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "pipeline/jit/parse/parse.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "utils/info.h"
|
||||
#include "debug/trace.h"
|
||||
|
@ -71,7 +72,7 @@ void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var " << var_name << " with node "
|
||||
<< node->DebugString();
|
||||
auto [iter, is_new_name] = vars_.emplace(var_name, std::make_pair(node, false));
|
||||
auto [iter, is_new_name] = assigned_vars_.emplace(var_name, std::make_pair(node, false));
|
||||
if (!is_new_name) {
|
||||
// If a cnode variable with same name already existed but not used,
|
||||
// add it as an isolate node. for example:
|
||||
|
@ -97,8 +98,8 @@ void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr
|
|||
// Read variable from predecessors
|
||||
AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
|
||||
// Get var node if it is found
|
||||
auto found = vars_.find(var);
|
||||
if (found != vars_.end()) {
|
||||
auto found = assigned_vars_.find(var);
|
||||
if (found != assigned_vars_.end()) {
|
||||
auto &node = found->second.first;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// Mark the variable as used.
|
||||
|
@ -109,7 +110,7 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
|
|||
}
|
||||
return node;
|
||||
}
|
||||
// Get var from predecessor block ,if can't get then make a resolve node to it
|
||||
// Get var from predecessor block, if can't get then make a resolve node to it
|
||||
if (matured_) {
|
||||
// If only one predecessor block, read the definition of var from it.
|
||||
if (prev_blocks_.size() == 1) {
|
||||
|
@ -122,6 +123,7 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
|
|||
if (it != var_to_resolve_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
MS_LOG(DEBUG) << "var: " << var;
|
||||
auto tmp_node = MakeResolveSymbol(var);
|
||||
var_to_resolve_[var] = tmp_node;
|
||||
return tmp_node;
|
||||
|
@ -154,6 +156,7 @@ AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) {
|
|||
}
|
||||
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_AST, namespace_var[0]);
|
||||
SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[1].cast<std::string>());
|
||||
MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
|
||||
return MakeResolve(name_space, symbol);
|
||||
}
|
||||
|
||||
|
@ -164,9 +167,39 @@ AnfNodePtr FunctionBlock::MakeResolveClassMember(const std::string &attr) {
|
|||
py::object namespace_var = ast->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, ast->obj());
|
||||
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
|
||||
SymbolPtr symbol = std::make_shared<Symbol>(attr);
|
||||
MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
|
||||
return MakeResolve(name_space, symbol);
|
||||
}
|
||||
|
||||
AnfNodePtr FunctionBlock::HandleNamespaceInfo(const py::tuple &namespace_info) {
|
||||
const size_t namespace_info_size = 2;
|
||||
const size_t namespace_more_info_size = 3;
|
||||
if (namespace_info.size() != namespace_info_size && namespace_info.size() != namespace_more_info_size) {
|
||||
MS_EXCEPTION(NameError) << "namespace info size should be 2 or 3, but got " << namespace_info.size();
|
||||
}
|
||||
bool unsupported = false;
|
||||
py::object py_obj;
|
||||
if (namespace_info.size() == namespace_more_info_size) {
|
||||
if (namespace_info[0].is_none()) { // If namespace is None, the symbol is an undefined name.
|
||||
MS_EXCEPTION(NameError) << namespace_info[namespace_more_info_size - 1].cast<std::string>();
|
||||
} else { // Or, the symbol is an unsupported builtin symbol in Graph mode.
|
||||
unsupported = true;
|
||||
py_obj = namespace_info[namespace_more_info_size - 1];
|
||||
}
|
||||
}
|
||||
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_info[0]);
|
||||
SymbolPtr symbol = std::make_shared<Symbol>(namespace_info[1].cast<std::string>());
|
||||
MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString()
|
||||
<< ", unsupported: " << unsupported;
|
||||
auto resolved_node = MakeResolve(name_space, symbol);
|
||||
if (unsupported) {
|
||||
resolved_node->set_interpret(true);
|
||||
AddGlobalPyParam(symbol->name(), py_obj);
|
||||
MS_LOG(INFO) << "Added global python symblol: {" << symbol->name() << " : " << py::str(py_obj) << "}";
|
||||
}
|
||||
return resolved_node;
|
||||
}
|
||||
|
||||
// Make a resolve node for symbol string
|
||||
AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
|
||||
if (value.compare(0, strlen("self"), "self") == 0) {
|
||||
|
@ -180,23 +213,17 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
|
|||
}
|
||||
auto ast = parser_.ast();
|
||||
MS_EXCEPTION_IF_NULL(ast);
|
||||
|
||||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto use_fallback = (parser_.support_fallback() == "0" ? false : true);
|
||||
if (!use_fallback) {
|
||||
py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
|
||||
const size_t namespace_info_size = 2;
|
||||
if (namespace_info.size() < namespace_info_size) {
|
||||
MS_EXCEPTION(NameError) << "namespace_info is less than 2";
|
||||
return HandleNamespaceInfo(namespace_info);
|
||||
} else {
|
||||
py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_BUILTIN_NAMESPACE_SYMBOL, value);
|
||||
return HandleNamespaceInfo(namespace_info);
|
||||
}
|
||||
// If namespace is None, the symbol is an undefined name or an unsupported builtin function.
|
||||
if (namespace_info[0].is_none()) {
|
||||
// If the size of namespace_var is greater than or equal to 3, the error information is stored in namespace_var[2].
|
||||
if (namespace_info.size() > namespace_info_size) {
|
||||
MS_EXCEPTION(NameError) << namespace_info[namespace_info_size].cast<std::string>();
|
||||
}
|
||||
// If the size of namespace_var is less than 3, the default error information is used.
|
||||
MS_EXCEPTION(NameError) << "The name \'" << value << "\' is not defined.";
|
||||
}
|
||||
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_info[0]);
|
||||
SymbolPtr symbol = std::make_shared<Symbol>(namespace_info[1].cast<std::string>());
|
||||
return MakeResolve(name_space, symbol);
|
||||
}
|
||||
|
||||
AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
|
||||
|
@ -209,6 +236,7 @@ AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
|
|||
}
|
||||
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_COMMON_OPS, namespace_var[0]);
|
||||
SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[1].cast<std::string>());
|
||||
MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
|
||||
return MakeResolve(name_space, symbol);
|
||||
}
|
||||
|
||||
|
@ -221,6 +249,17 @@ AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const Symb
|
|||
return node;
|
||||
}
|
||||
|
||||
AnfNodePtr FunctionBlock::MakeInterpret(const std::string &script_text, const AnfNodePtr &global_dict_node,
|
||||
const AnfNodePtr &local_dict_node, const AnfNodePtr &orig_node) {
|
||||
MS_LOG(DEBUG) << "MakeInterpret for " << script_text;
|
||||
ScriptPtr script = std::make_shared<Script>(script_text);
|
||||
auto script_node = NewValueNode(script);
|
||||
auto node = func_graph_->NewCNodeInOrder(
|
||||
{NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
|
||||
node->set_interpreted_node(orig_node);
|
||||
return node;
|
||||
}
|
||||
|
||||
// Add input for the block's phi parameter
|
||||
void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
|
||||
MS_EXCEPTION_IF_NULL(phi);
|
||||
|
@ -418,7 +457,7 @@ void FunctionBlock::FindIsolatedNodes() {
|
|||
//
|
||||
std::set<AnfNodePtr> used;
|
||||
// Find used variables.
|
||||
for (const auto &var : vars_) {
|
||||
for (const auto &var : assigned_vars_) {
|
||||
auto &node = var.second.first;
|
||||
if (node == nullptr) {
|
||||
continue;
|
||||
|
@ -429,7 +468,7 @@ void FunctionBlock::FindIsolatedNodes() {
|
|||
}
|
||||
}
|
||||
// Add isolated nodes which is unused var but not found in used set.
|
||||
for (const auto &var : vars_) {
|
||||
for (const auto &var : assigned_vars_) {
|
||||
auto &node = var.second.first;
|
||||
bool is_used = var.second.second;
|
||||
if (node == nullptr || is_used) {
|
||||
|
|
|
@ -26,16 +26,17 @@
|
|||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <tuple>
|
||||
#include "pipeline/jit/parse/parse_base.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ordered_set.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parse {
|
||||
|
||||
class Parser;
|
||||
class NameSpace;
|
||||
class Symbol;
|
||||
class Script;
|
||||
class FunctionBlock;
|
||||
using FunctionBlockPtr = std::shared_ptr<FunctionBlock>;
|
||||
|
||||
|
@ -70,11 +71,26 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
AnfNodePtr MakeResolveSymbol(const std::string &value);
|
||||
AnfNodePtr MakeResolveOperation(const std::string &value);
|
||||
AnfNodePtr MakeResolve(const std::shared_ptr<NameSpace> &name_space, const std::shared_ptr<Symbol> &resolve_symbol);
|
||||
AnfNodePtr HandleNamespaceInfo(const py::tuple &namespace_info);
|
||||
AnfNodePtr MakeInterpret(const std::string &script_text, const AnfNodePtr &global_dict_node,
|
||||
const AnfNodePtr &local_dict_node, const AnfNodePtr &orig_node);
|
||||
const std::unordered_map<ParameterPtr, AnfNodePtr> &removable_phis() const { return removable_phis_; }
|
||||
void FindIsolatedNodes();
|
||||
void AddIsolatedNode(const AnfNodePtr &target);
|
||||
void AttachIsolatedNodesBeforeReturn();
|
||||
|
||||
py::dict &global_py_params() { return global_py_params_; }
|
||||
void set_global_py_params(const py::dict &symbols) { global_py_params_ = symbols; }
|
||||
void AddGlobalPyParam(const std::string &name, const py::object &obj) { global_py_params_[py::str(name)] = obj; }
|
||||
|
||||
std::tuple<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> local_py_params() {
|
||||
return {local_py_params_keys_, local_py_params_values_};
|
||||
}
|
||||
void AddLocalPyParam(const std::string &name, const AnfNodePtr &node) {
|
||||
local_py_params_keys_.emplace_back(NewValueNode(name));
|
||||
local_py_params_values_.emplace_back(node);
|
||||
}
|
||||
|
||||
private:
|
||||
// Block graph
|
||||
FuncGraphPtr func_graph_;
|
||||
|
@ -90,7 +106,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
std::vector<FunctionBlock *> prev_blocks_;
|
||||
|
||||
// Store args and variable's node, use a bool flag to indicate if the variable is used.
|
||||
std::map<std::string, std::pair<AnfNodePtr, bool>> vars_;
|
||||
std::map<std::string, std::pair<AnfNodePtr, bool>> assigned_vars_;
|
||||
|
||||
// Map the parameter node to variable, it can be resolved if the block's predecessors are processed
|
||||
std::map<ParameterPtr, std::string> phi_nodes_;
|
||||
|
@ -114,10 +130,25 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
// Keep new made resolve symbol for the variable not found in vars_.
|
||||
std::unordered_map<std::string, AnfNodePtr> var_to_resolve_;
|
||||
|
||||
// Collect all python symbols in the block.
|
||||
// We treat both global symbols and local symbols declared previously as global symbols.
|
||||
py::dict global_py_params_;
|
||||
std::vector<AnfNodePtr> local_py_params_keys_;
|
||||
std::vector<AnfNodePtr> local_py_params_values_;
|
||||
|
||||
// Isolated nodes.
|
||||
OrderedSet<AnfNodePtr> isolated_nodes_;
|
||||
};
|
||||
|
||||
class ScriptInfo {
|
||||
public:
|
||||
explicit ScriptInfo(const py::object &obj) : py_obj_(obj) {}
|
||||
|
||||
// Key for user data.
|
||||
constexpr static char key[] = "ScriptInfo";
|
||||
|
||||
py::object py_obj_;
|
||||
};
|
||||
} // namespace parse
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -24,9 +24,10 @@
|
|||
#include <unordered_map>
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include "pybind_api/pybind_patch.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/operator/composite/composite.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "debug/trace.h"
|
||||
|
@ -42,7 +43,7 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto ast = std::make_shared<ParseAst>(obj);
|
||||
auto ast = std::make_shared<ParseFunctionAst>(obj);
|
||||
bool success = ast->InitParseAstInfo(python_mod_get_parse_method);
|
||||
if (!success) {
|
||||
MS_LOG(ERROR) << "Parse code to ast tree failed.";
|
||||
|
@ -89,8 +90,9 @@ AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNo
|
|||
|
||||
FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
|
||||
|
||||
Parser::Parser(const std::shared_ptr<ParseAst> &ast) : ast_(ast) {
|
||||
Parser::Parser(const std::shared_ptr<ParseFunctionAst> &ast) : ast_(ast) {
|
||||
max_for_loop_count_str_ = common::GetEnv("ENV_FOR_TO_WHILE_LOOP");
|
||||
support_fallback_ = "0"; // We will open it later by call common::GetEnv("ENV_SUPPORT_FALLBACK")
|
||||
errcode_ = PARSE_SUCCESS;
|
||||
BuildMethodMap();
|
||||
}
|
||||
|
@ -147,7 +149,7 @@ void Parser::CleanParserResource() {
|
|||
ScopeManager::GetInstance().ClearScope();
|
||||
}
|
||||
|
||||
void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &ast) {
|
||||
void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseFunctionAst> &ast) {
|
||||
// Check whether the functions referred by this function and itself are missing 'return' statement
|
||||
auto mng = Manage(fn, false);
|
||||
MS_EXCEPTION_IF_NULL(ast);
|
||||
|
@ -501,6 +503,7 @@ AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &no
|
|||
MS_LOG(DEBUG) << "The Name id is " << name_id;
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
if (block->IsGlobalVar(name_id)) {
|
||||
MS_LOG(DEBUG) << "name_id: " << name_id;
|
||||
return block->MakeResolveSymbol(name_id);
|
||||
}
|
||||
return block->ReadVariable(name_id);
|
||||
|
@ -612,6 +615,7 @@ AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &arg
|
|||
py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance);
|
||||
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
|
||||
SymbolPtr symbol = std::make_shared<Symbol>("namespace");
|
||||
MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
|
||||
return block->MakeResolve(name_space, symbol);
|
||||
}
|
||||
|
||||
|
@ -631,7 +635,7 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
|
|||
}
|
||||
}
|
||||
|
||||
AnfNodePtr call_function_anf_node = ParseExprNode(block, function_ast_node);
|
||||
AnfNodePtr call_function_node = ParseExprNode(block, function_ast_node);
|
||||
// Function call arguments should be passed in as groups and unpacked later using unpack call
|
||||
std::vector<AnfNodePtr> packed_arguments;
|
||||
std::vector<AnfNodePtr> group_arguments;
|
||||
|
@ -641,33 +645,37 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
|
|||
// If there is stared or keyword argument, unpack may be needed
|
||||
bool need_unpack = need_unpack_args || need_unpack_keywords;
|
||||
|
||||
return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack);
|
||||
auto call_cnode = GenerateAnfNodeForCall(block, call_function_node, packed_arguments, group_arguments, need_unpack);
|
||||
if (call_function_node->interpret()) {
|
||||
call_cnode->set_interpret(true);
|
||||
}
|
||||
return call_cnode;
|
||||
}
|
||||
|
||||
CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_anf_node,
|
||||
CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_node,
|
||||
const std::vector<AnfNodePtr> &packed_arguments) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> unpack_call_nodes;
|
||||
auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL));
|
||||
unpack_call_nodes.push_back(unpack_call_op);
|
||||
unpack_call_nodes.push_back(call_function_anf_node);
|
||||
unpack_call_nodes.push_back(call_function_node);
|
||||
(void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes),
|
||||
[](AnfNodePtr node) -> AnfNodePtr { return node; });
|
||||
CNodePtr unpack_call = func_graph->NewCNodeInOrder(unpack_call_nodes);
|
||||
return unpack_call;
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node,
|
||||
AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_node,
|
||||
const std::vector<AnfNodePtr> &packed_arguments,
|
||||
const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const {
|
||||
// If there is keyword arguments or starred, using an unpack_call op to unpack the argument
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
if (need_unpack) {
|
||||
return MakeUnpackCall(block->func_graph(), call_function_anf_node, packed_arguments);
|
||||
return MakeUnpackCall(block->func_graph(), call_function_node, packed_arguments);
|
||||
}
|
||||
// else there is no keyword arguments and starred, parsed as normal arguments without unpack
|
||||
std::vector<AnfNodePtr> func_call_nodes;
|
||||
func_call_nodes.push_back(call_function_anf_node);
|
||||
func_call_nodes.push_back(call_function_node);
|
||||
(void)std::transform(group_arguments.begin(), group_arguments.end(), std::back_inserter(func_call_nodes),
|
||||
[](AnfNodePtr node) -> AnfNodePtr { return node; });
|
||||
CNodePtr call_anf_node = block->func_graph()->NewCNodeInOrder(func_call_nodes);
|
||||
|
@ -689,7 +697,9 @@ bool Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args
|
|||
group_arguments->clear();
|
||||
need_unpack = true;
|
||||
} else {
|
||||
group_arguments->push_back(ParseExprNode(block, args[i]));
|
||||
auto node = ParseExprNode(block, args[i]);
|
||||
node = HandleInterpret(block, node, args[i]);
|
||||
group_arguments->push_back(node);
|
||||
}
|
||||
}
|
||||
if (!group_arguments->empty()) {
|
||||
|
@ -734,7 +744,7 @@ bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object
|
|||
// Process call attributes of class type define, eg: x.y()
|
||||
AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast Attribute";
|
||||
// Process class value,eg: self.xx
|
||||
// Process class value, eg: self.xx
|
||||
if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
|
||||
if (ast_->IsClassMember(node)) {
|
||||
std::string var_name = "self.";
|
||||
|
@ -746,6 +756,7 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
|
|||
(py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance<py::int_>(attr_obj) ||
|
||||
py::isinstance<py::float_>(attr_obj) || py::isinstance<py::bool_>(attr_obj) ||
|
||||
py::isinstance<py::str>(attr_obj) || data_converter::IsCellInstance(attr_obj))) {
|
||||
MS_LOG(DEBUG) << "var_name: " << var_name;
|
||||
return block->MakeResolveSymbol(var_name);
|
||||
} else {
|
||||
return block->ReadVariable(var_name);
|
||||
|
@ -775,7 +786,11 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
|
|||
}
|
||||
|
||||
// Create the apply node
|
||||
return block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
|
||||
auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
|
||||
if (value_node->interpret()) {
|
||||
attr_cnode->set_interpret(true);
|
||||
}
|
||||
return attr_cnode;
|
||||
}
|
||||
|
||||
// Process comparison expression : a == b. a > b etc.
|
||||
|
@ -1021,6 +1036,15 @@ AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object
|
|||
}
|
||||
|
||||
// Process a dict ast node expression
|
||||
AnfNodePtr Parser::ParseDictByKeysAndValues(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &key_nodes,
|
||||
const std::vector<AnfNodePtr> &value_nodes) {
|
||||
auto keys_tuple = GenerateMakeTuple(block, key_nodes);
|
||||
auto values_tuple = GenerateMakeTuple(block, value_nodes);
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
|
||||
return block->func_graph()->NewCNodeInOrder({make_dict_op, keys_tuple, values_tuple});
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast Dict";
|
||||
py::list keys = node.attr("keys");
|
||||
|
@ -1031,11 +1055,7 @@ AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &no
|
|||
key_nodes.push_back(ParseExprNode(block, keys[i]));
|
||||
value_nodes.push_back(ParseExprNode(block, values[i]));
|
||||
}
|
||||
auto keys_tuple = GenerateMakeTuple(block, key_nodes);
|
||||
auto values_tuple = GenerateMakeTuple(block, value_nodes);
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
|
||||
return block->func_graph()->NewCNodeInOrder({make_dict_op, keys_tuple, values_tuple});
|
||||
return ParseDictByKeysAndValues(block, key_nodes, value_nodes);
|
||||
}
|
||||
|
||||
// Process a augment assign such as a += b or mat[stride_slice] += b.
|
||||
|
@ -1631,17 +1651,18 @@ AnfNodePtr Parser::ParseListComp(const FunctionBlockPtr &block, const py::object
|
|||
auto top_block = ParseListCompIter(block, node, generator_node);
|
||||
|
||||
// Call the top graph and return the list.
|
||||
auto call_function_anf_node = NewValueNode(top_block->func_graph());
|
||||
auto call_function_node = NewValueNode(top_block->func_graph());
|
||||
std::vector<AnfNodePtr> func_call_nodes;
|
||||
func_call_nodes.push_back(call_function_anf_node);
|
||||
func_call_nodes.push_back(call_function_node);
|
||||
AnfNodePtr output = block->func_graph()->NewCNodeInOrder(func_call_nodes);
|
||||
return output;
|
||||
}
|
||||
|
||||
void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) {
|
||||
void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &target_obj,
|
||||
const AnfNodePtr &assigned_node) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
MS_EXCEPTION_IF_NULL(assigned_node);
|
||||
py::str name = python_adapter::GetPyObjAttr(targ, "id");
|
||||
py::str name = python_adapter::GetPyObjAttr(target_obj, "id");
|
||||
std::string name_id = name;
|
||||
assigned_node->debug_info()->set_name(name_id);
|
||||
// Set the debug name of the constant graph
|
||||
|
@ -1652,13 +1673,16 @@ void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &t
|
|||
fg->debug_info()->set_name(name_id);
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Assign name: " << name_id << " to node: " << assigned_node;
|
||||
block->AddLocalPyParam(name_id, assigned_node);
|
||||
block->WriteVariable(name_id, assigned_node);
|
||||
}
|
||||
|
||||
void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) {
|
||||
void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &target_obj,
|
||||
const AnfNodePtr &assigned_node) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
|
||||
py::list items = python_adapter::GetPyObjAttr(targ, "elts");
|
||||
py::list items = python_adapter::GetPyObjAttr(target_obj, "elts");
|
||||
for (size_t i = 0; i < items.size(); i++) {
|
||||
// Use the Primitive replace the operation resolve node (getitem),
|
||||
// because the getitem will eventually be converted to Primitive node
|
||||
|
@ -1670,13 +1694,13 @@ void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &
|
|||
}
|
||||
}
|
||||
|
||||
void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ,
|
||||
void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &target_obj,
|
||||
const AnfNodePtr &assigned_node) {
|
||||
// Now only support the self.xx = xxxxx, can't support x.y = xxxx
|
||||
AnfNodePtr target_node = ParseExprNode(block, targ);
|
||||
AnfNodePtr target_node = ParseExprNode(block, target_obj);
|
||||
MS_EXCEPTION_IF_NULL(target_node);
|
||||
|
||||
auto attr_name = targ.attr("attr").cast<std::string>();
|
||||
auto attr_name = target_obj.attr("attr").cast<std::string>();
|
||||
std::string var_name = "self." + attr_name;
|
||||
|
||||
// Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
|
||||
|
@ -1699,12 +1723,12 @@ void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::ob
|
|||
block->SetStateAssign(target_node, assigned_node);
|
||||
}
|
||||
|
||||
void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ,
|
||||
void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &target_obj,
|
||||
const AnfNodePtr &assigned_node) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM);
|
||||
py::object value_obj = python_adapter::GetPyObjAttr(targ, "value");
|
||||
py::object slice_obj = python_adapter::GetPyObjAttr(targ, "slice");
|
||||
py::object value_obj = python_adapter::GetPyObjAttr(target_obj, "value");
|
||||
py::object slice_obj = python_adapter::GetPyObjAttr(target_obj, "slice");
|
||||
AnfNodePtr value_node = ParseExprNode(block, value_obj);
|
||||
AnfNodePtr slice_node = ParseExprNode(block, slice_obj);
|
||||
CNodePtr setitem_app = block->func_graph()->NewCNodeInOrder({op_setitem, value_node, slice_node, assigned_node});
|
||||
|
@ -1742,18 +1766,19 @@ void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::obje
|
|||
block->WriteVariable(var_name, setitem_app);
|
||||
}
|
||||
|
||||
void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node) {
|
||||
void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &target_obj,
|
||||
const AnfNodePtr &value_node) {
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
MS_LOG(DEBUG) << "Process WriteAssignVars";
|
||||
auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, targ)));
|
||||
auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_obj)));
|
||||
if (ast_type == AST_SUB_TYPE_NAME) {
|
||||
HandleAssignName(block, targ, value_node);
|
||||
HandleAssignName(block, target_obj, value_node);
|
||||
} else if (ast_type == AST_SUB_TYPE_TUPLE) {
|
||||
HandleAssignTuple(block, targ, value_node);
|
||||
HandleAssignTuple(block, target_obj, value_node);
|
||||
} else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
|
||||
HandleAssignSubscript(block, targ, value_node);
|
||||
} else if (ast_->IsClassMember(targ)) {
|
||||
HandleAssignClassMember(block, targ, value_node);
|
||||
HandleAssignSubscript(block, target_obj, value_node);
|
||||
} else if (ast_->IsClassMember(target_obj)) {
|
||||
HandleAssignClassMember(block, target_obj, value_node);
|
||||
} else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) {
|
||||
MS_LOG(EXCEPTION) << "The subnet attributes cannot be changed in the network. \n\n"
|
||||
<< trace::GetDebugInfo(value_node->debug_info());
|
||||
|
@ -1763,11 +1788,47 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta
|
|||
}
|
||||
}
|
||||
|
||||
// Process a assign statement, such as a =b, a,b = tup
|
||||
AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
|
||||
const py::object &value_object) {
|
||||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto use_fallback = (support_fallback_ == "0" ? false : true);
|
||||
if (!use_fallback) {
|
||||
return value_node;
|
||||
}
|
||||
|
||||
AnfNodePtr interpreted_node = value_node;
|
||||
if (value_node->interpret()) {
|
||||
const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(value_object));
|
||||
MS_LOG(INFO) << "script_text: " << script_text << ", value_node: " << value_node->DebugString(2);
|
||||
// Prepare global parameters.
|
||||
py::dict global_dict = block->global_py_params();
|
||||
ValuePtr globals_converted_value = nullptr;
|
||||
if (!ConvertData(global_dict, &globals_converted_value)) {
|
||||
MS_LOG(EXCEPTION) << "Convert data failed";
|
||||
}
|
||||
auto global_dict_node = NewValueNode(globals_converted_value);
|
||||
// Prepare local parameters.
|
||||
auto [keys, values] = block->local_py_params();
|
||||
auto local_dict_node = ParseDictByKeysAndValues(block, keys, values);
|
||||
// Update the valued node if it need interpreting.
|
||||
interpreted_node = block->MakeInterpret(script_text, global_dict_node, local_dict_node, value_node);
|
||||
|
||||
// Print a hint for user.
|
||||
MS_LOG(ERROR) << "Found unsupported syntax in Graph mode, those codes would be fell back to Python interpreter:"
|
||||
<< "\n\n"
|
||||
<< trace::GetDebugInfo(value_node->debug_info());
|
||||
}
|
||||
return interpreted_node;
|
||||
}
|
||||
|
||||
// Process a assign statement, such as a = b, a, b = tuple(xx, xx)
|
||||
FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast assign";
|
||||
py::object value_object = python_adapter::GetPyObjAttr(node, "value");
|
||||
AnfNodePtr value_node = ParseExprNode(block, value_object);
|
||||
value_node = HandleInterpret(block, value_node, value_object);
|
||||
|
||||
py::object targets_object = python_adapter::GetPyObjAttr(node, "targets");
|
||||
py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__");
|
||||
size_t count = LongToSize(pcount);
|
||||
|
@ -1870,8 +1931,8 @@ void Parser::RemoveUnnecessaryPhis() {
|
|||
}
|
||||
}
|
||||
|
||||
// ParseAst class code
|
||||
bool ParseAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) {
|
||||
// ParseFunctionAst class code
|
||||
bool ParseFunctionAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) {
|
||||
// Init the type
|
||||
target_type_ = PARSE_TARGET_UNKNOW;
|
||||
|
||||
|
@ -1916,7 +1977,13 @@ bool ParseAst::InitParseAstInfo(const std::string &python_mod_get_parse_method)
|
|||
|
||||
// Call python parse get ast tree
|
||||
parser_ = python_adapter::CallPyModFn(module_, PYTHON_MOD_PARSE_OBJECT_FUNCTION, function_, parse_method);
|
||||
ast_tree_ = python_adapter::CallPyObjMethod(parser_, "parse");
|
||||
py::tuple ast_info = python_adapter::CallPyObjMethod(parser_, "parse");
|
||||
const size_t ast_info_size = 2;
|
||||
if (ast_info.size() != ast_info_size) {
|
||||
MS_EXCEPTION(NameError) << "ast info size is not equal to 2.";
|
||||
}
|
||||
ast_tokens_ = ast_info[0];
|
||||
ast_tree_ = ast_info[1];
|
||||
|
||||
// Get fn name and module
|
||||
function_module_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_module"));
|
||||
|
@ -1928,23 +1995,28 @@ bool ParseAst::InitParseAstInfo(const std::string &python_mod_get_parse_method)
|
|||
}
|
||||
|
||||
// Get ast tree node : is the tree bode list[0]
|
||||
py::object ParseAst::GetAstNode() {
|
||||
py::object ParseFunctionAst::GetAstNode() {
|
||||
py::list tree_body = python_adapter::GetPyObjAttr(ast_tree_, "body");
|
||||
py::object ast_node = tree_body[0];
|
||||
return ast_node;
|
||||
}
|
||||
|
||||
py::list ParseAst::GetArgs(const py::object &func_node) {
|
||||
// Get ast tokens node text.
|
||||
py::str ParseFunctionAst::GetAstNodeText(const py::object &node_obj) {
|
||||
return python_adapter::CallPyObjMethod(ast_tokens_, "get_text", node_obj);
|
||||
}
|
||||
|
||||
py::list ParseFunctionAst::GetArgs(const py::object &func_node) {
|
||||
py::list ret = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_ARGS, func_node);
|
||||
return ret;
|
||||
}
|
||||
|
||||
py::list ParseAst::GetArgsDefaultValues(const py::object &func_node) {
|
||||
py::list ParseFunctionAst::GetArgsDefaultValues(const py::object &func_node) {
|
||||
py::list ret = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES, func_node);
|
||||
return ret;
|
||||
}
|
||||
|
||||
AstNodeTypePtr ParseAst::GetNodeType(const py::object &node) {
|
||||
AstNodeTypePtr ParseFunctionAst::GetNodeType(const py::object &node) {
|
||||
py::list list_value = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_NODE_TYPE, node);
|
||||
const size_t list_value_size = 2;
|
||||
if (list_value.size() < list_value_size) {
|
||||
|
@ -1955,12 +2027,12 @@ AstNodeTypePtr ParseAst::GetNodeType(const py::object &node) {
|
|||
return std::make_shared<AstNodeType>(node, node_name, type);
|
||||
}
|
||||
|
||||
AstSubType ParseAst::GetOpType(const py::object &node) {
|
||||
AstSubType ParseFunctionAst::GetOpType(const py::object &node) {
|
||||
auto op_type = AstSubType(python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_AST_TYPE, node).cast<int32_t>());
|
||||
return op_type;
|
||||
}
|
||||
|
||||
bool ParseAst::IsClassMember(const py::object &node) {
|
||||
bool ParseFunctionAst::IsClassMember(const py::object &node) {
|
||||
py::object ret = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER, node);
|
||||
if (!py::isinstance<py::bool_>(ret)) {
|
||||
MS_LOG(ERROR) << "The result of mod function parse, should be bool type.";
|
||||
|
|
|
@ -34,7 +34,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parse {
|
||||
|
||||
// Parse status define
|
||||
enum ParseStatusCode : int64_t {
|
||||
PARSE_SUCCESS = 0,
|
||||
|
@ -58,7 +57,7 @@ enum ParseStatusCode : int64_t {
|
|||
const int64_t MAX_FOR_LOOP_COUNT = std::numeric_limits<int64_t>::max();
|
||||
|
||||
class AstNodeType;
|
||||
class ParseAst;
|
||||
class ParseFunctionAst;
|
||||
|
||||
// Save loop info for 'continue' and 'break' statements.
|
||||
struct Loop {
|
||||
|
@ -90,13 +89,14 @@ class LoopContext {
|
|||
// Parser to parse python function
|
||||
class Parser {
|
||||
public:
|
||||
explicit Parser(const std::shared_ptr<ParseAst> &ast);
|
||||
explicit Parser(const std::shared_ptr<ParseFunctionAst> &ast);
|
||||
|
||||
~Parser() {}
|
||||
FuncGraphPtr ParseFuncGraph();
|
||||
FuncGraphPtr func_graph() const { return func_graph_; }
|
||||
ParseStatusCode errcode() const { return errcode_; }
|
||||
std::shared_ptr<ParseAst> ast() const { return ast_; }
|
||||
std::shared_ptr<ParseFunctionAst> ast() const { return ast_; }
|
||||
const std::string &support_fallback() const { return support_fallback_; }
|
||||
// Get location info from the ast node
|
||||
LocationPtr GetLocation(const py::object &node) const;
|
||||
static void InitParserEnvironment(const py::object &obj);
|
||||
|
@ -177,6 +177,8 @@ class Parser {
|
|||
// Process a unaryop
|
||||
AnfNodePtr ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process a dict ast node expression
|
||||
AnfNodePtr ParseDictByKeysAndValues(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &key_nodes,
|
||||
const std::vector<AnfNodePtr> &value_nodes);
|
||||
AnfNodePtr ParseDict(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process ListComp expression
|
||||
AnfNodePtr ParseListComp(const FunctionBlockPtr &block, const py::object &node);
|
||||
|
@ -185,6 +187,10 @@ class Parser {
|
|||
AnfNodePtr ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param,
|
||||
const py::object &node, const py::object &generator_node);
|
||||
|
||||
// Check if the node need interpreting.
|
||||
AnfNodePtr HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
|
||||
const py::object &value_object);
|
||||
|
||||
// Generate argument nodes for ast function node
|
||||
void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node);
|
||||
// Generate argument default value for ast function node
|
||||
|
@ -260,7 +266,7 @@ class Parser {
|
|||
// The shared_ptr will be hold by GraphManager, so just hold a weak ref here.
|
||||
static FuncGraphWeakPtr top_func_graph_;
|
||||
// Python function id, used to indicate whether two CNodes come from the same Python function
|
||||
const std::shared_ptr<ParseAst> &ast_;
|
||||
const std::shared_ptr<ParseFunctionAst> &ast_;
|
||||
FuncGraphPtr func_graph_;
|
||||
// Error code setwhen parsing ast tree
|
||||
ParseStatusCode errcode_;
|
||||
|
@ -278,6 +284,7 @@ class Parser {
|
|||
// Save current loops to support 'continue', 'break' statement.
|
||||
std::stack<Loop> loops_;
|
||||
string max_for_loop_count_str_;
|
||||
string support_fallback_;
|
||||
};
|
||||
|
||||
// AST node type define code to ast
|
||||
|
@ -303,16 +310,19 @@ class AstNodeType {
|
|||
using AstNodeTypePtr = std::shared_ptr<AstNodeType>;
|
||||
|
||||
// A helper class to parse python function
|
||||
class ParseAst {
|
||||
class ParseFunctionAst {
|
||||
public:
|
||||
explicit ParseAst(const py::object &obj) : obj_(obj), target_type_(PARSE_TARGET_UNKNOW), function_line_offset_(-1) {}
|
||||
explicit ParseFunctionAst(const py::object &obj)
|
||||
: obj_(obj), target_type_(PARSE_TARGET_UNKNOW), function_line_offset_(-1) {}
|
||||
|
||||
~ParseAst() = default;
|
||||
~ParseFunctionAst() = default;
|
||||
|
||||
bool InitParseAstInfo(const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD);
|
||||
|
||||
py::object GetAstNode();
|
||||
|
||||
py::str GetAstNodeText(const py::object &node);
|
||||
|
||||
py::list GetArgs(const py::object &func_node);
|
||||
|
||||
py::list GetArgsDefaultValues(const py::object &func_node);
|
||||
|
@ -360,6 +370,7 @@ class ParseAst {
|
|||
// Function or class method.
|
||||
py::function function_;
|
||||
|
||||
py::object ast_tokens_;
|
||||
py::object ast_tree_;
|
||||
py::object parser_;
|
||||
py::module module_;
|
||||
|
|
|
@ -63,7 +63,8 @@ const char PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER[] = "is_class_member";
|
|||
const char PYTHON_MOD_RESOLVE_GET_OBJ_TYPE[] = "get_obj_type";
|
||||
const char PYTHON_MOD_GET_OBJ_ID[] = "get_obj_id";
|
||||
const char PYTHON_MOD_GET_CLASS_INSTANCE_TYPE[] = "get_class_instance_type";
|
||||
const char PYTHON_MOD_CREATE_OBJ_INSTANCE[] = "create_obj_instance";
|
||||
const char PYTHON_MOD_CREATE_INSTANCE[] = "create_instance";
|
||||
const char PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE[] = "is_supported_create_instance_type";
|
||||
const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes";
|
||||
const char PYTHON_MOD_GET_DATACLASS_METHODS[] = "get_dataclass_methods";
|
||||
const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
|
||||
|
@ -72,6 +73,7 @@ const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
|
|||
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
|
||||
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
|
||||
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
|
||||
const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script";
|
||||
|
||||
const char PYTHON_PARSE_GET_ARGS[] = "get_args";
|
||||
const char PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES[] = "get_args_default_values";
|
||||
|
@ -80,6 +82,7 @@ const char PYTHON_PARSE_GET_AST_TYPE[] = "get_ast_type";
|
|||
const char PYTHON_PARSE_GET_NAMESPACE_SYMBOL[] = "get_namespace_symbol";
|
||||
const char PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL[] = "get_ast_namespace_symbol";
|
||||
const char PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL[] = "get_operation_namespace_symbol";
|
||||
const char PYTHON_PARSE_GET_BUILTIN_NAMESPACE_SYMBOL[] = "get_builtin_namespace_symbol";
|
||||
const char PYTHON_PARSE_GET_LOCATION[] = "get_location";
|
||||
const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement";
|
||||
const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope";
|
||||
|
|
|
@ -33,7 +33,7 @@ static const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRI
|
|||
parse::NAMED_PRIMITIVE_NAMECONSTANT,
|
||||
parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR};
|
||||
|
||||
std::string DynamicParser::ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
|
||||
std::string DynamicParser::ParseNodeName(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node,
|
||||
parse::AstMainType type) {
|
||||
MS_EXCEPTION_IF_NULL(ast);
|
||||
if (py::isinstance<py::none>(node)) {
|
||||
|
@ -53,7 +53,7 @@ std::string DynamicParser::ParseNodeName(const std::shared_ptr<parse::ParseAst>
|
|||
return node_name;
|
||||
}
|
||||
|
||||
void DynamicParser::ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node) {
|
||||
void DynamicParser::ParseInputArgs(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &fn_node) {
|
||||
MS_EXCEPTION_IF_NULL(ast);
|
||||
py::list args = ast->GetArgs(fn_node);
|
||||
for (size_t i = 1; i < args.size(); i++) {
|
||||
|
@ -63,7 +63,7 @@ void DynamicParser::ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast,
|
|||
}
|
||||
}
|
||||
|
||||
bool DynamicParser::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
|
||||
bool DynamicParser::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Parse if/while expr";
|
||||
py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST);
|
||||
const auto &node_name = ParseNodeName(ast, test_node, parse::AST_MAIN_TYPE_EXPR);
|
||||
|
@ -112,7 +112,7 @@ bool DynamicParser::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst>
|
|||
return false;
|
||||
}
|
||||
|
||||
bool DynamicParser::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
|
||||
bool DynamicParser::ParseAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Parse assign expr";
|
||||
py::object value_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_VALUE);
|
||||
const auto &node_name = ParseNodeName(ast, value_node, parse::AST_MAIN_TYPE_EXPR);
|
||||
|
@ -140,7 +140,7 @@ bool DynamicParser::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &
|
|||
return false;
|
||||
}
|
||||
|
||||
bool DynamicParser::ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
|
||||
bool DynamicParser::ParseAugAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node,
|
||||
const std::vector<std::string> &compare_prim) {
|
||||
MS_LOG(DEBUG) << "Parse augassign expr";
|
||||
bool ret = false;
|
||||
|
@ -168,7 +168,7 @@ bool DynamicParser::ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst
|
|||
return ret;
|
||||
}
|
||||
|
||||
bool DynamicParser::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
|
||||
bool DynamicParser::ParseForExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Parse for expr";
|
||||
py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY);
|
||||
if (py::isinstance<py::none>(body_node)) {
|
||||
|
@ -188,7 +188,7 @@ bool DynamicParser::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast
|
|||
return false;
|
||||
}
|
||||
|
||||
bool DynamicParser::ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node,
|
||||
bool DynamicParser::ParseBodyContext(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &fn_node,
|
||||
const std::vector<std::string> &compare_prim) {
|
||||
MS_EXCEPTION_IF_NULL(ast);
|
||||
py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY);
|
||||
|
@ -236,7 +236,7 @@ bool DynamicParser::IsDynamicCell(const py::object &cell) {
|
|||
return false;
|
||||
}
|
||||
// Using ast parse to check whether the construct of cell will be changed
|
||||
auto ast = std::make_shared<parse::ParseAst>(cell);
|
||||
auto ast = std::make_shared<parse::ParseFunctionAst>(cell);
|
||||
bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD);
|
||||
if (!success) {
|
||||
MS_LOG(ERROR) << "Parse code to ast tree failed";
|
||||
|
|
|
@ -36,15 +36,15 @@ class DynamicParser {
|
|||
|
||||
private:
|
||||
static std::string GetCellInfo(const py::object &cell);
|
||||
static void ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node);
|
||||
static bool ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node,
|
||||
static void ParseInputArgs(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &fn_node);
|
||||
static bool ParseBodyContext(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &fn_node,
|
||||
const std::vector<std::string> &compare_prim = {});
|
||||
static bool ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
|
||||
static bool ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
|
||||
static bool ParseAugAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
|
||||
static bool ParseIfWhileExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node);
|
||||
static bool ParseAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node);
|
||||
static bool ParseAugAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node,
|
||||
const std::vector<std::string> &compare_prim = {});
|
||||
static bool ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
|
||||
static std::string ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
|
||||
static bool ParseForExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node);
|
||||
static std::string ParseNodeName(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node,
|
||||
parse::AstMainType type);
|
||||
};
|
||||
} // namespace mindspore::parse
|
||||
|
|
|
@ -43,9 +43,22 @@ abstract::AbstractBasePtr ClassObject::ToAbstract() {
|
|||
return std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list);
|
||||
}
|
||||
|
||||
static inline bool IsSupportedCreateInstanceType(const py::object &obj) {
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
auto res = python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE, obj);
|
||||
if (!py::isinstance<py::bool_>(res)) {
|
||||
MS_LOG(ERROR) << "Expect a bool type, but got " << py::str(res);
|
||||
return false;
|
||||
}
|
||||
return res.cast<bool>();
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr ClassType::ToAbstract() {
|
||||
auto abs_scalar =
|
||||
std::make_shared<abstract::AbstractScalar>(shared_from_base<ClassType>(), std::make_shared<TypeType>());
|
||||
if (!IsSupportedCreateInstanceType(obj())) {
|
||||
return abs_scalar;
|
||||
}
|
||||
AbstractBasePtrList args_spec_list = {abs_scalar};
|
||||
|
||||
auto func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimCreateInstance);
|
||||
|
@ -333,6 +346,7 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa
|
|||
auto new_namespace = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj);
|
||||
std::string attr_as_string = GetValueNode<StringImmPtr>(attr)->value();
|
||||
auto new_symbol = std::make_shared<Symbol>(attr_as_string);
|
||||
MS_LOG(DEBUG) << "name_space: " << new_namespace->ToString() << ", symbol: " << new_symbol->ToString();
|
||||
|
||||
AnfNodePtrList inputs = {NewValueNode(prim::kPrimResolve), NewValueNode(new_namespace), NewValueNode(new_symbol)};
|
||||
AnfNodePtr resolved_node = node->func_graph()->NewCNode(inputs);
|
||||
|
|
|
@ -36,15 +36,16 @@ using ResourceBasePtr = std::shared_ptr<ResourceBase>;
|
|||
|
||||
namespace mindspore {
|
||||
namespace parse {
|
||||
|
||||
// NameSpace class for resolving python code.
|
||||
class NameSpace : public Named {
|
||||
public:
|
||||
NameSpace(const std::string &module, const py::object &obj) : Named(module), module_(module), obj_(obj) {}
|
||||
NameSpace(const std::string &module, const py::object &obj, const py::object &module_obj = py::object())
|
||||
: Named(module), module_(module), obj_(obj), module_obj_(module_obj) {}
|
||||
~NameSpace() override = default;
|
||||
MS_DECLARE_PARENT(NameSpace, Named);
|
||||
|
||||
py::object obj() { return obj_; }
|
||||
py::object module_obj() { return module_obj_; }
|
||||
std::string module() { return module_; }
|
||||
abstract::AbstractBasePtr ToAbstract() override {
|
||||
return std::make_shared<abstract::AbstractScalar>(shared_from_base<NameSpace>(), std::make_shared<External>());
|
||||
|
@ -55,6 +56,8 @@ class NameSpace : public Named {
|
|||
std::string module_;
|
||||
// namespace object
|
||||
py::object obj_;
|
||||
// module object
|
||||
py::object module_obj_;
|
||||
};
|
||||
using NameSpacePtr = std::shared_ptr<NameSpace>;
|
||||
|
||||
|
@ -62,7 +65,7 @@ using NameSpacePtr = std::shared_ptr<NameSpace>;
|
|||
class Symbol : public Named {
|
||||
public:
|
||||
explicit Symbol(const std::string &symbol) : Named(symbol), symbol_(symbol) {}
|
||||
explicit Symbol(const std::string &symbol, const std::string &name) : Named(name), symbol_(symbol) {}
|
||||
Symbol(const std::string &symbol, const std::string &name) : Named(name), symbol_(symbol) {}
|
||||
|
||||
~Symbol() override = default;
|
||||
MS_DECLARE_PARENT(Symbol, Named);
|
||||
|
@ -77,6 +80,25 @@ class Symbol : public Named {
|
|||
};
|
||||
using SymbolPtr = std::shared_ptr<Symbol>;
|
||||
|
||||
class Script : public Named {
|
||||
public:
|
||||
explicit Script(const std::string &script) : Named(script), script_(script) {}
|
||||
Script(const std::string &script, const std::string &name) : Named(name), script_(script) {}
|
||||
|
||||
~Script() override = default;
|
||||
MS_DECLARE_PARENT(Script, Named);
|
||||
|
||||
std::string script() { return script_; }
|
||||
abstract::AbstractBasePtr ToAbstract() override {
|
||||
return std::make_shared<abstract::AbstractScript>(shared_from_base<Script>());
|
||||
}
|
||||
std::string ToString() const override { return "`" + name() + "`"; }
|
||||
|
||||
private:
|
||||
std::string script_;
|
||||
};
|
||||
using ScriptPtr = std::shared_ptr<Script>;
|
||||
|
||||
// PyObjectWrapper class wrappers resolved python object for further processing.
|
||||
class PyObjectWrapper : public Named {
|
||||
public:
|
||||
|
|
|
@ -359,7 +359,7 @@ py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data,
|
|||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
|
||||
if (output_node->isa<ValueNode>()) {
|
||||
return ValuePtrToPyData(GetValueNode(output_node));
|
||||
return ValueToPyData(GetValueNode(output_node));
|
||||
}
|
||||
|
||||
if (output_node->isa<Parameter>()) {
|
||||
|
|
|
@ -402,15 +402,15 @@ EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
|
||||
EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
if (args_conf_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Size should greater than 0";
|
||||
}
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
return conf->ObtainEvalResult()->abstract();
|
||||
});
|
||||
if (args_conf_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Size should greater than 0";
|
||||
}
|
||||
EvalResultPtr res = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf);
|
||||
// No need to cache.
|
||||
return res;
|
||||
|
|
|
@ -298,7 +298,7 @@ py::object BuildValue(const ValuePtr &value_ptr) {
|
|||
if (value_ptr == nullptr) {
|
||||
return py::none();
|
||||
} else {
|
||||
return ValuePtrToPyData(value_ptr);
|
||||
return ValueToPyData(value_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -786,23 +786,6 @@ EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveIm
|
|||
return uniform_primitive_evaluator;
|
||||
}
|
||||
|
||||
const int64_t kResolveCaseUserDefineClass = 1;
|
||||
const int64_t kResolveCaseBuiltInType = 2;
|
||||
const int64_t kResolveCaseFunction = 3;
|
||||
int64_t GetResolveCase(const TypePtr &data_type) {
|
||||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
if (data_type->type_id() == kObjectTypeClass) {
|
||||
return kResolveCaseUserDefineClass;
|
||||
}
|
||||
|
||||
// try method map, if not in method map, the data_type should be External type.
|
||||
if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) {
|
||||
return kResolveCaseBuiltInType;
|
||||
}
|
||||
|
||||
return kResolveCaseFunction;
|
||||
}
|
||||
|
||||
FuncGraphPtr PyObjToGraph(const AnalysisEnginePtr &engine, const ValuePtr &method) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
MS_EXCEPTION_IF_NULL(method);
|
||||
|
@ -883,18 +866,18 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, con
|
|||
MS_LOG(EXCEPTION) << "Data is not NameSpace : " << data_v->ToString();
|
||||
}
|
||||
|
||||
auto item_v = args_spec_list[1]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(item_v);
|
||||
if (item_v->isa<StringImm>()) {
|
||||
item_v = std::make_shared<parse::Symbol>(item_v->cast<StringImmPtr>()->value());
|
||||
auto item_value = args_spec_list[1]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(item_value);
|
||||
if (item_value->isa<StringImm>()) {
|
||||
item_value = std::make_shared<parse::Symbol>(item_value->cast<StringImmPtr>()->value());
|
||||
}
|
||||
|
||||
if (!item_v->isa<parse::Symbol>()) {
|
||||
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_v->ToString();
|
||||
if (!item_value->isa<parse::Symbol>()) {
|
||||
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
|
||||
}
|
||||
|
||||
// item_name to func addr from obj_map
|
||||
parse::SymbolPtr symbol = item_v->cast<parse::SymbolPtr>();
|
||||
parse::SymbolPtr symbol = item_value->cast<parse::SymbolPtr>();
|
||||
parse::NameSpacePtr name_space = data_v->cast<parse::NameSpacePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
auto out_node = out_conf->node();
|
||||
|
@ -915,19 +898,20 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, con
|
|||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine,
|
||||
const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v,
|
||||
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
|
||||
const AbstractBasePtrList &args_spec_list,
|
||||
const ValuePtr &item_value, const ConfigPtr &data_conf,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
if (args_spec_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "args_spec_list is empty";
|
||||
}
|
||||
AbstractClassPtr cls = CheckArg<AbstractClass>("__FUNC__", args_spec_list, 0);
|
||||
|
||||
// If item_v is an attribute, get abstract value from AbstractClass
|
||||
MS_EXCEPTION_IF_NULL(item_v);
|
||||
if (!item_v->isa<StringImm>()) {
|
||||
// If item_value is an attribute, get abstract value from AbstractClass
|
||||
MS_EXCEPTION_IF_NULL(item_value);
|
||||
if (!item_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << "Attribute type error";
|
||||
}
|
||||
std::string item_name = item_v->cast<StringImmPtr>()->value();
|
||||
std::string item_name = item_value->cast<StringImmPtr>()->value();
|
||||
MS_LOG(DEBUG) << "Resolve name: " << cls->tag().name();
|
||||
MS_LOG(DEBUG) << "Resolve item: " << item_name;
|
||||
MS_EXCEPTION_IF_NULL(cls);
|
||||
|
@ -941,25 +925,25 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng
|
|||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]->BuildType());
|
||||
MS_EXCEPTION(AttributeError) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString()
|
||||
<< ", item value: " << item_v->ToString();
|
||||
<< ", item value: " << item_value->ToString();
|
||||
}
|
||||
|
||||
// Infer class method
|
||||
ValuePtr converted_v = PyObjToGraph(engine, method);
|
||||
return StaticGetterInferred(converted_v, data_conf, out_conf);
|
||||
ValuePtr converted_value = PyObjToGraph(engine, method);
|
||||
return StaticGetterInferred(converted_value, data_conf, out_conf);
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v,
|
||||
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
|
||||
const TypePtr &data_type, const ConfigPtr &data_conf,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
MS_EXCEPTION_IF_NULL(item_v);
|
||||
MS_EXCEPTION_IF_NULL(item_value);
|
||||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
// The method maybe a Primitive or Composite
|
||||
if (!item_v->isa<StringImm>()) {
|
||||
if (!item_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << "Error item is not string";
|
||||
}
|
||||
|
||||
std::string item_name = item_v->cast<StringImmPtr>()->value();
|
||||
std::string item_name = item_value->cast<StringImmPtr>()->value();
|
||||
REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD;
|
||||
Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name);
|
||||
if (require.empty()) {
|
||||
|
@ -971,20 +955,38 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePt
|
|||
require_type = REQUIRE_TYPE::ATTR;
|
||||
}
|
||||
|
||||
ValuePtr converted_v = nullptr;
|
||||
ValuePtr converted_value = nullptr;
|
||||
if (require.is<std::string>()) {
|
||||
// composite registered in standard_method_map go to this branch
|
||||
converted_v = prim::GetPythonOps(require.cast<std::string>());
|
||||
MS_EXCEPTION_IF_NULL(converted_v);
|
||||
if (!converted_v->isa<Primitive>()) {
|
||||
AddToManager(engine, converted_v->cast<FuncGraphPtr>());
|
||||
converted_value = prim::GetPythonOps(require.cast<std::string>());
|
||||
MS_EXCEPTION_IF_NULL(converted_value);
|
||||
if (!converted_value->isa<Primitive>()) {
|
||||
AddToManager(engine, converted_value->cast<FuncGraphPtr>());
|
||||
}
|
||||
} else if (require.is<PrimitivePtr>()) {
|
||||
converted_v = require.cast<PrimitivePtr>();
|
||||
converted_value = require.cast<PrimitivePtr>();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from attr or method map, but got " << require.ToString();
|
||||
}
|
||||
return StaticGetterInferred(converted_v, data_conf, out_conf, require_type);
|
||||
return StaticGetterInferred(converted_value, data_conf, out_conf, require_type);
|
||||
}
|
||||
|
||||
enum ResolveType : int64_t {
|
||||
kResolveTypeUserDefineClass = 1,
|
||||
kResolveTypeBuiltInType,
|
||||
kResolveTypeFunction,
|
||||
};
|
||||
|
||||
int64_t GetResolveType(const TypePtr &data_type) {
|
||||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
if (data_type->type_id() == kObjectTypeClass) {
|
||||
return kResolveTypeUserDefineClass;
|
||||
}
|
||||
// Try to search method map, if not found, the data_type should be External type.
|
||||
if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) {
|
||||
return kResolveTypeBuiltInType;
|
||||
}
|
||||
return kResolveTypeFunction;
|
||||
}
|
||||
|
||||
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
|
@ -1006,10 +1008,10 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
|||
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
|
||||
}
|
||||
|
||||
int64_t case_v = GetResolveCase(data_type);
|
||||
if (case_v == kResolveCaseUserDefineClass) {
|
||||
int64_t resolve_type = GetResolveType(data_type);
|
||||
if (resolve_type == kResolveTypeUserDefineClass) {
|
||||
return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf);
|
||||
} else if (case_v == kResolveCaseBuiltInType) {
|
||||
} else if (resolve_type == kResolveTypeBuiltInType) {
|
||||
return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, item_value, data_type, data_conf, out_conf);
|
||||
} else {
|
||||
return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf);
|
||||
|
@ -1234,12 +1236,12 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
return infer_result;
|
||||
}
|
||||
|
||||
pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const {
|
||||
py::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const {
|
||||
// Exclude class type by minus 1;
|
||||
std::size_t params_size = args_spec_list.size() - 1;
|
||||
auto params = py::tuple(params_size);
|
||||
if (params_size > params.size()) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected params_size:" << params_size << ",params.size():" << params.size();
|
||||
MS_LOG(EXCEPTION) << "Unexpected params_size: " << params_size << ", params.size():" << params.size();
|
||||
}
|
||||
if (params_size > 0) {
|
||||
for (size_t i = 0; i < params_size; i++) {
|
||||
|
@ -1248,7 +1250,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
MS_EXCEPTION_IF_NULL(arg);
|
||||
// Because the Tensor's AbstractTensor can't get value from GetValueTrack.
|
||||
ValuePtr param_value = arg->BuildValue();
|
||||
py::object param = ValuePtrToPyData(param_value);
|
||||
py::object param = ValueToPyData(param_value);
|
||||
params[i] = param;
|
||||
}
|
||||
}
|
||||
|
@ -1256,6 +1258,80 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
}
|
||||
};
|
||||
|
||||
class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
||||
public:
|
||||
PyInterpretEvaluator() : TransitionPrimEvaluator("PyInterpretEvaluator") {}
|
||||
~PyInterpretEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(PyInterpretEvaluator, TransitionPrimEvaluator);
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
|
||||
const AnfNodeConfigPtr &out_conf) override {
|
||||
if (args_spec_list.empty()) {
|
||||
MS_LOG(ERROR) << "'args_spec_list' should not be empty";
|
||||
}
|
||||
|
||||
// Get the type parameter.
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
ValuePtr value_track = args_spec_list[0]->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
|
||||
std::shared_ptr<parse::Script> script_obj = dyn_cast<parse::Script>(value_track);
|
||||
if (script_obj == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
|
||||
}
|
||||
|
||||
// Make global and local parameters.
|
||||
py::tuple params = MakeParameters(args_spec_list);
|
||||
|
||||
// Call python script string.
|
||||
MS_LOG(DEBUG) << "Call script: " << script_obj->script() << ", params: " << py::str(params);
|
||||
auto obj = parse::data_converter::CallPythonScript(py::str(script_obj->script()), params);
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to call python script: `" << script_obj->script() << "`";
|
||||
}
|
||||
|
||||
ValuePtr converted_val = nullptr;
|
||||
bool converted = parse::ConvertData(obj, &converted_val, true);
|
||||
if (!converted) {
|
||||
MS_LOG(EXCEPTION) << "Convert the python object failed";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(converted_val);
|
||||
|
||||
AbstractBasePtr res = ToAbstract(converted_val, AnalysisContext::DummyContext(), out_conf);
|
||||
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
|
||||
return infer_result;
|
||||
}
|
||||
|
||||
py::tuple MakeParameters(const AbstractBasePtrList &args_spec_list) const {
|
||||
constexpr int params_size = 3;
|
||||
if (params_size != args_spec_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected params_size: " << params_size
|
||||
<< ", not equal to arguments.size:" << args_spec_list.size();
|
||||
}
|
||||
// The first argument is script string, ignore it.
|
||||
auto params = py::tuple(params_size - 1);
|
||||
|
||||
// Make the global parameters.
|
||||
auto global_dict = dyn_cast<AbstractDictionary>(args_spec_list[1]); // Global parameters dict.
|
||||
MS_EXCEPTION_IF_NULL(global_dict);
|
||||
MS_LOG(DEBUG) << "arg_1, global_dict: " << global_dict->ToString() << ", [" << global_dict->type_name() << "]";
|
||||
ValuePtr global_dict_value = global_dict->BuildValue();
|
||||
py::object global_params_dict = ValueToPyData(global_dict_value);
|
||||
MS_LOG(DEBUG) << "arg_1, python global_params_dict: " << py::str(global_params_dict);
|
||||
params[0] = global_params_dict;
|
||||
|
||||
// Make the local parameters.
|
||||
auto local_dict = dyn_cast<AbstractDictionary>(args_spec_list[2]); // Local parameters dict.
|
||||
MS_EXCEPTION_IF_NULL(local_dict);
|
||||
MS_LOG(DEBUG) << "arg_2, local_dict: " << local_dict->ToString() << ", [" << local_dict->type_name() << "]";
|
||||
ValuePtr local_dict_value = local_dict->BuildValue();
|
||||
py::object local_params_dict = ValueToPyData(local_dict_value);
|
||||
MS_LOG(DEBUG) << "arg_2, python local_params_dict: " << py::str(local_params_dict);
|
||||
params[1] = local_params_dict;
|
||||
return params;
|
||||
}
|
||||
};
|
||||
|
||||
class PartialEvaluator : public Evaluator {
|
||||
public:
|
||||
PartialEvaluator() : Evaluator("PartialEvaluator") {}
|
||||
|
@ -1399,6 +1475,7 @@ void InitPrimEvaluatorConstructors() {
|
|||
constructor[prim::kPrimResolve] = std::make_shared<ResolveEvaluator>();
|
||||
constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>();
|
||||
constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>();
|
||||
constructor[prim::kPrimPyInterpret] = std::make_shared<PyInterpretEvaluator>();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -286,7 +286,8 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
|
|||
|
||||
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(possible_func);
|
||||
if (func == nullptr) {
|
||||
MS_LOG(ERROR) << "Can not cast to a AbstractFunction: " << possible_func->ToString() << ".";
|
||||
MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << possible_func->ToString() << ".";
|
||||
MS_LOG(ERROR) << "It's called at: " << cnode->DebugString();
|
||||
MS_EXCEPTION(ValueError) << "This may be not defined, and it can't be a operator. Please check code.";
|
||||
}
|
||||
|
||||
|
|
|
@ -291,7 +291,7 @@ py::function PrimitivePy::GetComputeFunction() const {
|
|||
py::dict PrimitivePy::GetAttrDict() {
|
||||
py::dict attr_dict;
|
||||
for (auto &attr : attrs_) {
|
||||
attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second);
|
||||
attr_dict[py::str(attr.first)] = ValueToPyData(attr.second);
|
||||
}
|
||||
return attr_dict;
|
||||
}
|
||||
|
@ -430,7 +430,7 @@ py::dict PrimitivePyAdapter::GetAttrDict() {
|
|||
|
||||
py::dict attr_dict;
|
||||
for (auto &attr : attrs_) {
|
||||
attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second);
|
||||
attr_dict[py::str(attr.first)] = ValueToPyData(attr.second);
|
||||
}
|
||||
return attr_dict;
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "abstract/utils.h"
|
||||
#include "pipeline/jit/parse/parse.h"
|
||||
#include "pipeline/jit/parse/parse_base.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "ir/value.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/param_info.h"
|
||||
|
@ -106,80 +107,125 @@ py::object ScalarPtrToPyData(const ScalarPtr &value) {
|
|||
}
|
||||
}
|
||||
|
||||
py::object ValuePtrToPyData(const ValuePtr &value) {
|
||||
if (value == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "value is null";
|
||||
}
|
||||
py::object ret;
|
||||
if (value->isa<Scalar>()) {
|
||||
ret = ScalarPtrToPyData(value->cast<ScalarPtr>());
|
||||
} else if (value->isa<StringImm>()) {
|
||||
MS_LOG(DEBUG) << "String";
|
||||
py::str v = value->cast<StringImmPtr>()->value();
|
||||
ret = v;
|
||||
} else if (value->isa<tensor::Tensor>()) {
|
||||
MS_LOG(DEBUG) << "tensor";
|
||||
auto tensor_ptr = value->cast<tensor::TensorPtr>();
|
||||
ret = TensorToPyData(tensor_ptr);
|
||||
} else if (value->isa<tensor::MetaTensor>()) {
|
||||
MS_LOG(DEBUG) << "MetaTensor";
|
||||
py::tuple v(1);
|
||||
v[0] = value->cast<tensor::MetaTensorPtr>();
|
||||
ret = v[0];
|
||||
} else if (value->isa<RefKey>()) {
|
||||
MS_LOG(DEBUG) << "RefKey";
|
||||
py::tuple v(1);
|
||||
v[0] = value->cast<RefKeyPtr>();
|
||||
ret = v[0];
|
||||
} else if (value->isa<ValueSequeue>()) {
|
||||
MS_LOG(DEBUG) << "tuple or list";
|
||||
auto value_sequeue = value->cast<ValueSequeuePtr>()->value();
|
||||
py::tuple ret_sequeue(value_sequeue.size());
|
||||
using ConverterFunction = std::function<py::object(const ValuePtr &value)>;
|
||||
using ValueNameToConverterVector = std::vector<std::pair<const char *, ConverterFunction>>;
|
||||
|
||||
// (Value Type Name) -> (Converter Function)
|
||||
// The converter function is used to convert Value object to Python data object.
|
||||
static ValueNameToConverterVector value_name_to_converter = {
|
||||
// Scalar
|
||||
{typeid(Scalar).name(),
|
||||
[](const ValuePtr &value) -> py::object { return ScalarPtrToPyData(value->cast<ScalarPtr>()); }},
|
||||
// Tensor
|
||||
{typeid(tensor::Tensor).name(),
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
auto tensor_ptr = value->cast<tensor::TensorPtr>();
|
||||
return TensorToPyData(tensor_ptr);
|
||||
}},
|
||||
// MetaTenser
|
||||
{typeid(tensor::MetaTensor).name(),
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
py::tuple tuple_container(1);
|
||||
tuple_container[0] = value->cast<tensor::MetaTensorPtr>();
|
||||
return tuple_container[0];
|
||||
}},
|
||||
// RefKey
|
||||
{typeid(RefKey).name(),
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
py::tuple tuple_container(1);
|
||||
tuple_container[0] = value->cast<RefKeyPtr>();
|
||||
return tuple_container[0];
|
||||
}},
|
||||
// Type
|
||||
{typeid(Type).name(),
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
py::tuple tuple_container(1);
|
||||
tuple_container[0] = value->cast<TypePtr>();
|
||||
return tuple_container[0];
|
||||
}},
|
||||
// StringImm
|
||||
{typeid(StringImm).name(),
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
py::str res = value->cast<StringImmPtr>()->value();
|
||||
return res;
|
||||
}},
|
||||
// ValueSequeue
|
||||
{typeid(ValueSequeue).name(),
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
auto value_sequeue = value->cast<ValueSequeuePtr>()->value();
|
||||
py::tuple res_sequeue(value_sequeue.size());
|
||||
for (size_t i = 0; i < value_sequeue.size(); i++) {
|
||||
ret_sequeue[i] = ValuePtrToPyData(value_sequeue[i]);
|
||||
res_sequeue[i] = ValueToPyData(value_sequeue[i]);
|
||||
}
|
||||
if (value->isa<ValueTuple>()) {
|
||||
ret = ret_sequeue;
|
||||
} else {
|
||||
ret = ret_sequeue.cast<py::list>();
|
||||
return res_sequeue;
|
||||
}
|
||||
} else if (value->isa<ValueDictionary>()) {
|
||||
MS_LOG(DEBUG) << "dict";
|
||||
return res_sequeue.cast<py::list>();
|
||||
}},
|
||||
// ValueDictionary
|
||||
{typeid(ValueDictionary).name(),
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
auto value_list = value->cast<ValueDictionaryPtr>()->value();
|
||||
py::dict ret_dict;
|
||||
for (const auto &v : value_list) {
|
||||
ret_dict[py::str(v.first)] = ValuePtrToPyData(v.second);
|
||||
py::dict res_dict;
|
||||
for (const auto &value : value_list) {
|
||||
res_dict[py::str(value.first)] = ValueToPyData(value.second);
|
||||
}
|
||||
ret = ret_dict;
|
||||
} else if (value->isa<Ellipsis>()) {
|
||||
ret = py::ellipsis();
|
||||
} else if (value->isa<ValueSlice>()) {
|
||||
return res_dict;
|
||||
}},
|
||||
// ValueSlice
|
||||
{typeid(ValueSlice).name(),
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
auto slice = value->cast<ValueSlicePtr>();
|
||||
auto start = ValuePtrToPyData(slice->start());
|
||||
auto end = ValuePtrToPyData(slice->stop());
|
||||
auto step = ValuePtrToPyData(slice->step());
|
||||
ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end,
|
||||
auto start = ValueToPyData(slice->start());
|
||||
auto end = ValueToPyData(slice->stop());
|
||||
auto step = ValueToPyData(slice->step());
|
||||
return parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end,
|
||||
step);
|
||||
} else if (value->isa<Type>()) {
|
||||
py::tuple v(1);
|
||||
v[0] = value->cast<TypePtr>();
|
||||
ret = v[0];
|
||||
} else if (value->isa<AnyValue>() || value->isa<None>() || value->isa<Monad>() || value->isa<FuncGraph>()) {
|
||||
// FuncGraph is not used in the backend, return None
|
||||
ret = py::none();
|
||||
} else if (value->isa<KeywordArg>()) {
|
||||
}},
|
||||
// KeywordArg
|
||||
{typeid(KeywordArg).name(),
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
auto abs_keyword_arg = value->ToAbstract()->cast<abstract::AbstractKeywordArgPtr>();
|
||||
auto key = abs_keyword_arg->get_key();
|
||||
auto val = abs_keyword_arg->get_arg()->BuildValue();
|
||||
auto py_value = ValuePtrToPyData(val);
|
||||
auto py_value = ValueToPyData(val);
|
||||
auto kwargs = py::kwargs();
|
||||
kwargs[key.c_str()] = py_value;
|
||||
ret = kwargs;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
|
||||
return kwargs;
|
||||
}},
|
||||
// parse::NameSpace
|
||||
{typeid(parse::NameSpace).name(),
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
auto ns = value->cast<parse::NameSpacePtr>();
|
||||
return ns->module_obj();
|
||||
}},
|
||||
// parse::ClassType
|
||||
{typeid(parse::ClassType).name(),
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
auto class_type = value->cast<parse::ClassTypePtr>();
|
||||
return class_type->obj();
|
||||
}},
|
||||
// None
|
||||
{typeid(None).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
|
||||
// AnyValue
|
||||
{typeid(AnyValue).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
|
||||
// FuncGraph
|
||||
{typeid(FuncGraph).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
|
||||
// Monad
|
||||
{typeid(Monad).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
|
||||
// Ellipsis
|
||||
{typeid(Ellipsis).name(), [](const ValuePtr &value) -> py::object { return py::ellipsis(); }}};
|
||||
|
||||
py::object ValueToPyData(const ValuePtr &value) {
|
||||
if (value == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The `value` should not be null";
|
||||
}
|
||||
return ret;
|
||||
for (auto &iter : value_name_to_converter) {
|
||||
if (value->IsFromTypeId(Base::GetTypeId(iter.first))) {
|
||||
return iter.second(value);
|
||||
}
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Unsupported to convert " << value->ToString() << "[" << value->type_name() << "] to a PyData";
|
||||
}
|
||||
|
||||
py::object AnyToPyData(const Any &value) {
|
||||
|
@ -190,7 +236,7 @@ py::object AnyToPyData(const Any &value) {
|
|||
} else if (value.is<ValuePtr>()) {
|
||||
MS_LOG(DEBUG) << "ValuePtr";
|
||||
ValuePtr v = value.cast<ValuePtr>();
|
||||
ret = ValuePtrToPyData(v);
|
||||
ret = ValueToPyData(v);
|
||||
} else if (value.is<tensor::TensorPtr>()) {
|
||||
MS_LOG(DEBUG) << "tensor";
|
||||
auto tensor_ptr = value.cast<tensor::TensorPtr>();
|
||||
|
@ -233,7 +279,7 @@ py::object BaseRefToPyData(const BaseRef &value) {
|
|||
} else if (utils::isa<ValuePtr>(value)) {
|
||||
MS_LOG(DEBUG) << "ValuePtr";
|
||||
ValuePtr v = utils::cast<ValuePtr>(value);
|
||||
ret = ValuePtrToPyData(v);
|
||||
ret = ValueToPyData(v);
|
||||
} else if (utils::isa<tensor::TensorPtr>(value)) {
|
||||
MS_LOG(DEBUG) << "tensor";
|
||||
auto tensor_ptr = utils::cast<tensor::TensorPtr>(value);
|
||||
|
@ -459,7 +505,7 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
|
|||
if (output->isa<ValueNode>()) {
|
||||
MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
|
||||
ValuePtr value = GetValueNode(output);
|
||||
*ret_val = ValuePtrToPyData(value);
|
||||
*ret_val = ValueToPyData(value);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace py = pybind11;
|
|||
namespace mindspore {
|
||||
py::object AnyToPyData(const Any &value);
|
||||
py::object BaseRefToPyData(const BaseRef &value);
|
||||
py::object ValuePtrToPyData(const ValuePtr &value);
|
||||
py::object ValueToPyData(const ValuePtr &value);
|
||||
|
||||
bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
|
||||
const std::shared_ptr<py::object> &ret_val);
|
||||
|
|
|
@ -320,6 +320,38 @@ std::string TypedPrimitiveAbstractClosure::ToString() const {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
bool PyInterpretAbstractClosure::operator==(const AbstractFunction &other) const {
|
||||
if (!other.isa<PyInterpretAbstractClosure>()) {
|
||||
return false;
|
||||
}
|
||||
auto other_partial = static_cast<const PyInterpretAbstractClosure *>(&other);
|
||||
if (fn_ != other_partial->fn_) {
|
||||
return false;
|
||||
}
|
||||
if (args_spec_list_.size() != other_partial->args_spec_list_.size()) {
|
||||
return false;
|
||||
}
|
||||
return args_spec_list_ == other_partial->args_spec_list_;
|
||||
}
|
||||
|
||||
std::size_t PyInterpretAbstractClosure::hash() const {
|
||||
MS_EXCEPTION_IF_NULL(fn_);
|
||||
auto hash_value = hash_combine(tid(), fn_->hash());
|
||||
hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_));
|
||||
return hash_value;
|
||||
}
|
||||
|
||||
std::string PyInterpretAbstractClosure::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
buffer << "PyInterpretAbstractClosure(" << fn_->ToString() << "(";
|
||||
for (const auto &arg : args_spec_list_) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
buffer << arg->ToString() << ", ";
|
||||
}
|
||||
buffer << "))";
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
bool DummyAbstractClosure::operator==(const AbstractFunction &other) const {
|
||||
return !other.isa<DummyAbstractClosure>();
|
||||
}
|
||||
|
|
|
@ -192,7 +192,7 @@ class MS_CORE_API PartialAbstractClosure : public AbstractFuncAtom {
|
|||
MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom)
|
||||
|
||||
AbstractFunctionPtr fn() { return fn_; }
|
||||
AbstractBasePtrList args() { return args_spec_list_; }
|
||||
AbstractBasePtrList &args() { return args_spec_list_; }
|
||||
ValuePtr RealBuildValue() const override { return fn_->BuildValue(); }
|
||||
AnfNodePtr node() { return node_.lock(); }
|
||||
void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); }
|
||||
|
@ -287,6 +287,34 @@ class MS_CORE_API TypedPrimitiveAbstractClosure : public AbstractFuncAtom {
|
|||
AbstractBasePtr output_;
|
||||
};
|
||||
|
||||
class PyInterpretAbstractClosure : public AbstractFuncAtom {
|
||||
public:
|
||||
PyInterpretAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list,
|
||||
const AnfNodePtr &node = nullptr)
|
||||
: fn_(fn), args_spec_list_(args_spec_list), node_(AnfNodePtr(node)) {}
|
||||
~PyInterpretAbstractClosure() override = default;
|
||||
MS_DECLARE_PARENT(PyInterpretAbstractClosure, AbstractFuncAtom)
|
||||
|
||||
AbstractFunctionPtr fn() { return fn_; }
|
||||
AbstractBasePtrList args() { return args_spec_list_; }
|
||||
ValuePtr RealBuildValue() const override { return fn_->BuildValue(); }
|
||||
AnfNodePtr node() { return node_.lock(); }
|
||||
void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); }
|
||||
AbstractFunctionPtr Copy() const override {
|
||||
return std::make_shared<PyInterpretAbstractClosure>(fn_, args_spec_list_, node_.lock());
|
||||
}
|
||||
bool operator==(const AbstractFunction &other) const override;
|
||||
std::size_t hash() const override;
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
private:
|
||||
AbstractFuncAtomPtr fn_;
|
||||
AbstractBasePtrList args_spec_list_;
|
||||
AnfNodeWeakPtr node_;
|
||||
};
|
||||
using PyInterpretAbstractClosurePtr = std::shared_ptr<PyInterpretAbstractClosure>;
|
||||
|
||||
// Represents a function that can't be called.
|
||||
class MS_CORE_API DummyAbstractClosure : public AbstractFuncAtom {
|
||||
public:
|
||||
|
|
|
@ -104,7 +104,7 @@ class MS_CORE_API AbstractBase : public Base {
|
|||
class MS_CORE_API AbstractScalar : public AbstractBase {
|
||||
public:
|
||||
AbstractScalar() : AbstractBase(kAnyValue, kAnyType) {}
|
||||
explicit AbstractScalar(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}
|
||||
AbstractScalar(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}
|
||||
explicit AbstractScalar(const ValuePtr &value) : AbstractBase(value, value->type()) {}
|
||||
explicit AbstractScalar(int value) : AbstractBase(MakeValue(value), kInt32) {}
|
||||
explicit AbstractScalar(int64_t value) : AbstractBase(MakeValue(value), kInt64) {}
|
||||
|
@ -148,7 +148,7 @@ using AbstractTypePtr = std::shared_ptr<AbstractType>;
|
|||
|
||||
class MS_CORE_API AbstractError : public AbstractBase {
|
||||
public:
|
||||
explicit AbstractError(const StringImmPtr &err, const AnfNodePtr &node) : AbstractBase(err), node_(node) {
|
||||
AbstractError(const StringImmPtr &err, const AnfNodePtr &node) : AbstractBase(err), node_(node) {
|
||||
if (err == nullptr || node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "err or node is nullptr";
|
||||
}
|
||||
|
@ -170,6 +170,25 @@ class MS_CORE_API AbstractError : public AbstractBase {
|
|||
const AnfNodePtr node_;
|
||||
};
|
||||
|
||||
class MS_CORE_API AbstractScript : public AbstractBase {
|
||||
public:
|
||||
AbstractScript() : AbstractBase(kAnyValue, kAnyType) {}
|
||||
AbstractScript(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {}
|
||||
explicit AbstractScript(const ValuePtr &value) : AbstractBase(value, kString) {}
|
||||
// explicit AbstractScript(const std::string &value) : AbstractBase(MakeValue(value), kString) {}
|
||||
~AbstractScript() override = default;
|
||||
MS_DECLARE_PARENT(AbstractScript, AbstractBase)
|
||||
|
||||
std::size_t hash() const override { return hash_combine({tid(), GetValueTrack()->hash(), GetTypeTrack()->hash()}); }
|
||||
|
||||
TypePtr BuildType() const override { return GetTypeTrack(); }
|
||||
AbstractBasePtr Clone() const override {
|
||||
return std::make_shared<AbstractScript>(GetValueTrack(), GetTypeTrack()->Clone());
|
||||
}
|
||||
AbstractBasePtr Broaden() const override { return Clone(); }
|
||||
};
|
||||
using AbstractScriptPtr = std::shared_ptr<AbstractScript>;
|
||||
|
||||
class Evaluator;
|
||||
using EvaluatorPtr = std::shared_ptr<Evaluator>;
|
||||
class AnalysisEngine;
|
||||
|
|
|
@ -604,6 +604,9 @@ inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_
|
|||
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
|
||||
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
|
||||
|
||||
// Python interpreter runner
|
||||
inline const PrimitivePtr kPrimPyInterpret = std::make_shared<Primitive>("PyInterpret");
|
||||
|
||||
// Other primitive not used by backend but used in core;
|
||||
inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
|
||||
inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J", kSideEffectPropagate);
|
||||
|
|
|
@ -107,7 +107,9 @@ class MS_CORE_API AnfNode : public Base {
|
|||
hash_(std::hash<const AnfNode *>()),
|
||||
kernel_info_(nullptr),
|
||||
stage_(-1),
|
||||
need_grad_(false) {
|
||||
need_grad_(false),
|
||||
interpret_(false),
|
||||
interpreted_node_(nullptr) {
|
||||
scope_ = ScopeManager::GetInstance().GetCurrentScope();
|
||||
}
|
||||
|
||||
|
@ -204,6 +206,12 @@ class MS_CORE_API AnfNode : public Base {
|
|||
bool grad() { return need_grad_; }
|
||||
void set_grad(const bool &need_grad) { need_grad_ = need_grad; }
|
||||
|
||||
bool interpret() { return interpret_; }
|
||||
void set_interpret(const bool interpret) { interpret_ = interpret; }
|
||||
|
||||
AnfNodePtr interpreted_node() { return interpreted_node_; }
|
||||
void set_interpreted_node(const AnfNodePtr &node) { interpreted_node_ = node; }
|
||||
|
||||
protected:
|
||||
// Hold a weak ref to Graph as Graph also hold ref to AnfNode.
|
||||
// Otherwise, func_graph_ and AnfNode will make a reference cycle.
|
||||
|
@ -220,6 +228,8 @@ class MS_CORE_API AnfNode : public Base {
|
|||
UserData user_data_;
|
||||
int64_t stage_;
|
||||
bool need_grad_;
|
||||
bool interpret_;
|
||||
AnfNodePtr interpreted_node_;
|
||||
};
|
||||
|
||||
// CNode represents the complex node with a set of arguments.
|
||||
|
|
|
@ -291,7 +291,7 @@ TEST_F(TestStepParallel, CreatOpInstance) {
|
|||
|
||||
std::vector<py::object> arglist;
|
||||
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arglist),
|
||||
[](Attr attr) { return ValuePtrToPyData(attr.second); });
|
||||
[](Attr attr) { return ValueToPyData(attr.second); });
|
||||
py::object allreduce_pyobj = parse::python_adapter::CallPyFn(
|
||||
"mindspore.parallel._utils", "_get_python_op", "AllReduce", "mindspore.ops.operations", "test", arglist);
|
||||
py::dict opAttr = py::getattr(allreduce_pyobj, "attrs");
|
||||
|
|
|
@ -65,7 +65,7 @@ TEST_F(TestParser, TestParseApi) {
|
|||
TEST_F(TestParser, TestParseAst) {
|
||||
GetPythonFunction("test_f");
|
||||
|
||||
ParseAst ast = ParseAst(fn);
|
||||
ParseFunctionAst ast = ParseFunctionAst(fn);
|
||||
bool succ = ast.InitParseAstInfo();
|
||||
ASSERT_TRUE(succ = true);
|
||||
|
||||
|
|
|
@ -47,7 +47,8 @@ def test_use_numpy_method():
|
|||
net = Net()
|
||||
with pytest.raises(NotImplementedError) as err:
|
||||
net()
|
||||
assert "MindSpore does not support to use the numpy methods in the function construct with the graph mode." \
|
||||
assert "Mindspore does not support to use the numpy methods " \
|
||||
"within the construct() or @ms_function decorated function in graph mode." \
|
||||
in str(err.value)
|
||||
|
||||
|
||||
|
@ -63,5 +64,6 @@ def test_use_numpy_module():
|
|||
net = Net()
|
||||
with pytest.raises(NotImplementedError) as err:
|
||||
net()
|
||||
assert "MindSpore does not support to use the numpy methods in the function construct with the graph mode." \
|
||||
assert "Mindspore does not support to use the numpy methods " \
|
||||
"within the construct() or @ms_function decorated function in graph mode." \
|
||||
in str(err.value)
|
||||
|
|
Loading…
Reference in New Issue