diff --git a/mindspore/_extends/parse/parser.py b/mindspore/_extends/parse/parser.py index 654f3317417..74cd3168f2f 100644 --- a/mindspore/_extends/parse/parser.py +++ b/mindspore/_extends/parse/parser.py @@ -76,6 +76,7 @@ parse_expr_statement_white_list = ( "append", ) +_builtin_function_or_method_type = type(abs) def create_slice_obj(start, end, step): """Create slice object""" @@ -248,6 +249,7 @@ def get_obj_id(obj): def get_obj_type(obj): """Get the obj type.""" + logger.debug("Get object type: %r", obj) obj_type = RESOLVE_TYPE_INVALID if obj is None: obj_type = RESOLVE_TYPE_NONE @@ -529,9 +531,9 @@ class Parser: # Used to resolve mindspore builtin ops namespace. self.ms_common_ns = CellNamespace('mindspore.common') self.ms_ops_ns = CellNamespace('mindspore.ops') - self.ms_ops_c = CellNamespace('mindspore.ops.composite') - self.ms_ops_c_multitype = CellNamespace('mindspore.ops.composite.multitype_ops') - self.ms_ops_p = CellNamespace('mindspore.ops.operations') + self.ms_ops_c_ns = CellNamespace('mindspore.ops.composite') + self.ms_ops_c_multitype_ns = CellNamespace('mindspore.ops.composite.multitype_ops') + self.ms_ops_p_ns = CellNamespace('mindspore.ops.operations') # Used to resolve the function's globals namespace. self.global_namespace = CellNamespace(fn.__module__) self.function_module = fn.__module__ @@ -567,6 +569,11 @@ class Parser: logger.error("Fn type is invalid") return None, None + def is_unsupported_namespace(self, value): + unsupported = isinstance(value, _builtin_function_or_method_type) and value not in convert_object_map + logger.debug(f'`{value}` unsupported: {unsupported}.') + return unsupported + def get_namespace_symbol(self, var: str): """Get symbol type and namespace and symbol.""" if var in self.closure_namespace: @@ -575,7 +582,7 @@ class Parser: if var in self.global_namespace: logger.debug(f"Found `{var}` in global_namespace {self.global_namespace.__str__()}") value = self.global_namespace[var] - if isinstance(value, type(abs)) and self.global_namespace[var] not in convert_object_map: + if self.is_unsupported_namespace(value): error_info = f"The builtin function '{var}' is not supported in graph mode." return None, var, error_info return self.global_namespace, var @@ -604,6 +611,11 @@ class Parser: logger.debug(f'Found `{name}` in mindspore root namespace.') return True + # Check `Tensor` namespace. + if value == Tensor: + logger.debug(f'Not support `{name}`.') + return False + # Check `builtins` namespace. if hasattr(value, '__module__'): # Not types.ModuleType mod = value.__module__ @@ -613,25 +625,29 @@ class Parser: # We suppose it's supported if not a Module. if not isinstance(value, types.ModuleType): + logger.debug(f'Found `{name}`, not a module.') return True # Check supported Module namespace. rightmost_name = name.split('.')[-1] - # By now, we don't check `self.ms_common_ns`. if rightmost_name in self.ms_ops_ns: - logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_ns.__str__()}.') + logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {str(self.ms_ops_ns)}.') return True - if rightmost_name in self.ms_ops_c: - logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_c.__str__()}.') + if rightmost_name in self.ms_ops_c_ns: + logger.debug(f'Found `{name}`({rightmost_name}) in C namespace: {str(self.ms_ops_c_ns)}.') return True - if rightmost_name in self.ms_ops_c_multitype: - logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_c_multitype.__str__()}.') + if rightmost_name in self.ms_ops_c_multitype_ns: + logger.debug( + f'Found `{name}`({rightmost_name}) in C.multitype namespace: {str(self.ms_ops_c_multitype_ns)}.') return True - if rightmost_name in self.ms_ops_p: - logger.debug(f'Found `{name}`({rightmost_name}) in ops namespace: {self.ms_ops_p.__str__()}.') + if rightmost_name in self.ms_ops_p_ns: + logger.debug(f'Found `{name}`({rightmost_name}) in P namespace: {str(self.ms_ops_p_ns)}.') + return True + if rightmost_name in self.ms_common_ns: + logger.debug(f'Found `{name}`({rightmost_name}) in P namespace: {str(self.ms_common_ns)}.') return True if rightmost_name in trope_ns: - logger.debug(f'Found `{name}`({rightmost_name}) in trope namespace: {self.trope_ns.__str__()}.') + logger.debug(f'Found `{name}`({rightmost_name}) in trope namespace: {str(trope_ns)}.') return True logger.error(f'Not found `{name}` in mindspore supported namespace.') @@ -648,14 +664,16 @@ class Parser: value_str = value.__name__ if hasattr(value, '__name__') else str(value) logger.debug(f"value: {type(value)}, `{value_str}`, hasattr(__name__): {hasattr(value, '__name__')}.") # To check if allowed to support. + if self.is_unsupported_namespace(value): + return self.global_namespace, var, value if self.is_unsupported_builtin_type(value): return self.global_namespace, var, value if not self.is_supported_namespace_module(value): # Check if support including instance of types.ModuleType return self.global_namespace, var, value return self.global_namespace, var - error_info = f"The symbol '{var}' is not supported in graph mode." - logger.debug(error_info) + error_info = f"The name '{var}' is not defined, or not supported in graph mode." + logger.debug(f'error info: {error_info}') return None, var, error_info def analyze_super(self, class_type_node, subclass_instance): diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index a9a4a1444a0..09d9b1f5c93 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -284,7 +284,7 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const pipeline::ResourceB } } if (!fn || py::isinstance(fn)) { - MS_LOG(ERROR) << "Fail to find bprop function for " << prim->name() << ". fn: " << py::str(fn); + MS_LOG(WARNING) << "Fail to find bprop function for " << prim->name() << ". fn: " << py::str(fn); return nullptr; } func_graph = parse::ParsePythonCode(fn); diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index aac10b31038..ef1d7f147f7 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -579,10 +579,18 @@ std::vector GetObjKey(const py::object &obj) { // Get obj detail type ResolveTypeDef GetObjType(const py::object &obj) { - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - auto obj_type = - ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast()); - return obj_type; + try { + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + auto obj_type = + ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast()); + return obj_type; + } catch (const py::error_already_set &ex) { + MS_LOG(ERROR) << "Meet a exception from Python when get the type of `" << py::str(obj) << "`.\n" << ex.what(); + std::rethrow_exception(std::current_exception()); + } catch (const py::type_error &ex) { + MS_LOG(ERROR) << "Meet a exception when get the type of `" << py::str(obj) << "`.\n" << ex.what(); + std::rethrow_exception(std::current_exception()); + } } // Get class instance detail type. diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc index 4ad8bfb454a..760c6ccf35b 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc @@ -70,7 +70,7 @@ static bool CanBeIsolatedNode(const std::string &var_name, const AnfNodePtr &nod // Write variable records the variable name to corresponding node void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var " << var_name << " with node " + MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var `" << var_name << "` with node " << node->DebugString(); auto [iter, is_new_name] = assigned_vars_.emplace(var_name, std::make_pair(node, false)); if (!is_new_name) { @@ -97,7 +97,7 @@ void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr // Read variable from predecessors AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) { - MS_LOG(DEBUG) << "Read begin, var: " << var << ", block id: " << func_graph_->debug_info()->debug_id(); + MS_LOG(DEBUG) << "Read begin, var: " << var << ", block: " << ToString(); // Get var node if it is found auto found = assigned_vars_.find(var); if (found != assigned_vars_.end()) { @@ -117,7 +117,12 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) { if (prev_blocks_.size() == 1) { auto block = prev_blocks_[0]; MS_EXCEPTION_IF_NULL(block); - return block->ReadVariable(var); + auto res = block->ReadVariable(var); + MS_LOG(INFO) << "Update global params of block: " << ToString() << ", with previous block: " << block->ToString() + << ",\nCurrent: " << py::str(global_py_params()) + << "\nInsert: " << py::str(block->global_py_params()); + CopyGlobalPyParam(block->global_py_params()); + return res; } else if (prev_blocks_.empty()) { // Get namespace and make Resolve auto it = var_to_resolve_.find(var); @@ -190,13 +195,14 @@ AnfNodePtr FunctionBlock::HandleNamespaceInfo(const py::tuple &namespace_info) { } NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_info[0]); SymbolPtr symbol = std::make_shared(namespace_info[1].cast()); - MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString() - << ", unsupported: " << unsupported; + MS_LOG(DEBUG) << "[" << func_graph()->ToString() << "] name_space: " << name_space->ToString() + << ", symbol: " << symbol->ToString() << ", unsupported: " << unsupported; auto resolved_node = MakeResolve(name_space, symbol); if (unsupported) { resolved_node->set_interpret(true); AddGlobalPyParam(symbol->name(), py_obj); - MS_LOG(INFO) << "Added global python symblol: {" << symbol->name() << " : " << py::str(py_obj) << "}"; + MS_LOG(INFO) << "[" << func_graph()->ToString() << "] Added global python symblol: {" << symbol->name() << " : " + << py::str(py_obj) << "}"; } return resolved_node; } @@ -218,7 +224,7 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) { // The fallback feature is enabled in default. // Not support change the flag during the process is alive. - static const auto use_fallback = (parser_.support_fallback() != "1" ? false : true); + static const auto use_fallback = (parser_.support_fallback() == "1"); if (!use_fallback) { py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value); return HandleNamespaceInfo(namespace_info); @@ -268,12 +274,12 @@ void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { TraceGuard trace_guard(std::make_shared(phi->debug_info())); std::string var = phi_nodes_[phi]; MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " set phi " << phi->ToString() - << " for var " << var; + << " for var `" << var << "`"; auto removable = CollectRemovablePhi(phi); // If the phi node is not necessary, not need to add to jumps_ of the prev blocks. if (removable) { MS_LOG(DEBUG) << "remove the phi when call graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") - << " var " << var; + << " var `" << var << "`"; return; } for (auto &pred : prev_blocks_) { @@ -402,12 +408,10 @@ CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) { // Perform a jump from this block to target block void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const std::vector &args) { - MS_LOG(DEBUG) << "Jump from " << func_graph_->debug_info()->debug_id() << " to " - << target_block->func_graph()->debug_info()->debug_id(); + MS_LOG(DEBUG) << "Jump from block: " << ToString() << " to block: " << target_block->ToString(); MS_EXCEPTION_IF_NULL(target_block); if (is_dead_block_) { - MS_LOG(DEBUG) << "Dead code block should not jump to other block! Block id:" - << func_graph_->debug_info()->debug_id(); + MS_LOG(DEBUG) << "Dead code block should not jump to other block! block: " << ToString(); return; } if (func_graph_->get_return() != nullptr) { diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.h b/mindspore/ccsrc/pipeline/jit/parse/function_block.h index 3945012ba87..59a26a41965 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/function_block.h +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.h @@ -49,6 +49,7 @@ class FunctionBlock : public std::enable_shared_from_this { virtual ~FunctionBlock() = default; FuncGraphPtr func_graph() { return func_graph_; } + std::string ToString() const { return func_graph_->ToString(); } void WriteVariable(const std::string &var_name, const AnfNodePtr &node); AnfNodePtr ReadVariable(const std::string &var_name); void AddPrevBlock(const FunctionBlockPtr &block); @@ -85,6 +86,13 @@ class FunctionBlock : public std::enable_shared_from_this { py::dict &global_py_params() { return global_py_params_; } void set_global_py_params(const py::dict &symbols) { global_py_params_ = symbols; } void AddGlobalPyParam(const std::string &name, const py::object &obj) { global_py_params_[py::str(name)] = obj; } + void CopyGlobalPyParam(const py::dict &symbols) { + for (auto ¶m : symbols) { + if (!global_py_params_.contains(param.first)) { + global_py_params_[param.first] = param.second; + } + } + } std::tuple, std::vector> local_py_params() { return {local_py_params_keys_, local_py_params_values_}; diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index 6b9105b8baa..cbc4540cb3c 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -167,7 +167,7 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptrmodule(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]); MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast() << ". " - << "Func graph id: " << func_graph->debug_info()->debug_id(); + << "FuncGraph: " << func_graph->ToString(); } } @@ -733,7 +733,9 @@ bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object } else { auto kw_key_c = kw_key.cast(); keys.push_back(NewValueNode(kw_key_c)); - values.push_back(ParseExprNode(block, kw_value)); + auto node = ParseExprNode(block, kw_value); + node = HandleInterpret(block, node, kw_value); + values.push_back(node); } } auto keys_tuple = GenerateMakeTuple(block, keys); @@ -1070,20 +1072,21 @@ FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py: MS_LOG(DEBUG) << "Process ast AugAssign"; MS_EXCEPTION_IF_NULL(block); MS_EXCEPTION_IF_NULL(ast_); - py::object target_obj = python_adapter::GetPyObjAttr(node, "target"); - py::object op_obj = python_adapter::GetPyObjAttr(node, "op"); - py::object value_obj = python_adapter::GetPyObjAttr(node, "value"); + + py::object target_object = python_adapter::GetPyObjAttr(node, "target"); + py::object op_object = python_adapter::GetPyObjAttr(node, "op"); + py::object value_object = python_adapter::GetPyObjAttr(node, "value"); AnfNodePtr target_node = nullptr; - AnfNodePtr op_node = block->MakeResolveAstOp(op_obj); - AnfNodePtr value_node = ParseExprNode(block, value_obj); - auto ast_type = AstSubType(py::cast(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_obj))); + AnfNodePtr op_node = block->MakeResolveAstOp(op_object); + AnfNodePtr value_node = ParseExprNode(block, value_object); + auto ast_type = AstSubType(py::cast(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_object))); if (ast_type == AST_SUB_TYPE_NAME) { - target_node = ParseName(block, target_obj); + target_node = ParseName(block, target_object); } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) { - target_node = ParseSubscript(block, target_obj); - } else if (ast_->IsClassMember(target_obj)) { - target_node = ParseAttribute(block, target_obj); + target_node = ParseSubscript(block, target_object); + } else if (ast_->IsClassMember(target_object)) { + target_node = ParseAttribute(block, target_object); } else { MS_LOG(EXCEPTION) << "Not supported augassign"; } @@ -1091,8 +1094,7 @@ FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py: MS_LOG(EXCEPTION) << "Can not get target node "; } CNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node}); - WriteAssignVars(block, target_obj, augassign_app); - + WriteAssignVars(block, target_object, augassign_app); return block; } // Process global declaration such as 'global x'; @@ -1668,11 +1670,11 @@ AnfNodePtr Parser::ParseListComp(const FunctionBlockPtr &block, const py::object return output; } -void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &target_obj, +void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &target_object, const AnfNodePtr &assigned_node) { MS_EXCEPTION_IF_NULL(block); MS_EXCEPTION_IF_NULL(assigned_node); - py::str name = python_adapter::GetPyObjAttr(target_obj, "id"); + py::str name = python_adapter::GetPyObjAttr(target_object, "id"); std::string name_id = name; assigned_node->debug_info()->set_name(name_id); // Set the debug name of the constant graph @@ -1683,16 +1685,16 @@ void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &t fg->debug_info()->set_name(name_id); } } - MS_LOG(DEBUG) << "Assign name: " << name_id << " to node: " << assigned_node; + MS_LOG(DEBUG) << "Assign name: `" << name_id << "` to node: " << assigned_node->DebugString(); block->AddLocalPyParam(name_id, assigned_node); block->WriteVariable(name_id, assigned_node); } -void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &target_obj, +void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &target_object, const AnfNodePtr &assigned_node) { MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); - py::list items = python_adapter::GetPyObjAttr(target_obj, "elts"); + py::list items = python_adapter::GetPyObjAttr(target_object, "elts"); for (size_t i = 0; i < items.size(); i++) { // Use the Primitive replace the operation resolve node (getitem), // because the getitem will eventually be converted to Primitive node @@ -1704,13 +1706,13 @@ void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object & } } -void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &target_obj, +void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &target_object, const AnfNodePtr &assigned_node) { // Now only support the self.xx = xxxxx, can't support x.y = xxxx - AnfNodePtr target_node = ParseExprNode(block, target_obj); + AnfNodePtr target_node = ParseExprNode(block, target_object); MS_EXCEPTION_IF_NULL(target_node); - auto attr_name = target_obj.attr("attr").cast(); + auto attr_name = target_object.attr("attr").cast(); std::string var_name = "self." + attr_name; // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type @@ -1733,12 +1735,12 @@ void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::ob block->SetStateAssign(target_node, assigned_node); } -void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &target_obj, +void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &target_object, const AnfNodePtr &assigned_node) { MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM); - py::object value_obj = python_adapter::GetPyObjAttr(target_obj, "value"); - py::object slice_obj = python_adapter::GetPyObjAttr(target_obj, "slice"); + py::object value_obj = python_adapter::GetPyObjAttr(target_object, "value"); + py::object slice_obj = python_adapter::GetPyObjAttr(target_object, "slice"); AnfNodePtr value_node = ParseExprNode(block, value_obj); AnfNodePtr slice_node = ParseExprNode(block, slice_obj); CNodePtr setitem_app = block->func_graph()->NewCNodeInOrder({op_setitem, value_node, slice_node, assigned_node}); @@ -1776,19 +1778,19 @@ void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::obje block->WriteVariable(var_name, setitem_app); } -void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &target_obj, +void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &target_object, const AnfNodePtr &value_node) { MS_EXCEPTION_IF_NULL(value_node); MS_LOG(DEBUG) << "Process WriteAssignVars"; - auto ast_type = AstSubType(py::cast(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_obj))); + auto ast_type = AstSubType(py::cast(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_object))); if (ast_type == AST_SUB_TYPE_NAME) { - HandleAssignName(block, target_obj, value_node); + HandleAssignName(block, target_object, value_node); } else if (ast_type == AST_SUB_TYPE_TUPLE) { - HandleAssignTuple(block, target_obj, value_node); + HandleAssignTuple(block, target_object, value_node); } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) { - HandleAssignSubscript(block, target_obj, value_node); - } else if (ast_->IsClassMember(target_obj)) { - HandleAssignClassMember(block, target_obj, value_node); + HandleAssignSubscript(block, target_object, value_node); + } else if (ast_->IsClassMember(target_object)) { + HandleAssignClassMember(block, target_object, value_node); } else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) { MS_LOG(EXCEPTION) << "The subnet attributes cannot be changed in the network. \n\n" << trace::GetDebugInfo(value_node->debug_info()); @@ -1802,7 +1804,7 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP const py::object &value_object) { // The fallback feature is enabled in default. // Not support change the flag during the process is alive. - static const auto use_fallback = (support_fallback() != "1" ? false : true); + static const auto use_fallback = (support_fallback() == "1"); if (!use_fallback) { return value_node; } @@ -1810,9 +1812,10 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP AnfNodePtr interpreted_node = value_node; if (value_node->interpret()) { const auto script_text = py::cast(ast()->GetAstNodeText(value_object)); - MS_LOG(INFO) << "script_text: " << script_text << ", value_node: " << value_node->DebugString(2); - // Prepare global parameters. py::dict global_dict = block->global_py_params(); + MS_LOG(INFO) << "[" << block->func_graph()->ToString() << "] script_text: " << script_text + << ", value_node: " << value_node->DebugString(3) << ", global_dict: " << py::str(global_dict); + // Prepare global parameters. ValuePtr globals_converted_value = nullptr; if (!ConvertData(global_dict, &globals_converted_value)) { MS_LOG(EXCEPTION) << "Convert data failed"; diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index c3df379b053..341fb214afd 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -60,7 +60,7 @@ abstract::AbstractBasePtr ClassType::ToAbstract() { // The fallback feature is enabled in default. // Not support change the flag during the process is alive. static const auto support_fallback = common::GetEnv("ENV_SUPPORT_FALLBACK"); - static const auto use_fallback = (support_fallback != "1" ? false : true); + static const auto use_fallback = (support_fallback == "1"); if (use_fallback && !IsSupportedCreateInstanceType(obj())) { return abs_scalar; } diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.h b/mindspore/ccsrc/pipeline/jit/parse/resolve.h index 2f627915ea3..7c5bcc03310 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.h +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.h @@ -112,6 +112,16 @@ class PyObjectWrapper : public Named { py::object obj_; }; +// InterpretedObject class wrappers interpreted python object. +class InterpretedObject : public PyObjectWrapper { + public: + explicit InterpretedObject(const py::object &obj, const std::string &name = "Interpreted object") + : PyObjectWrapper(obj, name) {} + ~InterpretedObject() override = default; + MS_DECLARE_PARENT(InterpretedObject, PyObjectWrapper); + abstract::AbstractBasePtr ToAbstract() override; +}; + // ClassObject class wrappers dataclass class ClassObject : public PyObjectWrapper { public: diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 16d4c8bc61b..dedf768b68e 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -207,7 +207,7 @@ class MS_CORE_API AnfNode : public Base { void set_grad(const bool &need_grad) { need_grad_ = need_grad; } bool interpret() { return interpret_; } - void set_interpret(const bool interpret) { interpret_ = interpret; } + void set_interpret(const bool &interpret) { interpret_ = interpret; } AnfNodePtr interpreted_node() { return interpreted_node_; } void set_interpreted_node(const AnfNodePtr &node) { interpreted_node_ = node; } diff --git a/mindspore/core/ir/value.h b/mindspore/core/ir/value.h index 4da4474b008..f02da6cace7 100644 --- a/mindspore/core/ir/value.h +++ b/mindspore/core/ir/value.h @@ -176,17 +176,23 @@ class MS_CORE_API ValueDictionary : public Value { keys.push_back(kv.first); values.push_back(kv.second); } - buffer << "(Dict: " - << " keys:("; - for (const auto &key : keys) { - buffer << key << ", "; + buffer << "dict: {keys: ("; + for (size_t i = 0; i < keys.size(); i++) { + buffer << keys[i]; + if (i != keys.size() - 1) { + buffer << ", "; + } } - buffer << ") values:("; - for (const auto &value : values) { + buffer << "), values: ("; + for (size_t i = 0; i < values.size(); i++) { + const auto &value = values[i]; MS_EXCEPTION_IF_NULL(value); - buffer << value->ToString() << ", "; + buffer << value->ToString(); + if (i != values.size() - 1) { + buffer << ", "; + } } - buffer << ")"; + buffer << ")}"; return buffer.str(); } abstract::AbstractBasePtr ToAbstract() override; diff --git a/tests/ut/python/fallback/test_graph_fallback.py b/tests/ut/python/fallback/test_graph_fallback.py index 79301933da9..e4ca09ba585 100644 --- a/tests/ut/python/fallback/test_graph_fallback.py +++ b/tests/ut/python/fallback/test_graph_fallback.py @@ -18,8 +18,10 @@ import numpy as np import mindspore.nn as nn from mindspore import Tensor, ms_function, context +from mindspore.ops import operations as P from mindspore.ops import functional as F - +import mindspore.common.dtype as mstype +import mindspore.common._monad as monad context.set_context(mode=context.GRAPH_MODE) @@ -38,16 +40,15 @@ def test_increment(): @ms_function -def np_fallback_func(): - array_x = [2, 3, 4, 5] - np_x = np.array(array_x).astype(np.float32) - me_x = Tensor(np_x) - me_x = me_x + me_x - return me_x +def use_monad(x, y): + res = P.Mul()(x, y) + res = F.depend(res, monad.U) + return res -@pytest.mark.skip(reason='Graph fallback feature is not supported yet') -def test_np_fallback_func(): - print(np_fallback_func()) +def test_use_monad(): + x = Tensor(1.0, mstype.float32) + y = Tensor(1.0, mstype.float32) + print(use_monad(x, y)) class Net(nn.Cell): @@ -64,3 +65,74 @@ class Net(nn.Cell): def test_builtins_len(): net = Net() net() + + +@ms_function +def np_fallback_func(): + array_x = tuple([2, 3, 4, 5]) + np_x = np.array(array_x).astype(np.float32) + me_x = Tensor(np_x) + me_x = me_x + me_x + return me_x + +@pytest.mark.skip(reason='Not support graph fallback feature yet') +def test_np_fallback_func(): + print(np_fallback_func()) + + +@ms_function +def div_mod_func(x, y): + a = divmod(x, y) + return Tensor(a) + +@pytest.mark.skip(reason='Not support graph fallback feature yet') +def test_div_mod_func(): + print(div_mod_func(8, 3)) # (2, 2) + + +# NameError: name 'Tensor' is not defined. +@ms_function +def select_func(cond, x, y): + if isinstance(cond, (tuple, list)): + output = y + elif isinstance(cond, Tensor): + output = F.select(cond, x, y) + else: + output = x + return output + +def test_select_func(): + cond = Tensor([True, False]) + x = Tensor([2, 3], mstype.float32) + y = Tensor([1, 2], mstype.float32) + print(select_func(cond, x, y)) + + +# Not interpret 'Tensor'. +@ms_function +def select_func2(cond, x, y): + if isinstance(cond, (tuple, list)): + output = y + if isinstance(cond, Tensor): + output = F.select(cond, x, y) + else: + output = x + return output + +def test_select_func2(): + cond = Tensor([True, False]) + x = Tensor([2, 3], mstype.float32) + y = Tensor([1, 2], mstype.float32) + print(select_func2(cond, x, y)) + + +# NameError: name 'Tensor' is not defined. +@ms_function +def slice_func(a, b): + a[1:3, ::] = b + return a + +def test_slice_func(): + a = Tensor(np.arange(60).reshape(3, 4, 5), dtype=mstype.float32) + b = Tensor([1], dtype=mstype.float32) + print(slice_func(a, b))