!31489 [ME][Fallback] Handling the problem of misidentifying Tensor and functools.

Merge pull request !31489 from Margaret_wangrui/fallback
This commit is contained in:
i-robot 2022-03-21 01:09:30 +00:00 committed by Gitee
commit dd079cec90
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 18 additions and 0 deletions

View File

@ -2202,6 +2202,14 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
return MakeInterpretNode(block, value_node, script_text);
}
bool Parser::IsTensorType(const AnfNodePtr &node, const std::string &script_text) const {
if (node->interpret_internal_type() && script_text.find("(") == std::string::npos) {
MS_LOG(DEBUG) << "The Tensor is present as type.";
return true;
}
return false;
}
AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
const string &script_text) {
MS_EXCEPTION_IF_NULL(block);
@ -2209,6 +2217,9 @@ AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNod
// Check if script_text is in global/local params.
py::dict global_dict = block->global_py_params();
auto [keys, values] = block->local_py_params();
if (IsTensorType(value_node, script_text)) {
return value_node;
}
bool is_special_node = value_node->interpret_special_type();
if (IsScriptInParams(script_text, global_dict, keys, block->func_graph()) && !is_special_node) {
return value_node;

View File

@ -201,6 +201,9 @@ class Parser {
// Transform tail call to parallel call.
void TransformParallelCall();
// If Tensor is present as type, not Tensor(xxx), should not make InterpretNode.
bool IsTensorType(const AnfNodePtr &node, const std::string &script_text) const;
// Check if script_text is in global/local params.
bool IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph);

View File

@ -777,6 +777,10 @@ class Parser:
logger.debug(f"Found 'mindspore.context' namespace.")
return True
if name == 'functools':
logger.debug(f"Found 'functools' namespace.")
return True
# Check `builtins` namespace.
if hasattr(value, '__module__'): # Not types.ModuleType
mod = value.__module__