!32465 [Fallback][Control_flow] Ensure the information of the Interpret node is transmitted normally when function block jumps.

Merge pull request !32465 from Margaret_wangrui/fallback_control_flow_2
This commit is contained in:
i-robot 2022-04-06 07:06:45 +00:00 committed by Gitee
commit 41f1f48ed9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 127 additions and 54 deletions

View File

@ -128,6 +128,9 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
if (node != nullptr) {
return node;
}
// The fallback feature is enabled in default.
// Not support change the flag during the process is alive.
static const auto use_fallback = (parser_.support_fallback() != "0");
// Get var from predecessor block, if can't get then make a resolve node to it
if (matured_) {
// If only one predecessor block, read the definition of var from it.
@ -135,10 +138,6 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
auto block = prev_blocks_[0];
MS_EXCEPTION_IF_NULL(block);
auto res = block->ReadVariable(var_name);
// The fallback feature is enabled in default.
// Not support change the flag during the process is alive.
static const auto use_fallback = (parser_.support_fallback() != "0");
if (use_fallback) {
MS_LOG(DEBUG) << "Update global params of block: " << ToString()
<< ", with previous block: " << block->ToString()
@ -166,6 +165,11 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
ParameterPtr phi_param = std::make_shared<Parameter>(func_graph());
MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " generate phi node "
<< phi_param->ToString() << " for " << var_name;
// If information transform by phi, need remove the var in interpret dict in fallback feature.
if (use_fallback) {
EraseLocalPyParam(var_name);
}
func_graph()->add_parameter(phi_param);
phi_nodes_[phi_param] = var_name;
WriteVariable(var_name, phi_param);

View File

@ -100,49 +100,52 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
}
}
std::tuple<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> local_py_params() {
std::tuple<std::map<std::string, AnfNodePtr>, std::map<std::string, AnfNodePtr>> local_py_params() {
return {local_py_params_keys_, local_py_params_values_};
}
void AddLocalPyParam(const std::string &name, const AnfNodePtr &node) {
MS_LOG(DEBUG) << "Add '" << name << "', " << node->DebugString();
local_py_params_keys_.emplace_back(NewValueNode(name));
local_py_params_values_.emplace_back(node);
}
// Call this methon only if you need update a variable. Usually variable override.
void UpdateLocalPyParam(const std::string &name, const AnfNodePtr &node) {
auto iter = std::find_if(local_py_params_keys_.cbegin(), local_py_params_keys_.cend(),
[&name](const AnfNodePtr node) -> bool {
const auto value_node = dyn_cast<ValueNode>(node);
MS_EXCEPTION_IF_NULL(value_node);
const StringImmPtr &str_imm = dyn_cast<StringImm>(value_node->value());
MS_EXCEPTION_IF_NULL(str_imm);
return name == str_imm->value();
});
if (iter == local_py_params_keys_.cend()) {
MS_LOG(EXCEPTION) << "Only for updating. Should not call this method if 'name' not exist.";
}
// Find the same position in 'values', and update the node.
auto distance = std::distance(local_py_params_keys_.cbegin(), iter);
auto values_pos_iter = local_py_params_values_.begin() + distance;
MS_LOG(DEBUG) << "Update '" << name << "', " << (*values_pos_iter)->DebugString() << " -> " << node->DebugString();
*values_pos_iter = node;
(void)local_py_params_keys_.insert(std::pair<std::string, AnfNodePtr>(name, NewValueNode(name)));
(void)local_py_params_values_.insert(std::pair<std::string, AnfNodePtr>(name, node));
}
void UpdateLocalPyParam(const std::vector<AnfNodePtr> &keys, const std::vector<AnfNodePtr> &values) {
// Call this methon only if you need update a variable. Usually variable override.
void UpdateLocalPyParam(const std::string &name, const AnfNodePtr &node) {
auto key_iter = local_py_params_keys_.find(name);
if (key_iter == local_py_params_keys_.end()) {
MS_LOG(EXCEPTION) << "Only for updating. Should not call this method if '" << name << "' not exist.";
}
// Find the same position in 'values', and update the node.
MS_LOG(DEBUG) << "Update '" << name << "', " << local_py_params_values_[name]->DebugString() << " -> "
<< node->DebugString();
local_py_params_values_[name] = node;
}
void EraseLocalPyParam(const std::string &name) {
auto key_iter = local_py_params_keys_.find(name);
auto value_iter = local_py_params_values_.find(name);
if (key_iter != local_py_params_keys_.end() && value_iter != local_py_params_values_.end()) {
MS_LOG(DEBUG) << "Erase '" << name
<< "' from local_py_params, the key node:" << local_py_params_keys_[name]->DebugString()
<< ", the value node:" << local_py_params_values_[name]->DebugString();
local_py_params_keys_.erase(key_iter);
local_py_params_values_.erase(value_iter);
}
}
void UpdateLocalPyParam(const std::map<std::string, AnfNodePtr> &keys, std::map<std::string, AnfNodePtr> values) {
if (keys.size() != values.size()) {
MS_LOG(EXCEPTION) << "keys size should be equal to values size.";
}
for (size_t index = 0; index < keys.size(); ++index) {
auto iter = std::find(local_py_params_keys_.cbegin(), local_py_params_keys_.cend(), keys[index]);
if (iter == local_py_params_keys_.cend()) {
local_py_params_keys_.emplace_back(keys[index]);
local_py_params_values_.emplace_back(values[index]);
MS_LOG(DEBUG) << "Add '" << keys[index]->DebugString() << "', " << values[index]->DebugString();
for (auto iter = keys.begin(); iter != keys.end(); ++iter) {
const std::string &cur_key_name = iter->first;
if (local_py_params_keys_.find(cur_key_name) == local_py_params_keys_.end()) {
(void)local_py_params_keys_.insert(std::pair<std::string, AnfNodePtr>(cur_key_name, iter->second));
(void)local_py_params_values_.insert(std::pair<std::string, AnfNodePtr>(cur_key_name, values[cur_key_name]));
MS_LOG(DEBUG) << "Add '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString();
} else {
auto distance = std::distance(local_py_params_keys_.cbegin(), iter);
auto values_pos_iter = local_py_params_values_.begin() + distance;
MS_LOG(DEBUG) << "Update '" << keys[index]->DebugString() << "', " << values[index]->DebugString();
*values_pos_iter = values[index];
MS_LOG(DEBUG) << "Update '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString();
local_py_params_values_[cur_key_name] = values[cur_key_name];
}
}
if (local_py_params_keys_.size() != local_py_params_values_.size()) {
@ -186,8 +189,8 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
// Collect all python symbols in the block.
// We treat both global symbols and local symbols declared previously as global symbols.
py::dict global_py_params_;
std::vector<AnfNodePtr> local_py_params_keys_;
std::vector<AnfNodePtr> local_py_params_values_;
std::map<std::string, AnfNodePtr> local_py_params_keys_;
std::map<std::string, AnfNodePtr> local_py_params_values_;
// Isolated nodes.
OrderedSet<AnfNodePtr> isolated_nodes_;

View File

@ -1613,7 +1613,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
MS_EXCEPTION_IF_NULL(after_block->func_graph());
after_block->func_graph()->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true);
}
static const auto use_fallback = (support_fallback() != "0");
// Process the if-true branch
std::pair<FunctionBlockPtr, FunctionBlockPtr> true_branch_graphs;
py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
@ -1630,6 +1630,10 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
}
MS_LOG(DEBUG) << "The true_end block jump to after, true_block: " << true_block->ToString()
<< ", true_end: " << true_end->ToString();
if (use_fallback) {
UpdateBlockPyParams(after_block, true_end);
}
}
// Process the orelse branch
@ -1648,6 +1652,9 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
}
MS_LOG(DEBUG) << "The false_end block jump to after, false_block: " << false_block->ToString()
<< ", false_end: " << false_end->ToString();
if (use_fallback) {
UpdateBlockPyParams(after_block, false_end);
}
}
auto switch_app = block->ConditionalJump(bool_node, true_block, false_block);
@ -2364,7 +2371,7 @@ void Parser::UpdateInterpretForUserNode(const AnfNodePtr &user_node, const std::
}
bool Parser::IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) {
const std::map<std::string, AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
// Check global parameters.
if (global_dict.contains(script_text)) {
@ -2373,14 +2380,7 @@ bool Parser::IsScriptInParams(const std::string &script_text, const py::dict &gl
}
// Check local parameters.
auto in_local_params = std::any_of(local_keys.begin(), local_keys.end(), [&script_text](const AnfNodePtr &node) {
const auto value_node = dyn_cast<ValueNode>(node);
MS_EXCEPTION_IF_NULL(value_node);
const StringImmPtr &str_imm = dyn_cast<StringImm>(value_node->value());
MS_EXCEPTION_IF_NULL(str_imm);
return script_text == str_imm->value();
});
if (in_local_params) {
if (local_keys.find(script_text) != local_keys.end()) {
MS_LOG(DEBUG) << "[" << func_graph->ToString() << "] Found `" << script_text << "` in local params.";
return true;
}
@ -2414,7 +2414,7 @@ AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNod
MS_EXCEPTION_IF_NULL(value_node);
// Check if script_text is in global/local params.
py::dict global_dict = block->global_py_params();
const auto &[keys, values] = block->local_py_params();
auto [keys, values] = block->local_py_params();
if (IsTensorType(value_node, script_text)) {
return value_node;
}
@ -2434,13 +2434,14 @@ AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNod
auto current_fg = value_node->func_graph();
std::vector<AnfNodePtr> filter_keys;
std::vector<AnfNodePtr> filter_values;
for (size_t index = 0; index < values.size(); ++index) {
auto value = values[index];
for (auto iter = values.begin(); iter != values.end(); ++iter) {
auto value = iter->second;
auto fg = GetValueNode<FuncGraphPtr>(value);
if (fg == current_fg) {
continue;
}
(void)filter_keys.emplace_back(keys[index]);
const std::string &name = iter->first;
(void)filter_keys.emplace_back(keys[name]);
(void)filter_values.emplace_back(value);
}
auto local_dict_node = ParseDictByKeysAndValues(block, filter_keys, filter_values);

View File

@ -219,7 +219,7 @@ class Parser {
// Check if script_text is in global/local params.
bool IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph);
const std::map<std::string, AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph);
// Set the interpret flag for the node calling the interpret node.
void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node);
void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const std::vector<AnfNodePtr> &nodes);

View File

@ -0,0 +1,65 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test graph fallback control flow."""
import numpy as np
from mindspore import context
from mindspore.nn import Cell
context.set_context(mode=context.GRAPH_MODE)
def test_single_if_no_else_type():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
class FalseNet(Cell):
def __init__(self):
super(FalseNet, self).__init__()
self.cond = False
def construct(self):
x = np.array(1)
if self.cond:
return type(2).mro()
return type(x).mro()
test_net = FalseNet()
res = test_net()
assert str(res) == "(<class 'numpy.ndarray'>, <class 'object'>)"
def test_single_if_no_else_type_2():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
class TrueNet(Cell):
def __init__(self):
super(TrueNet, self).__init__()
self.cond = True
def construct(self):
x = np.array(2)
y = 2
if self.cond:
return type(y).mro()
return type(x).mro()
test_net = TrueNet()
res = test_net()
assert str(res) == "(<class 'int'>, <class 'object'>)"