Do not interpret elements in global parameters and local parameters.

This commit is contained in:
huangbingjian 2021-12-01 15:15:30 +08:00
parent 5dd23f3632
commit ea72869729
2 changed files with 33 additions and 2 deletions

View File

@ -1881,6 +1881,29 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta
}
}
bool Parser::IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) {
// Check global parameters.
if (global_dict.contains(script_text)) {
MS_LOG(DEBUG) << "[" << func_graph->ToString() << "] Found `" << script_text << "` in global params.";
return true;
}
// Check local parameters.
auto in_local_params = std::any_of(local_keys.begin(), local_keys.end(), [&script_text](const AnfNodePtr &node) {
const auto value_node = dyn_cast<ValueNode>(node);
MS_EXCEPTION_IF_NULL(value_node);
const StringImmPtr &str_imm = dyn_cast<StringImm>(value_node->value());
MS_EXCEPTION_IF_NULL(str_imm);
return script_text == str_imm->value();
});
if (in_local_params) {
MS_LOG(DEBUG) << "[" << func_graph->ToString() << "] Found `" << script_text << "` in local params.";
return true;
}
return false;
}
AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
const py::object &value_object) {
// The fallback feature is enabled in default.
@ -1891,15 +1914,20 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
}
const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(value_object));
// Prepare global parameters.
// Check if script_text is in global/local params.
py::dict global_dict = block->global_py_params();
auto [keys, values] = block->local_py_params();
if (IsScriptInParams(script_text, global_dict, keys, block->func_graph())) {
return value_node;
}
// Prepare global parameters.
ValuePtr globals_converted_value = nullptr;
if (!ConvertData(global_dict, &globals_converted_value)) {
MS_LOG(EXCEPTION) << "Convert data failed";
}
auto global_dict_node = NewValueNode(globals_converted_value);
// Prepare local parameters.
auto [keys, values] = block->local_py_params();
// Filter the func_graph node where the current node is located.
auto current_fg = value_node->func_graph();
std::vector<AnfNodePtr> filter_keys;

View File

@ -187,6 +187,9 @@ class Parser {
AnfNodePtr ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param,
const py::object &node, const py::object &generator_node);
// Check if script_text is in global/local params.
bool IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph);
// Check if the node need interpreting.
AnfNodePtr HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
const py::object &value_object);