!35454 Change the logic of inherit local python parameter between block.

Merge pull request !35454 from LiangZhibo/master
This commit is contained in:
i-robot 2022-06-09 01:34:52 +00:00 committed by Gitee
commit ef555956f2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
14 changed files with 118 additions and 125 deletions

View File

@ -21,6 +21,8 @@
#include <string>
#include <memory>
#include <algorithm>
#include <unordered_set>
#include <queue>
#include "pybind11/pybind11.h"
#include "pipeline/jit/parse/resolve.h"
@ -125,6 +127,34 @@ AnfNodePtr FunctionBlock::ReadLocalVariable(const std::string &var_name) {
return nullptr;
}
AnfNodePtr FunctionBlock::FindPredInterpretNode(const std::string &var_name) {
// Search the predecessors of the current block for the local parameter. If one of the local parameter of the
// predecessors is interpret node, the phi_param needs to set the interpret true.
std::unordered_set<FunctionBlock *> visited_block;
std::queue<FunctionBlock *> block_queue;
block_queue.push(this);
while (!block_queue.empty()) {
const auto &cur_block = block_queue.front();
block_queue.pop();
visited_block.insert(cur_block);
auto pred_node = cur_block->ReadLocalVariable(var_name);
if (pred_node != nullptr) {
bool interpret_without_internal =
IsPrimitiveCNode(pred_node, prim::kPrimPyInterpret) && !pred_node->interpret_internal_type();
if (pred_node->interpret() || interpret_without_internal) {
return pred_node;
}
} else {
for (const auto &cur_pred_block : cur_block->prev_blocks()) {
if (visited_block.count(cur_pred_block) == 0) {
block_queue.push(cur_pred_block);
}
}
}
}
return nullptr;
}
// Read variable from predecessors
AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
MS_LOG(DEBUG) << "Read begin, var: " << var_name << ", block: " << ToString();
@ -175,24 +205,14 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
ParameterPtr phi_param = std::make_shared<Parameter>(func_graph());
MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " generate phi node "
<< phi_param->ToString() << " for " << var_name;
if (use_fallback) {
// Check the phi whether need interpret flag.
// If has Interpret node which name is var_name in prev_blocks_, means the phi need set interpret true.
for (auto &pred : prev_blocks_) {
MS_EXCEPTION_IF_NULL(pred);
auto iter = pred->local_py_params_values_.find(var_name);
if (iter != pred->local_py_params_values_.end()) {
auto pred_node = iter->second;
if (pred_node->interpret_special_type()) {
phi_param->set_interpret_special_type(true);
}
bool interpret_without_internal =
IsPrimitiveCNode(pred_node, prim::kPrimPyInterpret) && !pred_node->interpret_internal_type();
if (pred_node->interpret() || interpret_without_internal) {
phi_param->set_interpret(true);
break;
}
const auto &pred_node = FindPredInterpretNode(var_name);
if (pred_node != nullptr) {
if (pred_node->interpret_special_type()) {
phi_param->set_interpret_special_type(true);
}
phi_param->set_interpret(true);
}
}

View File

@ -203,6 +203,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
bool is_dead_block_{false};
AnfNodePtr ReadLocalVariable(const std::string &var_name);
AnfNodePtr FindPredInterpretNode(const std::string &var_name);
};
} // namespace parse
} // namespace mindspore

View File

@ -822,12 +822,6 @@ LocationPtr Parser::GetLocation(const py::object &node) const {
return location;
}
void Parser::UpdateBlockPyParams(const FunctionBlockPtr &block, const FunctionBlockPtr &pre_block) {
block->UpdateGlobalPyParam(pre_block->global_py_params());
const auto &[keys, values] = pre_block->local_py_params();
block->UpdateLocalPyParam(keys, values);
}
void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block,
const FunctionBlockPtr &false_block) {
MS_EXCEPTION_IF_NULL(true_block);
@ -840,8 +834,8 @@ void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const Functi
static const auto use_fallback = (support_fallback() != "0");
if (use_fallback) {
UpdateBlockPyParams(true_block, pre_block);
UpdateBlockPyParams(false_block, pre_block);
true_block->UpdateGlobalPyParam(pre_block->global_py_params());
false_block->UpdateGlobalPyParam(pre_block->global_py_params());
}
}
@ -1772,7 +1766,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
<< true_block->ToString() << ", true_end: " << true_end->ToString();
}
if (use_fallback) {
UpdateBlockPyParams(after_block, true_end);
after_block->UpdateGlobalPyParam(true_end->global_py_params());
}
}
@ -1794,7 +1788,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
<< false_block->ToString() << ", false_end: " << false_end->ToString();
}
if (use_fallback) {
UpdateBlockPyParams(after_block, false_end);
after_block->UpdateGlobalPyParam(false_end->global_py_params());
}
}
auto switch_app = block->ConditionalJump(bool_node, true_block, false_block);
@ -1862,9 +1856,9 @@ FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::obj
py::object test_node = python_adapter::GetPyObjAttr(node, "test");
static const auto use_fallback = (support_fallback() != "0");
if (use_fallback) {
UpdateBlockPyParams(header_block, block);
UpdateBlockPyParams(body_block, block);
UpdateBlockPyParams(after_block, block);
header_block->UpdateGlobalPyParam(block->global_py_params());
body_block->UpdateGlobalPyParam(block->global_py_params());
after_block->UpdateGlobalPyParam(block->global_py_params());
}
AnfNodePtr condition_node = ParseExprNode(header_block, test_node);
condition_node = header_block->ForceToWhileCond(condition_node);
@ -1955,8 +1949,8 @@ FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py:
CNodePtr target_var = body_func_graph->NewCNodeInOrder({op_getitem, iter_node, loop_var});
static const auto use_fallback = (support_fallback() != "0");
if (use_fallback) {
UpdateBlockPyParams(header_block, block);
UpdateBlockPyParams(body_block, block);
header_block->UpdateGlobalPyParam(block->global_py_params());
body_block->UpdateGlobalPyParam(block->global_py_params());
}
WriteAssignVars(body_block, target_node, target_var);
@ -1981,7 +1975,7 @@ FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py:
block->Jump(header_block, {NewValueNode(static_cast<int64_t>(0))});
body_block->Mature();
if (use_fallback) {
UpdateBlockPyParams(after_block, block);
after_block->UpdateGlobalPyParam(block->global_py_params());
}
header_block->ConditionalJump(cond_node, body_block, after_block);
@ -1991,7 +1985,7 @@ FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py:
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
if (use_fallback) {
UpdateBlockPyParams(after_body_block, block);
after_body_block->UpdateGlobalPyParam(block->global_py_params());
}
if (after_body_block->func_graph()->get_return() == nullptr) {
after_body_block->Jump(header_block, {loop_var_inc});
@ -2044,8 +2038,8 @@ FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py:
static const auto use_fallback = (support_fallback() != "0");
if (use_fallback) {
UpdateBlockPyParams(header_block, block);
UpdateBlockPyParams(body_block, block);
header_block->UpdateGlobalPyParam(block->global_py_params());
body_block->UpdateGlobalPyParam(block->global_py_params());
}
// Get variable name of 'x' in statement 'for x in xs'
@ -2073,7 +2067,7 @@ FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py:
block->Jump(header_block, {iter_node, NewValueNode(static_cast<int64_t>(0))});
body_block->Mature();
if (use_fallback) {
UpdateBlockPyParams(after_block, block);
after_block->UpdateGlobalPyParam(block->global_py_params());
}
header_block->ConditionalJump(cond_node, body_block, after_block);
@ -2087,7 +2081,7 @@ FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py:
body_block->Jump(rolled_body_block, {});
auto rolled_body_call = dyn_cast<CNode>(body_block->func_graph()->output());
if (use_fallback) {
UpdateBlockPyParams(rolled_body_block, block);
rolled_body_block->UpdateGlobalPyParam(block->global_py_params());
}
// Parse loop body statements with loop context.
@ -2095,7 +2089,7 @@ FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py:
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
FunctionBlockPtr after_body_block = ParseStatements(rolled_body_block, body_node);
if (use_fallback) {
UpdateBlockPyParams(after_body_block, block);
after_body_block->UpdateGlobalPyParam(block->global_py_params());
}
MS_LOG(DEBUG) << "Finish rolled block, after_body_block: " << after_body_block->ToString()
<< ", rolled_body_block: " << rolled_body_block->ToString();

View File

@ -313,8 +313,6 @@ class Parser {
// Return a make tuple for input elements list
AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes);
void UpdateBlockPyParams(const FunctionBlockPtr &block, const FunctionBlockPtr &pre_block);
// The shared_ptr will be hold by GraphManager, so just hold a weak ref here.
static FuncGraphWeakPtr top_func_graph_;
// Python function id, used to indicate whether two CNodes come from the same Python function

View File

@ -102,7 +102,11 @@ def test_while_after_if_numpy():
assert (res.asnumpy() == [-3, -4]).all()
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_while_after_if_numpy_2():
"""
Feature: JIT Fallback

View File

@ -53,31 +53,11 @@ def test_while_after_if_in_while_tensor():
assert res == 33
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_while_after_if_in_while_numpy():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def control_flow_while_after_if_in_while():
x = np.array([1])
y = np.array([10])
while Tensor(x) < Tensor(y):
z = Tensor([-2])
if Tensor(x) < z:
y = 0
else:
x = y + x
while Tensor(y[0]) < Tensor(x[0]):
y += x
return Tensor(y)
res = control_flow_while_after_if_in_while()
assert res == 21
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_while_after_if_in_while_numpy_2():
"""
Feature: JIT Fallback
@ -97,7 +77,7 @@ def test_while_after_if_in_while_numpy_2():
else:
y = x + y
while y[0] > x[0]:
y[0] -= x[0]
y -= x[0]
return Tensor(x), Tensor(y)
res_x, res_y = control_flow_while_after_if_in_while()
assert res_x == 1

View File

@ -83,30 +83,6 @@ def test_while_after_if_in_for_tensor_2():
assert res_y == 5
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_while_after_if_in_for_numpy():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def control_flow_while_after_if_in_for():
x = np.array([1])
y = np.array([10])
for _ in range(3):
z = Tensor([-2])
if Tensor(x) < z:
y = 0
else:
x = y - x
while Tensor(y[0]) > Tensor(x[0]):
y -= x
return Tensor(y)
res = control_flow_while_after_if_in_for()
assert res == 1
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_while_after_if_in_for_numpy_2():
"""

View File

@ -0,0 +1,46 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_while_after_while_in_while_numpy():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def control_flow_while_after_while_in_while():
x = Tensor([-1])
y = Tensor([-2])
while abs(x) <= abs(y):
z = np.array([3, 4, 5])
index = 0
z_sum = 0
while index < 3:
z_sum += z[index]
index += 1
x = x + Tensor(z_sum)
while y < x:
y += x
return x, y
res = control_flow_while_after_while_in_while()
assert res == (11, 20)

View File

@ -118,7 +118,11 @@ def test_while_after_for_in_if_3():
assert (res.asnumpy() == [-3, -4]).all()
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_while_after_for_in_if_4():
"""
Feature: JIT Fallback

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
""" test graph fallback control flow if after for scenario"""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
@ -79,7 +78,6 @@ def test_if_after_for_tensor_3():
assert res == 20
@pytest.mark.skip(reason="Currently, a can not be parsed in if statement.")
def test_if_after_for_tensor_zip():
"""
Feature: JIT Fallback
@ -90,6 +88,7 @@ def test_if_after_for_tensor_zip():
def control_flow_for():
tuple_x = (Tensor(1), Tensor(3), Tensor(5))
sum_x = Tensor(0)
a = Tensor(0)
for x in zip(tuple_x):
sum_x += x
a = x

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
@ -69,32 +68,6 @@ def test_while_after_while_in_while_tensor_2():
assert res == (2, -1, -3)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_while_after_while_in_while_numpy():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def control_flow_while_after_while_in_while():
x = Tensor([-1])
y = Tensor([-2])
while abs(x) <= abs(y):
z = np.array([3, 4, 5])
index = 0
z_sum = 0
while index < len(z):
z_sum += z[index]
index += 1
x = x + Tensor(z_sum)
while y < x:
y += x
return x, y
res = control_flow_while_after_while_in_while()
assert res == (11, 20)
def test_while_after_while_in_while_numpy_2():
"""
Feature: JIT Fallback

View File

@ -54,7 +54,7 @@ def test_for_after_if_4():
x = x + min(x, y)
z = (Tensor(1), Tensor(2), Tensor(3))
for i in zip(z):
for i in z:
x = x * i
return x

View File

@ -70,7 +70,6 @@ def test_for_after_if_in_if_tensor_2():
assert res == 19
@pytest.mark.skip(reason='Failed to find parent context.')
def test_for_after_if_in_if_numpy():
"""
Feature: JIT Fallback

View File

@ -43,7 +43,6 @@ def test_for_after_for_in_if_1():
assert res == 8
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_for_after_for_in_if_2():
"""
Feature: JIT Fallback
@ -64,7 +63,7 @@ def test_for_after_for_in_if_2():
for i in range(3):
y = y + i
return x, y
return Tensor(x), Tensor(y)
res_x, res_y = func3302()
assert res_x == 4