Do not interpret elements in global parameters and local parameters.
This commit is contained in:
parent
5dd23f3632
commit
ea72869729
|
@ -1881,6 +1881,29 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta
|
|||
}
|
||||
}
|
||||
|
||||
bool Parser::IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
|
||||
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) {
|
||||
// Check global parameters.
|
||||
if (global_dict.contains(script_text)) {
|
||||
MS_LOG(DEBUG) << "[" << func_graph->ToString() << "] Found `" << script_text << "` in global params.";
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check local parameters.
|
||||
auto in_local_params = std::any_of(local_keys.begin(), local_keys.end(), [&script_text](const AnfNodePtr &node) {
|
||||
const auto value_node = dyn_cast<ValueNode>(node);
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
const StringImmPtr &str_imm = dyn_cast<StringImm>(value_node->value());
|
||||
MS_EXCEPTION_IF_NULL(str_imm);
|
||||
return script_text == str_imm->value();
|
||||
});
|
||||
if (in_local_params) {
|
||||
MS_LOG(DEBUG) << "[" << func_graph->ToString() << "] Found `" << script_text << "` in local params.";
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
|
||||
const py::object &value_object) {
|
||||
// The fallback feature is enabled in default.
|
||||
|
@ -1891,15 +1914,20 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
|
|||
}
|
||||
|
||||
const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(value_object));
|
||||
// Prepare global parameters.
|
||||
// Check if script_text is in global/local params.
|
||||
py::dict global_dict = block->global_py_params();
|
||||
auto [keys, values] = block->local_py_params();
|
||||
if (IsScriptInParams(script_text, global_dict, keys, block->func_graph())) {
|
||||
return value_node;
|
||||
}
|
||||
|
||||
// Prepare global parameters.
|
||||
ValuePtr globals_converted_value = nullptr;
|
||||
if (!ConvertData(global_dict, &globals_converted_value)) {
|
||||
MS_LOG(EXCEPTION) << "Convert data failed";
|
||||
}
|
||||
auto global_dict_node = NewValueNode(globals_converted_value);
|
||||
// Prepare local parameters.
|
||||
auto [keys, values] = block->local_py_params();
|
||||
// Filter the func_graph node where the current node is located.
|
||||
auto current_fg = value_node->func_graph();
|
||||
std::vector<AnfNodePtr> filter_keys;
|
||||
|
|
|
@ -187,6 +187,9 @@ class Parser {
|
|||
AnfNodePtr ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param,
|
||||
const py::object &node, const py::object &generator_node);
|
||||
|
||||
// Check if script_text is in global/local params.
|
||||
bool IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
|
||||
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph);
|
||||
// Check if the node need interpreting.
|
||||
AnfNodePtr HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
|
||||
const py::object &value_object);
|
||||
|
|
Loading…
Reference in New Issue