forked from mindspore-Ecosystem/mindspore
!32465 [Fallback][Control_flow] Ensure the information of the Interpret node is transmitted normally when function block jumps.
Merge pull request !32465 from Margaret_wangrui/fallback_control_flow_2
This commit is contained in:
commit
41f1f48ed9
|
@ -128,6 +128,9 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
|
|||
if (node != nullptr) {
|
||||
return node;
|
||||
}
|
||||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto use_fallback = (parser_.support_fallback() != "0");
|
||||
// Get var from predecessor block, if can't get then make a resolve node to it
|
||||
if (matured_) {
|
||||
// If only one predecessor block, read the definition of var from it.
|
||||
|
@ -135,10 +138,6 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
|
|||
auto block = prev_blocks_[0];
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
auto res = block->ReadVariable(var_name);
|
||||
|
||||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto use_fallback = (parser_.support_fallback() != "0");
|
||||
if (use_fallback) {
|
||||
MS_LOG(DEBUG) << "Update global params of block: " << ToString()
|
||||
<< ", with previous block: " << block->ToString()
|
||||
|
@ -166,6 +165,11 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
|
|||
ParameterPtr phi_param = std::make_shared<Parameter>(func_graph());
|
||||
MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " generate phi node "
|
||||
<< phi_param->ToString() << " for " << var_name;
|
||||
// If information transform by phi, need remove the var in interpret dict in fallback feature.
|
||||
if (use_fallback) {
|
||||
EraseLocalPyParam(var_name);
|
||||
}
|
||||
|
||||
func_graph()->add_parameter(phi_param);
|
||||
phi_nodes_[phi_param] = var_name;
|
||||
WriteVariable(var_name, phi_param);
|
||||
|
|
|
@ -100,49 +100,52 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
}
|
||||
}
|
||||
|
||||
std::tuple<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> local_py_params() {
|
||||
std::tuple<std::map<std::string, AnfNodePtr>, std::map<std::string, AnfNodePtr>> local_py_params() {
|
||||
return {local_py_params_keys_, local_py_params_values_};
|
||||
}
|
||||
void AddLocalPyParam(const std::string &name, const AnfNodePtr &node) {
|
||||
MS_LOG(DEBUG) << "Add '" << name << "', " << node->DebugString();
|
||||
local_py_params_keys_.emplace_back(NewValueNode(name));
|
||||
local_py_params_values_.emplace_back(node);
|
||||
}
|
||||
// Call this methon only if you need update a variable. Usually variable override.
|
||||
void UpdateLocalPyParam(const std::string &name, const AnfNodePtr &node) {
|
||||
auto iter = std::find_if(local_py_params_keys_.cbegin(), local_py_params_keys_.cend(),
|
||||
[&name](const AnfNodePtr node) -> bool {
|
||||
const auto value_node = dyn_cast<ValueNode>(node);
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
const StringImmPtr &str_imm = dyn_cast<StringImm>(value_node->value());
|
||||
MS_EXCEPTION_IF_NULL(str_imm);
|
||||
return name == str_imm->value();
|
||||
});
|
||||
if (iter == local_py_params_keys_.cend()) {
|
||||
MS_LOG(EXCEPTION) << "Only for updating. Should not call this method if 'name' not exist.";
|
||||
}
|
||||
// Find the same position in 'values', and update the node.
|
||||
auto distance = std::distance(local_py_params_keys_.cbegin(), iter);
|
||||
auto values_pos_iter = local_py_params_values_.begin() + distance;
|
||||
MS_LOG(DEBUG) << "Update '" << name << "', " << (*values_pos_iter)->DebugString() << " -> " << node->DebugString();
|
||||
*values_pos_iter = node;
|
||||
(void)local_py_params_keys_.insert(std::pair<std::string, AnfNodePtr>(name, NewValueNode(name)));
|
||||
(void)local_py_params_values_.insert(std::pair<std::string, AnfNodePtr>(name, node));
|
||||
}
|
||||
|
||||
void UpdateLocalPyParam(const std::vector<AnfNodePtr> &keys, const std::vector<AnfNodePtr> &values) {
|
||||
// Call this methon only if you need update a variable. Usually variable override.
|
||||
void UpdateLocalPyParam(const std::string &name, const AnfNodePtr &node) {
|
||||
auto key_iter = local_py_params_keys_.find(name);
|
||||
if (key_iter == local_py_params_keys_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Only for updating. Should not call this method if '" << name << "' not exist.";
|
||||
}
|
||||
// Find the same position in 'values', and update the node.
|
||||
MS_LOG(DEBUG) << "Update '" << name << "', " << local_py_params_values_[name]->DebugString() << " -> "
|
||||
<< node->DebugString();
|
||||
local_py_params_values_[name] = node;
|
||||
}
|
||||
|
||||
void EraseLocalPyParam(const std::string &name) {
|
||||
auto key_iter = local_py_params_keys_.find(name);
|
||||
auto value_iter = local_py_params_values_.find(name);
|
||||
if (key_iter != local_py_params_keys_.end() && value_iter != local_py_params_values_.end()) {
|
||||
MS_LOG(DEBUG) << "Erase '" << name
|
||||
<< "' from local_py_params, the key node:" << local_py_params_keys_[name]->DebugString()
|
||||
<< ", the value node:" << local_py_params_values_[name]->DebugString();
|
||||
local_py_params_keys_.erase(key_iter);
|
||||
local_py_params_values_.erase(value_iter);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateLocalPyParam(const std::map<std::string, AnfNodePtr> &keys, std::map<std::string, AnfNodePtr> values) {
|
||||
if (keys.size() != values.size()) {
|
||||
MS_LOG(EXCEPTION) << "keys size should be equal to values size.";
|
||||
}
|
||||
for (size_t index = 0; index < keys.size(); ++index) {
|
||||
auto iter = std::find(local_py_params_keys_.cbegin(), local_py_params_keys_.cend(), keys[index]);
|
||||
if (iter == local_py_params_keys_.cend()) {
|
||||
local_py_params_keys_.emplace_back(keys[index]);
|
||||
local_py_params_values_.emplace_back(values[index]);
|
||||
MS_LOG(DEBUG) << "Add '" << keys[index]->DebugString() << "', " << values[index]->DebugString();
|
||||
for (auto iter = keys.begin(); iter != keys.end(); ++iter) {
|
||||
const std::string &cur_key_name = iter->first;
|
||||
if (local_py_params_keys_.find(cur_key_name) == local_py_params_keys_.end()) {
|
||||
(void)local_py_params_keys_.insert(std::pair<std::string, AnfNodePtr>(cur_key_name, iter->second));
|
||||
(void)local_py_params_values_.insert(std::pair<std::string, AnfNodePtr>(cur_key_name, values[cur_key_name]));
|
||||
MS_LOG(DEBUG) << "Add '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString();
|
||||
} else {
|
||||
auto distance = std::distance(local_py_params_keys_.cbegin(), iter);
|
||||
auto values_pos_iter = local_py_params_values_.begin() + distance;
|
||||
MS_LOG(DEBUG) << "Update '" << keys[index]->DebugString() << "', " << values[index]->DebugString();
|
||||
*values_pos_iter = values[index];
|
||||
MS_LOG(DEBUG) << "Update '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString();
|
||||
local_py_params_values_[cur_key_name] = values[cur_key_name];
|
||||
}
|
||||
}
|
||||
if (local_py_params_keys_.size() != local_py_params_values_.size()) {
|
||||
|
@ -186,8 +189,8 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
// Collect all python symbols in the block.
|
||||
// We treat both global symbols and local symbols declared previously as global symbols.
|
||||
py::dict global_py_params_;
|
||||
std::vector<AnfNodePtr> local_py_params_keys_;
|
||||
std::vector<AnfNodePtr> local_py_params_values_;
|
||||
std::map<std::string, AnfNodePtr> local_py_params_keys_;
|
||||
std::map<std::string, AnfNodePtr> local_py_params_values_;
|
||||
|
||||
// Isolated nodes.
|
||||
OrderedSet<AnfNodePtr> isolated_nodes_;
|
||||
|
|
|
@ -1613,7 +1613,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
|
|||
MS_EXCEPTION_IF_NULL(after_block->func_graph());
|
||||
after_block->func_graph()->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true);
|
||||
}
|
||||
|
||||
static const auto use_fallback = (support_fallback() != "0");
|
||||
// Process the if-true branch
|
||||
std::pair<FunctionBlockPtr, FunctionBlockPtr> true_branch_graphs;
|
||||
py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
|
||||
|
@ -1630,6 +1630,10 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
|
|||
}
|
||||
MS_LOG(DEBUG) << "The true_end block jump to after, true_block: " << true_block->ToString()
|
||||
<< ", true_end: " << true_end->ToString();
|
||||
|
||||
if (use_fallback) {
|
||||
UpdateBlockPyParams(after_block, true_end);
|
||||
}
|
||||
}
|
||||
|
||||
// Process the orelse branch
|
||||
|
@ -1648,6 +1652,9 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
|
|||
}
|
||||
MS_LOG(DEBUG) << "The false_end block jump to after, false_block: " << false_block->ToString()
|
||||
<< ", false_end: " << false_end->ToString();
|
||||
if (use_fallback) {
|
||||
UpdateBlockPyParams(after_block, false_end);
|
||||
}
|
||||
}
|
||||
auto switch_app = block->ConditionalJump(bool_node, true_block, false_block);
|
||||
|
||||
|
@ -2364,7 +2371,7 @@ void Parser::UpdateInterpretForUserNode(const AnfNodePtr &user_node, const std::
|
|||
}
|
||||
|
||||
bool Parser::IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
|
||||
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) {
|
||||
const std::map<std::string, AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// Check global parameters.
|
||||
if (global_dict.contains(script_text)) {
|
||||
|
@ -2373,14 +2380,7 @@ bool Parser::IsScriptInParams(const std::string &script_text, const py::dict &gl
|
|||
}
|
||||
|
||||
// Check local parameters.
|
||||
auto in_local_params = std::any_of(local_keys.begin(), local_keys.end(), [&script_text](const AnfNodePtr &node) {
|
||||
const auto value_node = dyn_cast<ValueNode>(node);
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
const StringImmPtr &str_imm = dyn_cast<StringImm>(value_node->value());
|
||||
MS_EXCEPTION_IF_NULL(str_imm);
|
||||
return script_text == str_imm->value();
|
||||
});
|
||||
if (in_local_params) {
|
||||
if (local_keys.find(script_text) != local_keys.end()) {
|
||||
MS_LOG(DEBUG) << "[" << func_graph->ToString() << "] Found `" << script_text << "` in local params.";
|
||||
return true;
|
||||
}
|
||||
|
@ -2414,7 +2414,7 @@ AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNod
|
|||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
// Check if script_text is in global/local params.
|
||||
py::dict global_dict = block->global_py_params();
|
||||
const auto &[keys, values] = block->local_py_params();
|
||||
auto [keys, values] = block->local_py_params();
|
||||
if (IsTensorType(value_node, script_text)) {
|
||||
return value_node;
|
||||
}
|
||||
|
@ -2434,13 +2434,14 @@ AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNod
|
|||
auto current_fg = value_node->func_graph();
|
||||
std::vector<AnfNodePtr> filter_keys;
|
||||
std::vector<AnfNodePtr> filter_values;
|
||||
for (size_t index = 0; index < values.size(); ++index) {
|
||||
auto value = values[index];
|
||||
for (auto iter = values.begin(); iter != values.end(); ++iter) {
|
||||
auto value = iter->second;
|
||||
auto fg = GetValueNode<FuncGraphPtr>(value);
|
||||
if (fg == current_fg) {
|
||||
continue;
|
||||
}
|
||||
(void)filter_keys.emplace_back(keys[index]);
|
||||
const std::string &name = iter->first;
|
||||
(void)filter_keys.emplace_back(keys[name]);
|
||||
(void)filter_values.emplace_back(value);
|
||||
}
|
||||
auto local_dict_node = ParseDictByKeysAndValues(block, filter_keys, filter_values);
|
||||
|
|
|
@ -219,7 +219,7 @@ class Parser {
|
|||
|
||||
// Check if script_text is in global/local params.
|
||||
bool IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
|
||||
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph);
|
||||
const std::map<std::string, AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph);
|
||||
// Set the interpret flag for the node calling the interpret node.
|
||||
void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node);
|
||||
void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const std::vector<AnfNodePtr> &nodes);
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback control flow."""
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore.nn import Cell
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_single_if_no_else_type():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class FalseNet(Cell):
|
||||
def __init__(self):
|
||||
super(FalseNet, self).__init__()
|
||||
self.cond = False
|
||||
|
||||
def construct(self):
|
||||
x = np.array(1)
|
||||
if self.cond:
|
||||
return type(2).mro()
|
||||
return type(x).mro()
|
||||
|
||||
test_net = FalseNet()
|
||||
res = test_net()
|
||||
assert str(res) == "(<class 'numpy.ndarray'>, <class 'object'>)"
|
||||
|
||||
|
||||
def test_single_if_no_else_type_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class TrueNet(Cell):
|
||||
def __init__(self):
|
||||
super(TrueNet, self).__init__()
|
||||
self.cond = True
|
||||
|
||||
def construct(self):
|
||||
x = np.array(2)
|
||||
y = 2
|
||||
if self.cond:
|
||||
return type(y).mro()
|
||||
return type(x).mro()
|
||||
|
||||
test_net = TrueNet()
|
||||
res = test_net()
|
||||
assert str(res) == "(<class 'int'>, <class 'object'>)"
|
Loading…
Reference in New Issue