forked from mindspore-Ecosystem/mindspore
!31489 [ME][Fallback] Handling the problem of misidentifying Tensor and functools.
Merge pull request !31489 from Margaret_wangrui/fallback
This commit is contained in:
commit
dd079cec90
|
@ -2202,6 +2202,14 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
|
|||
return MakeInterpretNode(block, value_node, script_text);
|
||||
}
|
||||
|
||||
bool Parser::IsTensorType(const AnfNodePtr &node, const std::string &script_text) const {
|
||||
if (node->interpret_internal_type() && script_text.find("(") == std::string::npos) {
|
||||
MS_LOG(DEBUG) << "The Tensor is present as type.";
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
|
||||
const string &script_text) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
|
@ -2209,6 +2217,9 @@ AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNod
|
|||
// Check if script_text is in global/local params.
|
||||
py::dict global_dict = block->global_py_params();
|
||||
auto [keys, values] = block->local_py_params();
|
||||
if (IsTensorType(value_node, script_text)) {
|
||||
return value_node;
|
||||
}
|
||||
bool is_special_node = value_node->interpret_special_type();
|
||||
if (IsScriptInParams(script_text, global_dict, keys, block->func_graph()) && !is_special_node) {
|
||||
return value_node;
|
||||
|
|
|
@ -201,6 +201,9 @@ class Parser {
|
|||
// Transform tail call to parallel call.
|
||||
void TransformParallelCall();
|
||||
|
||||
// If Tensor is present as type, not Tensor(xxx), should not make InterpretNode.
|
||||
bool IsTensorType(const AnfNodePtr &node, const std::string &script_text) const;
|
||||
|
||||
// Check if script_text is in global/local params.
|
||||
bool IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
|
||||
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph);
|
||||
|
|
|
@ -777,6 +777,10 @@ class Parser:
|
|||
logger.debug(f"Found 'mindspore.context' namespace.")
|
||||
return True
|
||||
|
||||
if name == 'functools':
|
||||
logger.debug(f"Found 'functools' namespace.")
|
||||
return True
|
||||
|
||||
# Check `builtins` namespace.
|
||||
if hasattr(value, '__module__'): # Not types.ModuleType
|
||||
mod = value.__module__
|
||||
|
|
Loading…
Reference in New Issue