!23039 [Fallback] Fix previous block parameters problem and add test cases.

Merge pull request !23039 from 张清华/opt_fallback
This commit is contained in:
i-robot 2021-09-11 07:39:02 +00:00 committed by Gitee
commit 0071667155
11 changed files with 217 additions and 88 deletions

View File

@ -76,6 +76,7 @@ parse_expr_statement_white_list = (
"append", "append",
) )
_builtin_function_or_method_type = type(abs)
def create_slice_obj(start, end, step): def create_slice_obj(start, end, step):
"""Create slice object""" """Create slice object"""
@ -248,6 +249,7 @@ def get_obj_id(obj):
def get_obj_type(obj): def get_obj_type(obj):
"""Get the obj type.""" """Get the obj type."""
logger.debug("Get object type: %r", obj)
obj_type = RESOLVE_TYPE_INVALID obj_type = RESOLVE_TYPE_INVALID
if obj is None: if obj is None:
obj_type = RESOLVE_TYPE_NONE obj_type = RESOLVE_TYPE_NONE
@ -529,9 +531,9 @@ class Parser:
# Used to resolve mindspore builtin ops namespace. # Used to resolve mindspore builtin ops namespace.
self.ms_common_ns = CellNamespace('mindspore.common') self.ms_common_ns = CellNamespace('mindspore.common')
self.ms_ops_ns = CellNamespace('mindspore.ops') self.ms_ops_ns = CellNamespace('mindspore.ops')
self.ms_ops_c = CellNamespace('mindspore.ops.composite') self.ms_ops_c_ns = CellNamespace('mindspore.ops.composite')
self.ms_ops_c_multitype = CellNamespace('mindspore.ops.composite.multitype_ops') self.ms_ops_c_multitype_ns = CellNamespace('mindspore.ops.composite.multitype_ops')
self.ms_ops_p = CellNamespace('mindspore.ops.operations') self.ms_ops_p_ns = CellNamespace('mindspore.ops.operations')
# Used to resolve the function's globals namespace. # Used to resolve the function's globals namespace.
self.global_namespace = CellNamespace(fn.__module__) self.global_namespace = CellNamespace(fn.__module__)
self.function_module = fn.__module__ self.function_module = fn.__module__
@ -567,6 +569,11 @@ class Parser:
logger.error("Fn type is invalid") logger.error("Fn type is invalid")
return None, None return None, None
def is_unsupported_namespace(self, value):
unsupported = isinstance(value, _builtin_function_or_method_type) and value not in convert_object_map
logger.debug(f'`{value}` unsupported: {unsupported}.')
return unsupported
def get_namespace_symbol(self, var: str): def get_namespace_symbol(self, var: str):
"""Get symbol type and namespace and symbol.""" """Get symbol type and namespace and symbol."""
if var in self.closure_namespace: if var in self.closure_namespace:
@ -575,7 +582,7 @@ class Parser:
if var in self.global_namespace: if var in self.global_namespace:
logger.debug(f"Found `{var}` in global_namespace {self.global_namespace.__str__()}") logger.debug(f"Found `{var}` in global_namespace {self.global_namespace.__str__()}")
value = self.global_namespace[var] value = self.global_namespace[var]
if isinstance(value, type(abs)) and self.global_namespace[var] not in convert_object_map: if self.is_unsupported_namespace(value):
error_info = f"The builtin function '{var}' is not supported in graph mode." error_info = f"The builtin function '{var}' is not supported in graph mode."
return None, var, error_info return None, var, error_info
return self.global_namespace, var return self.global_namespace, var
@ -604,6 +611,11 @@ class Parser:
logger.debug(f'Found `{name}` in mindspore root namespace.') logger.debug(f'Found `{name}` in mindspore root namespace.')
return True return True
# Check `Tensor` namespace.
if value == Tensor:
logger.debug(f'Not support `{name}`.')
return False
# Check `builtins` namespace. # Check `builtins` namespace.
if hasattr(value, '__module__'): # Not types.ModuleType if hasattr(value, '__module__'): # Not types.ModuleType
mod = value.__module__ mod = value.__module__
@ -613,25 +625,29 @@ class Parser:
# We suppose it's supported if not a Module. # We suppose it's supported if not a Module.
if not isinstance(value, types.ModuleType): if not isinstance(value, types.ModuleType):
logger.debug(f'Found `{name}`, not a module.')
return True return True
# Check supported Module namespace. # Check supported Module namespace.
rightmost_name = name.split('.')[-1] rightmost_name = name.split('.')[-1]
# By now, we don't check `self.ms_common_ns`.
if rightmost_name in self.ms_ops_ns: if rightmost_name in self.ms_ops_ns:
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_ns.__str__()}.') logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {str(self.ms_ops_ns)}.')
return True return True
if rightmost_name in self.ms_ops_c: if rightmost_name in self.ms_ops_c_ns:
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_c.__str__()}.') logger.debug(f'Found `{name}`({rightmost_name}) in C namespace: {str(self.ms_ops_c_ns)}.')
return True return True
if rightmost_name in self.ms_ops_c_multitype: if rightmost_name in self.ms_ops_c_multitype_ns:
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_c_multitype.__str__()}.') logger.debug(
f'Found `{name}`({rightmost_name}) in C.multitype namespace: {str(self.ms_ops_c_multitype_ns)}.')
return True return True
if rightmost_name in self.ms_ops_p: if rightmost_name in self.ms_ops_p_ns:
logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_p.__str__()}.') logger.debug(f'Found `{name}`({rightmost_name}) in P namespace: {str(self.ms_ops_p_ns)}.')
return True
if rightmost_name in self.ms_common_ns:
logger.debug(f'Found `{name}`({rightmost_name}) in P namespace: {str(self.ms_common_ns)}.')
return True return True
if rightmost_name in trope_ns: if rightmost_name in trope_ns:
logger.debug(f'Found `{name}`({rightmost_name}) in trope namespace: {self.trope_ns.__str__()}.') logger.debug(f'Found `{name}`({rightmost_name}) in trope namespace: {str(trope_ns)}.')
return True return True
logger.error(f'Not found `{name}` in mindspore supported namespace.') logger.error(f'Not found `{name}` in mindspore supported namespace.')
@ -648,14 +664,16 @@ class Parser:
value_str = value.__name__ if hasattr(value, '__name__') else str(value) value_str = value.__name__ if hasattr(value, '__name__') else str(value)
logger.debug(f"value: {type(value)}, `{value_str}`, hasattr(__name__): {hasattr(value, '__name__')}.") logger.debug(f"value: {type(value)}, `{value_str}`, hasattr(__name__): {hasattr(value, '__name__')}.")
# To check if allowed to support. # To check if allowed to support.
if self.is_unsupported_namespace(value):
return self.global_namespace, var, value
if self.is_unsupported_builtin_type(value): if self.is_unsupported_builtin_type(value):
return self.global_namespace, var, value return self.global_namespace, var, value
if not self.is_supported_namespace_module(value): # Check if support including instance of types.ModuleType if not self.is_supported_namespace_module(value): # Check if support including instance of types.ModuleType
return self.global_namespace, var, value return self.global_namespace, var, value
return self.global_namespace, var return self.global_namespace, var
error_info = f"The symbol '{var}' is not supported in graph mode." error_info = f"The name '{var}' is not defined, or not supported in graph mode."
logger.debug(error_info) logger.debug(f'error info: {error_info}')
return None, var, error_info return None, var, error_info
def analyze_super(self, class_type_node, subclass_instance): def analyze_super(self, class_type_node, subclass_instance):

View File

@ -284,7 +284,7 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const pipeline::ResourceB
} }
} }
if (!fn || py::isinstance<py::none>(fn)) { if (!fn || py::isinstance<py::none>(fn)) {
MS_LOG(ERROR) << "Fail to find bprop function for " << prim->name() << ". fn: " << py::str(fn); MS_LOG(WARNING) << "Fail to find bprop function for " << prim->name() << ". fn: " << py::str(fn);
return nullptr; return nullptr;
} }
func_graph = parse::ParsePythonCode(fn); func_graph = parse::ParsePythonCode(fn);

View File

@ -579,10 +579,18 @@ std::vector<std::string> GetObjKey(const py::object &obj) {
// Get obj detail type // Get obj detail type
ResolveTypeDef GetObjType(const py::object &obj) { ResolveTypeDef GetObjType(const py::object &obj) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); try {
auto obj_type = py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>()); auto obj_type =
return obj_type; ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>());
return obj_type;
} catch (const py::error_already_set &ex) {
MS_LOG(ERROR) << "Meet a exception from Python when get the type of `" << py::str(obj) << "`.\n" << ex.what();
std::rethrow_exception(std::current_exception());
} catch (const py::type_error &ex) {
MS_LOG(ERROR) << "Meet a exception when get the type of `" << py::str(obj) << "`.\n" << ex.what();
std::rethrow_exception(std::current_exception());
}
} }
// Get class instance detail type. // Get class instance detail type.

View File

@ -70,7 +70,7 @@ static bool CanBeIsolatedNode(const std::string &var_name, const AnfNodePtr &nod
// Write variable records the variable name to corresponding node // Write variable records the variable name to corresponding node
void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) { void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var " << var_name << " with node " MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var `" << var_name << "` with node "
<< node->DebugString(); << node->DebugString();
auto [iter, is_new_name] = assigned_vars_.emplace(var_name, std::make_pair(node, false)); auto [iter, is_new_name] = assigned_vars_.emplace(var_name, std::make_pair(node, false));
if (!is_new_name) { if (!is_new_name) {
@ -97,7 +97,7 @@ void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr
// Read variable from predecessors // Read variable from predecessors
AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) { AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
MS_LOG(DEBUG) << "Read begin, var: " << var << ", block id: " << func_graph_->debug_info()->debug_id(); MS_LOG(DEBUG) << "Read begin, var: " << var << ", block: " << ToString();
// Get var node if it is found // Get var node if it is found
auto found = assigned_vars_.find(var); auto found = assigned_vars_.find(var);
if (found != assigned_vars_.end()) { if (found != assigned_vars_.end()) {
@ -117,7 +117,12 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
if (prev_blocks_.size() == 1) { if (prev_blocks_.size() == 1) {
auto block = prev_blocks_[0]; auto block = prev_blocks_[0];
MS_EXCEPTION_IF_NULL(block); MS_EXCEPTION_IF_NULL(block);
return block->ReadVariable(var); auto res = block->ReadVariable(var);
MS_LOG(INFO) << "Update global params of block: " << ToString() << ", with previous block: " << block->ToString()
<< ",\nCurrent: " << py::str(global_py_params())
<< "\nInsert: " << py::str(block->global_py_params());
CopyGlobalPyParam(block->global_py_params());
return res;
} else if (prev_blocks_.empty()) { } else if (prev_blocks_.empty()) {
// Get namespace and make Resolve // Get namespace and make Resolve
auto it = var_to_resolve_.find(var); auto it = var_to_resolve_.find(var);
@ -190,13 +195,14 @@ AnfNodePtr FunctionBlock::HandleNamespaceInfo(const py::tuple &namespace_info) {
} }
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_info[0]); NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_info[0]);
SymbolPtr symbol = std::make_shared<Symbol>(namespace_info[1].cast<std::string>()); SymbolPtr symbol = std::make_shared<Symbol>(namespace_info[1].cast<std::string>());
MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString() MS_LOG(DEBUG) << "[" << func_graph()->ToString() << "] name_space: " << name_space->ToString()
<< ", unsupported: " << unsupported; << ", symbol: " << symbol->ToString() << ", unsupported: " << unsupported;
auto resolved_node = MakeResolve(name_space, symbol); auto resolved_node = MakeResolve(name_space, symbol);
if (unsupported) { if (unsupported) {
resolved_node->set_interpret(true); resolved_node->set_interpret(true);
AddGlobalPyParam(symbol->name(), py_obj); AddGlobalPyParam(symbol->name(), py_obj);
MS_LOG(INFO) << "Added global python symblol: {" << symbol->name() << " : " << py::str(py_obj) << "}"; MS_LOG(INFO) << "[" << func_graph()->ToString() << "] Added global python symblol: {" << symbol->name() << " : "
<< py::str(py_obj) << "}";
} }
return resolved_node; return resolved_node;
} }
@ -218,7 +224,7 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
// The fallback feature is enabled in default. // The fallback feature is enabled in default.
// Not support change the flag during the process is alive. // Not support change the flag during the process is alive.
static const auto use_fallback = (parser_.support_fallback() != "1" ? false : true); static const auto use_fallback = (parser_.support_fallback() == "1");
if (!use_fallback) { if (!use_fallback) {
py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value); py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
return HandleNamespaceInfo(namespace_info); return HandleNamespaceInfo(namespace_info);
@ -268,12 +274,12 @@ void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
TraceGuard trace_guard(std::make_shared<TraceResolve>(phi->debug_info())); TraceGuard trace_guard(std::make_shared<TraceResolve>(phi->debug_info()));
std::string var = phi_nodes_[phi]; std::string var = phi_nodes_[phi];
MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " set phi " << phi->ToString() MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " set phi " << phi->ToString()
<< " for var " << var; << " for var `" << var << "`";
auto removable = CollectRemovablePhi(phi); auto removable = CollectRemovablePhi(phi);
// If the phi node is not necessary, not need to add to jumps_ of the prev blocks. // If the phi node is not necessary, not need to add to jumps_ of the prev blocks.
if (removable) { if (removable) {
MS_LOG(DEBUG) << "remove the phi when call graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") MS_LOG(DEBUG) << "remove the phi when call graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
<< " var " << var; << " var `" << var << "`";
return; return;
} }
for (auto &pred : prev_blocks_) { for (auto &pred : prev_blocks_) {
@ -402,12 +408,10 @@ CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) {
// Perform a jump from this block to target block // Perform a jump from this block to target block
void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const std::vector<AnfNodePtr> &args) { void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const std::vector<AnfNodePtr> &args) {
MS_LOG(DEBUG) << "Jump from " << func_graph_->debug_info()->debug_id() << " to " MS_LOG(DEBUG) << "Jump from block: " << ToString() << " to block: " << target_block->ToString();
<< target_block->func_graph()->debug_info()->debug_id();
MS_EXCEPTION_IF_NULL(target_block); MS_EXCEPTION_IF_NULL(target_block);
if (is_dead_block_) { if (is_dead_block_) {
MS_LOG(DEBUG) << "Dead code block should not jump to other block! Block id:" MS_LOG(DEBUG) << "Dead code block should not jump to other block! block: " << ToString();
<< func_graph_->debug_info()->debug_id();
return; return;
} }
if (func_graph_->get_return() != nullptr) { if (func_graph_->get_return() != nullptr) {

View File

@ -49,6 +49,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
virtual ~FunctionBlock() = default; virtual ~FunctionBlock() = default;
FuncGraphPtr func_graph() { return func_graph_; } FuncGraphPtr func_graph() { return func_graph_; }
std::string ToString() const { return func_graph_->ToString(); }
void WriteVariable(const std::string &var_name, const AnfNodePtr &node); void WriteVariable(const std::string &var_name, const AnfNodePtr &node);
AnfNodePtr ReadVariable(const std::string &var_name); AnfNodePtr ReadVariable(const std::string &var_name);
void AddPrevBlock(const FunctionBlockPtr &block); void AddPrevBlock(const FunctionBlockPtr &block);
@ -85,6 +86,13 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
py::dict &global_py_params() { return global_py_params_; } py::dict &global_py_params() { return global_py_params_; }
void set_global_py_params(const py::dict &symbols) { global_py_params_ = symbols; } void set_global_py_params(const py::dict &symbols) { global_py_params_ = symbols; }
void AddGlobalPyParam(const std::string &name, const py::object &obj) { global_py_params_[py::str(name)] = obj; } void AddGlobalPyParam(const std::string &name, const py::object &obj) { global_py_params_[py::str(name)] = obj; }
void CopyGlobalPyParam(const py::dict &symbols) {
for (auto &param : symbols) {
if (!global_py_params_.contains(param.first)) {
global_py_params_[param.first] = param.second;
}
}
}
std::tuple<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> local_py_params() { std::tuple<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> local_py_params() {
return {local_py_params_keys_, local_py_params_values_}; return {local_py_params_keys_, local_py_params_values_};

View File

@ -167,7 +167,7 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseFunction
py::str desc = py::str desc =
python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]); python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]);
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ". " MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ". "
<< "Func graph id: " << func_graph->debug_info()->debug_id(); << "FuncGraph: " << func_graph->ToString();
} }
} }
@ -733,7 +733,9 @@ bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object
} else { } else {
auto kw_key_c = kw_key.cast<std::string>(); auto kw_key_c = kw_key.cast<std::string>();
keys.push_back(NewValueNode(kw_key_c)); keys.push_back(NewValueNode(kw_key_c));
values.push_back(ParseExprNode(block, kw_value)); auto node = ParseExprNode(block, kw_value);
node = HandleInterpret(block, node, kw_value);
values.push_back(node);
} }
} }
auto keys_tuple = GenerateMakeTuple(block, keys); auto keys_tuple = GenerateMakeTuple(block, keys);
@ -1070,20 +1072,21 @@ FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py:
MS_LOG(DEBUG) << "Process ast AugAssign"; MS_LOG(DEBUG) << "Process ast AugAssign";
MS_EXCEPTION_IF_NULL(block); MS_EXCEPTION_IF_NULL(block);
MS_EXCEPTION_IF_NULL(ast_); MS_EXCEPTION_IF_NULL(ast_);
py::object target_obj = python_adapter::GetPyObjAttr(node, "target");
py::object op_obj = python_adapter::GetPyObjAttr(node, "op"); py::object target_object = python_adapter::GetPyObjAttr(node, "target");
py::object value_obj = python_adapter::GetPyObjAttr(node, "value"); py::object op_object = python_adapter::GetPyObjAttr(node, "op");
py::object value_object = python_adapter::GetPyObjAttr(node, "value");
AnfNodePtr target_node = nullptr; AnfNodePtr target_node = nullptr;
AnfNodePtr op_node = block->MakeResolveAstOp(op_obj); AnfNodePtr op_node = block->MakeResolveAstOp(op_object);
AnfNodePtr value_node = ParseExprNode(block, value_obj); AnfNodePtr value_node = ParseExprNode(block, value_object);
auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_obj))); auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_object)));
if (ast_type == AST_SUB_TYPE_NAME) { if (ast_type == AST_SUB_TYPE_NAME) {
target_node = ParseName(block, target_obj); target_node = ParseName(block, target_object);
} else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) { } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
target_node = ParseSubscript(block, target_obj); target_node = ParseSubscript(block, target_object);
} else if (ast_->IsClassMember(target_obj)) { } else if (ast_->IsClassMember(target_object)) {
target_node = ParseAttribute(block, target_obj); target_node = ParseAttribute(block, target_object);
} else { } else {
MS_LOG(EXCEPTION) << "Not supported augassign"; MS_LOG(EXCEPTION) << "Not supported augassign";
} }
@ -1091,8 +1094,7 @@ FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py:
MS_LOG(EXCEPTION) << "Can not get target node "; MS_LOG(EXCEPTION) << "Can not get target node ";
} }
CNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node}); CNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node});
WriteAssignVars(block, target_obj, augassign_app); WriteAssignVars(block, target_object, augassign_app);
return block; return block;
} }
// Process global declaration such as 'global x'; // Process global declaration such as 'global x';
@ -1668,11 +1670,11 @@ AnfNodePtr Parser::ParseListComp(const FunctionBlockPtr &block, const py::object
return output; return output;
} }
void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &target_obj, void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &target_object,
const AnfNodePtr &assigned_node) { const AnfNodePtr &assigned_node) {
MS_EXCEPTION_IF_NULL(block); MS_EXCEPTION_IF_NULL(block);
MS_EXCEPTION_IF_NULL(assigned_node); MS_EXCEPTION_IF_NULL(assigned_node);
py::str name = python_adapter::GetPyObjAttr(target_obj, "id"); py::str name = python_adapter::GetPyObjAttr(target_object, "id");
std::string name_id = name; std::string name_id = name;
assigned_node->debug_info()->set_name(name_id); assigned_node->debug_info()->set_name(name_id);
// Set the debug name of the constant graph // Set the debug name of the constant graph
@ -1683,16 +1685,16 @@ void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &t
fg->debug_info()->set_name(name_id); fg->debug_info()->set_name(name_id);
} }
} }
MS_LOG(DEBUG) << "Assign name: " << name_id << " to node: " << assigned_node; MS_LOG(DEBUG) << "Assign name: `" << name_id << "` to node: " << assigned_node->DebugString();
block->AddLocalPyParam(name_id, assigned_node); block->AddLocalPyParam(name_id, assigned_node);
block->WriteVariable(name_id, assigned_node); block->WriteVariable(name_id, assigned_node);
} }
void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &target_obj, void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &target_object,
const AnfNodePtr &assigned_node) { const AnfNodePtr &assigned_node) {
MS_EXCEPTION_IF_NULL(block); MS_EXCEPTION_IF_NULL(block);
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
py::list items = python_adapter::GetPyObjAttr(target_obj, "elts"); py::list items = python_adapter::GetPyObjAttr(target_object, "elts");
for (size_t i = 0; i < items.size(); i++) { for (size_t i = 0; i < items.size(); i++) {
// Use the Primitive replace the operation resolve node (getitem), // Use the Primitive replace the operation resolve node (getitem),
// because the getitem will eventually be converted to Primitive node // because the getitem will eventually be converted to Primitive node
@ -1704,13 +1706,13 @@ void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &
} }
} }
void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &target_obj, void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &target_object,
const AnfNodePtr &assigned_node) { const AnfNodePtr &assigned_node) {
// Now only support the self.xx = xxxxx, can't support x.y = xxxx // Now only support the self.xx = xxxxx, can't support x.y = xxxx
AnfNodePtr target_node = ParseExprNode(block, target_obj); AnfNodePtr target_node = ParseExprNode(block, target_object);
MS_EXCEPTION_IF_NULL(target_node); MS_EXCEPTION_IF_NULL(target_node);
auto attr_name = target_obj.attr("attr").cast<std::string>(); auto attr_name = target_object.attr("attr").cast<std::string>();
std::string var_name = "self." + attr_name; std::string var_name = "self." + attr_name;
// Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
@ -1733,12 +1735,12 @@ void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::ob
block->SetStateAssign(target_node, assigned_node); block->SetStateAssign(target_node, assigned_node);
} }
void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &target_obj, void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &target_object,
const AnfNodePtr &assigned_node) { const AnfNodePtr &assigned_node) {
MS_EXCEPTION_IF_NULL(block); MS_EXCEPTION_IF_NULL(block);
AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM); AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM);
py::object value_obj = python_adapter::GetPyObjAttr(target_obj, "value"); py::object value_obj = python_adapter::GetPyObjAttr(target_object, "value");
py::object slice_obj = python_adapter::GetPyObjAttr(target_obj, "slice"); py::object slice_obj = python_adapter::GetPyObjAttr(target_object, "slice");
AnfNodePtr value_node = ParseExprNode(block, value_obj); AnfNodePtr value_node = ParseExprNode(block, value_obj);
AnfNodePtr slice_node = ParseExprNode(block, slice_obj); AnfNodePtr slice_node = ParseExprNode(block, slice_obj);
CNodePtr setitem_app = block->func_graph()->NewCNodeInOrder({op_setitem, value_node, slice_node, assigned_node}); CNodePtr setitem_app = block->func_graph()->NewCNodeInOrder({op_setitem, value_node, slice_node, assigned_node});
@ -1776,19 +1778,19 @@ void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::obje
block->WriteVariable(var_name, setitem_app); block->WriteVariable(var_name, setitem_app);
} }
void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &target_obj, void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &target_object,
const AnfNodePtr &value_node) { const AnfNodePtr &value_node) {
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value_node);
MS_LOG(DEBUG) << "Process WriteAssignVars"; MS_LOG(DEBUG) << "Process WriteAssignVars";
auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_obj))); auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_object)));
if (ast_type == AST_SUB_TYPE_NAME) { if (ast_type == AST_SUB_TYPE_NAME) {
HandleAssignName(block, target_obj, value_node); HandleAssignName(block, target_object, value_node);
} else if (ast_type == AST_SUB_TYPE_TUPLE) { } else if (ast_type == AST_SUB_TYPE_TUPLE) {
HandleAssignTuple(block, target_obj, value_node); HandleAssignTuple(block, target_object, value_node);
} else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) { } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
HandleAssignSubscript(block, target_obj, value_node); HandleAssignSubscript(block, target_object, value_node);
} else if (ast_->IsClassMember(target_obj)) { } else if (ast_->IsClassMember(target_object)) {
HandleAssignClassMember(block, target_obj, value_node); HandleAssignClassMember(block, target_object, value_node);
} else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) { } else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) {
MS_LOG(EXCEPTION) << "The subnet attributes cannot be changed in the network. \n\n" MS_LOG(EXCEPTION) << "The subnet attributes cannot be changed in the network. \n\n"
<< trace::GetDebugInfo(value_node->debug_info()); << trace::GetDebugInfo(value_node->debug_info());
@ -1802,7 +1804,7 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
const py::object &value_object) { const py::object &value_object) {
// The fallback feature is enabled in default. // The fallback feature is enabled in default.
// Not support change the flag during the process is alive. // Not support change the flag during the process is alive.
static const auto use_fallback = (support_fallback() != "1" ? false : true); static const auto use_fallback = (support_fallback() == "1");
if (!use_fallback) { if (!use_fallback) {
return value_node; return value_node;
} }
@ -1810,9 +1812,10 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
AnfNodePtr interpreted_node = value_node; AnfNodePtr interpreted_node = value_node;
if (value_node->interpret()) { if (value_node->interpret()) {
const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(value_object)); const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(value_object));
MS_LOG(INFO) << "script_text: " << script_text << ", value_node: " << value_node->DebugString(2);
// Prepare global parameters.
py::dict global_dict = block->global_py_params(); py::dict global_dict = block->global_py_params();
MS_LOG(INFO) << "[" << block->func_graph()->ToString() << "] script_text: " << script_text
<< ", value_node: " << value_node->DebugString(3) << ", global_dict: " << py::str(global_dict);
// Prepare global parameters.
ValuePtr globals_converted_value = nullptr; ValuePtr globals_converted_value = nullptr;
if (!ConvertData(global_dict, &globals_converted_value)) { if (!ConvertData(global_dict, &globals_converted_value)) {
MS_LOG(EXCEPTION) << "Convert data failed"; MS_LOG(EXCEPTION) << "Convert data failed";

View File

@ -60,7 +60,7 @@ abstract::AbstractBasePtr ClassType::ToAbstract() {
// The fallback feature is enabled in default. // The fallback feature is enabled in default.
// Not support change the flag during the process is alive. // Not support change the flag during the process is alive.
static const auto support_fallback = common::GetEnv("ENV_SUPPORT_FALLBACK"); static const auto support_fallback = common::GetEnv("ENV_SUPPORT_FALLBACK");
static const auto use_fallback = (support_fallback != "1" ? false : true); static const auto use_fallback = (support_fallback == "1");
if (use_fallback && !IsSupportedCreateInstanceType(obj())) { if (use_fallback && !IsSupportedCreateInstanceType(obj())) {
return abs_scalar; return abs_scalar;
} }

View File

@ -112,6 +112,16 @@ class PyObjectWrapper : public Named {
py::object obj_; py::object obj_;
}; };
// InterpretedObject class wrappers interpreted python object.
class InterpretedObject : public PyObjectWrapper {
public:
explicit InterpretedObject(const py::object &obj, const std::string &name = "Interpreted object")
: PyObjectWrapper(obj, name) {}
~InterpretedObject() override = default;
MS_DECLARE_PARENT(InterpretedObject, PyObjectWrapper);
abstract::AbstractBasePtr ToAbstract() override;
};
// ClassObject class wrappers dataclass // ClassObject class wrappers dataclass
class ClassObject : public PyObjectWrapper { class ClassObject : public PyObjectWrapper {
public: public:

View File

@ -207,7 +207,7 @@ class MS_CORE_API AnfNode : public Base {
void set_grad(const bool &need_grad) { need_grad_ = need_grad; } void set_grad(const bool &need_grad) { need_grad_ = need_grad; }
bool interpret() { return interpret_; } bool interpret() { return interpret_; }
void set_interpret(const bool interpret) { interpret_ = interpret; } void set_interpret(const bool &interpret) { interpret_ = interpret; }
AnfNodePtr interpreted_node() { return interpreted_node_; } AnfNodePtr interpreted_node() { return interpreted_node_; }
void set_interpreted_node(const AnfNodePtr &node) { interpreted_node_ = node; } void set_interpreted_node(const AnfNodePtr &node) { interpreted_node_ = node; }

View File

@ -176,17 +176,23 @@ class MS_CORE_API ValueDictionary : public Value {
keys.push_back(kv.first); keys.push_back(kv.first);
values.push_back(kv.second); values.push_back(kv.second);
} }
buffer << "(Dict: " buffer << "dict: {keys: (";
<< " keys:("; for (size_t i = 0; i < keys.size(); i++) {
for (const auto &key : keys) { buffer << keys[i];
buffer << key << ", "; if (i != keys.size() - 1) {
buffer << ", ";
}
} }
buffer << ") values:("; buffer << "), values: (";
for (const auto &value : values) { for (size_t i = 0; i < values.size(); i++) {
const auto &value = values[i];
MS_EXCEPTION_IF_NULL(value); MS_EXCEPTION_IF_NULL(value);
buffer << value->ToString() << ", "; buffer << value->ToString();
if (i != values.size() - 1) {
buffer << ", ";
}
} }
buffer << ")"; buffer << ")}";
return buffer.str(); return buffer.str();
} }
abstract::AbstractBasePtr ToAbstract() override; abstract::AbstractBasePtr ToAbstract() override;

View File

@ -18,8 +18,10 @@ import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor, ms_function, context from mindspore import Tensor, ms_function, context
from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
import mindspore.common.dtype as mstype
import mindspore.common._monad as monad
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
@ -38,16 +40,15 @@ def test_increment():
@ms_function @ms_function
def np_fallback_func(): def use_monad(x, y):
array_x = [2, 3, 4, 5] res = P.Mul()(x, y)
np_x = np.array(array_x).astype(np.float32) res = F.depend(res, monad.U)
me_x = Tensor(np_x) return res
me_x = me_x + me_x
return me_x
@pytest.mark.skip(reason='Graph fallback feature is not supported yet') def test_use_monad():
def test_np_fallback_func(): x = Tensor(1.0, mstype.float32)
print(np_fallback_func()) y = Tensor(1.0, mstype.float32)
print(use_monad(x, y))
class Net(nn.Cell): class Net(nn.Cell):
@ -64,3 +65,74 @@ class Net(nn.Cell):
def test_builtins_len(): def test_builtins_len():
net = Net() net = Net()
net() net()
@ms_function
def np_fallback_func():
array_x = tuple([2, 3, 4, 5])
np_x = np.array(array_x).astype(np.float32)
me_x = Tensor(np_x)
me_x = me_x + me_x
return me_x
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_fallback_func():
print(np_fallback_func())
@ms_function
def div_mod_func(x, y):
a = divmod(x, y)
return Tensor(a)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_div_mod_func():
print(div_mod_func(8, 3)) # (2, 2)
# NameError: name 'Tensor' is not defined.
@ms_function
def select_func(cond, x, y):
if isinstance(cond, (tuple, list)):
output = y
elif isinstance(cond, Tensor):
output = F.select(cond, x, y)
else:
output = x
return output
def test_select_func():
cond = Tensor([True, False])
x = Tensor([2, 3], mstype.float32)
y = Tensor([1, 2], mstype.float32)
print(select_func(cond, x, y))
# Not interpret 'Tensor'.
@ms_function
def select_func2(cond, x, y):
if isinstance(cond, (tuple, list)):
output = y
if isinstance(cond, Tensor):
output = F.select(cond, x, y)
else:
output = x
return output
def test_select_func2():
cond = Tensor([True, False])
x = Tensor([2, 3], mstype.float32)
y = Tensor([1, 2], mstype.float32)
print(select_func2(cond, x, y))
# NameError: name 'Tensor' is not defined.
@ms_function
def slice_func(a, b):
a[1:3, ::] = b
return a
def test_slice_func():
a = Tensor(np.arange(60).reshape(3, 4, 5), dtype=mstype.float32)
b = Tensor([1], dtype=mstype.float32)
print(slice_func(a, b))