forked from mindspore-Ecosystem/mindspore
!35454 Change the logic of inherit local python parameter between block.
Merge pull request !35454 from LiangZhibo/master
This commit is contained in:
commit
ef555956f2
|
@ -21,6 +21,8 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <unordered_set>
|
||||
#include <queue>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
|
@ -125,6 +127,34 @@ AnfNodePtr FunctionBlock::ReadLocalVariable(const std::string &var_name) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr FunctionBlock::FindPredInterpretNode(const std::string &var_name) {
|
||||
// Search the predecessors of the current block for the local parameter. If one of the local parameter of the
|
||||
// predecessors is interpret node, the phi_param needs to set the interpret true.
|
||||
std::unordered_set<FunctionBlock *> visited_block;
|
||||
std::queue<FunctionBlock *> block_queue;
|
||||
block_queue.push(this);
|
||||
while (!block_queue.empty()) {
|
||||
const auto &cur_block = block_queue.front();
|
||||
block_queue.pop();
|
||||
visited_block.insert(cur_block);
|
||||
auto pred_node = cur_block->ReadLocalVariable(var_name);
|
||||
if (pred_node != nullptr) {
|
||||
bool interpret_without_internal =
|
||||
IsPrimitiveCNode(pred_node, prim::kPrimPyInterpret) && !pred_node->interpret_internal_type();
|
||||
if (pred_node->interpret() || interpret_without_internal) {
|
||||
return pred_node;
|
||||
}
|
||||
} else {
|
||||
for (const auto &cur_pred_block : cur_block->prev_blocks()) {
|
||||
if (visited_block.count(cur_pred_block) == 0) {
|
||||
block_queue.push(cur_pred_block);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Read variable from predecessors
|
||||
AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
|
||||
MS_LOG(DEBUG) << "Read begin, var: " << var_name << ", block: " << ToString();
|
||||
|
@ -175,24 +205,14 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
|
|||
ParameterPtr phi_param = std::make_shared<Parameter>(func_graph());
|
||||
MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " generate phi node "
|
||||
<< phi_param->ToString() << " for " << var_name;
|
||||
|
||||
if (use_fallback) {
|
||||
// Check the phi whether need interpret flag.
|
||||
// If has Interpret node which name is var_name in prev_blocks_, means the phi need set interpret true.
|
||||
for (auto &pred : prev_blocks_) {
|
||||
MS_EXCEPTION_IF_NULL(pred);
|
||||
auto iter = pred->local_py_params_values_.find(var_name);
|
||||
if (iter != pred->local_py_params_values_.end()) {
|
||||
auto pred_node = iter->second;
|
||||
if (pred_node->interpret_special_type()) {
|
||||
phi_param->set_interpret_special_type(true);
|
||||
}
|
||||
bool interpret_without_internal =
|
||||
IsPrimitiveCNode(pred_node, prim::kPrimPyInterpret) && !pred_node->interpret_internal_type();
|
||||
if (pred_node->interpret() || interpret_without_internal) {
|
||||
phi_param->set_interpret(true);
|
||||
break;
|
||||
}
|
||||
const auto &pred_node = FindPredInterpretNode(var_name);
|
||||
if (pred_node != nullptr) {
|
||||
if (pred_node->interpret_special_type()) {
|
||||
phi_param->set_interpret_special_type(true);
|
||||
}
|
||||
phi_param->set_interpret(true);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -203,6 +203,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
bool is_dead_block_{false};
|
||||
|
||||
AnfNodePtr ReadLocalVariable(const std::string &var_name);
|
||||
AnfNodePtr FindPredInterpretNode(const std::string &var_name);
|
||||
};
|
||||
} // namespace parse
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -822,12 +822,6 @@ LocationPtr Parser::GetLocation(const py::object &node) const {
|
|||
return location;
|
||||
}
|
||||
|
||||
void Parser::UpdateBlockPyParams(const FunctionBlockPtr &block, const FunctionBlockPtr &pre_block) {
|
||||
block->UpdateGlobalPyParam(pre_block->global_py_params());
|
||||
const auto &[keys, values] = pre_block->local_py_params();
|
||||
block->UpdateLocalPyParam(keys, values);
|
||||
}
|
||||
|
||||
void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block,
|
||||
const FunctionBlockPtr &false_block) {
|
||||
MS_EXCEPTION_IF_NULL(true_block);
|
||||
|
@ -840,8 +834,8 @@ void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const Functi
|
|||
|
||||
static const auto use_fallback = (support_fallback() != "0");
|
||||
if (use_fallback) {
|
||||
UpdateBlockPyParams(true_block, pre_block);
|
||||
UpdateBlockPyParams(false_block, pre_block);
|
||||
true_block->UpdateGlobalPyParam(pre_block->global_py_params());
|
||||
false_block->UpdateGlobalPyParam(pre_block->global_py_params());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1772,7 +1766,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
|
|||
<< true_block->ToString() << ", true_end: " << true_end->ToString();
|
||||
}
|
||||
if (use_fallback) {
|
||||
UpdateBlockPyParams(after_block, true_end);
|
||||
after_block->UpdateGlobalPyParam(true_end->global_py_params());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1794,7 +1788,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
|
|||
<< false_block->ToString() << ", false_end: " << false_end->ToString();
|
||||
}
|
||||
if (use_fallback) {
|
||||
UpdateBlockPyParams(after_block, false_end);
|
||||
after_block->UpdateGlobalPyParam(false_end->global_py_params());
|
||||
}
|
||||
}
|
||||
auto switch_app = block->ConditionalJump(bool_node, true_block, false_block);
|
||||
|
@ -1862,9 +1856,9 @@ FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::obj
|
|||
py::object test_node = python_adapter::GetPyObjAttr(node, "test");
|
||||
static const auto use_fallback = (support_fallback() != "0");
|
||||
if (use_fallback) {
|
||||
UpdateBlockPyParams(header_block, block);
|
||||
UpdateBlockPyParams(body_block, block);
|
||||
UpdateBlockPyParams(after_block, block);
|
||||
header_block->UpdateGlobalPyParam(block->global_py_params());
|
||||
body_block->UpdateGlobalPyParam(block->global_py_params());
|
||||
after_block->UpdateGlobalPyParam(block->global_py_params());
|
||||
}
|
||||
AnfNodePtr condition_node = ParseExprNode(header_block, test_node);
|
||||
condition_node = header_block->ForceToWhileCond(condition_node);
|
||||
|
@ -1955,8 +1949,8 @@ FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py:
|
|||
CNodePtr target_var = body_func_graph->NewCNodeInOrder({op_getitem, iter_node, loop_var});
|
||||
static const auto use_fallback = (support_fallback() != "0");
|
||||
if (use_fallback) {
|
||||
UpdateBlockPyParams(header_block, block);
|
||||
UpdateBlockPyParams(body_block, block);
|
||||
header_block->UpdateGlobalPyParam(block->global_py_params());
|
||||
body_block->UpdateGlobalPyParam(block->global_py_params());
|
||||
}
|
||||
WriteAssignVars(body_block, target_node, target_var);
|
||||
|
||||
|
@ -1981,7 +1975,7 @@ FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py:
|
|||
block->Jump(header_block, {NewValueNode(static_cast<int64_t>(0))});
|
||||
body_block->Mature();
|
||||
if (use_fallback) {
|
||||
UpdateBlockPyParams(after_block, block);
|
||||
after_block->UpdateGlobalPyParam(block->global_py_params());
|
||||
}
|
||||
|
||||
header_block->ConditionalJump(cond_node, body_block, after_block);
|
||||
|
@ -1991,7 +1985,7 @@ FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py:
|
|||
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
||||
FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
|
||||
if (use_fallback) {
|
||||
UpdateBlockPyParams(after_body_block, block);
|
||||
after_body_block->UpdateGlobalPyParam(block->global_py_params());
|
||||
}
|
||||
if (after_body_block->func_graph()->get_return() == nullptr) {
|
||||
after_body_block->Jump(header_block, {loop_var_inc});
|
||||
|
@ -2044,8 +2038,8 @@ FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py:
|
|||
|
||||
static const auto use_fallback = (support_fallback() != "0");
|
||||
if (use_fallback) {
|
||||
UpdateBlockPyParams(header_block, block);
|
||||
UpdateBlockPyParams(body_block, block);
|
||||
header_block->UpdateGlobalPyParam(block->global_py_params());
|
||||
body_block->UpdateGlobalPyParam(block->global_py_params());
|
||||
}
|
||||
|
||||
// Get variable name of 'x' in statement 'for x in xs'
|
||||
|
@ -2073,7 +2067,7 @@ FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py:
|
|||
block->Jump(header_block, {iter_node, NewValueNode(static_cast<int64_t>(0))});
|
||||
body_block->Mature();
|
||||
if (use_fallback) {
|
||||
UpdateBlockPyParams(after_block, block);
|
||||
after_block->UpdateGlobalPyParam(block->global_py_params());
|
||||
}
|
||||
header_block->ConditionalJump(cond_node, body_block, after_block);
|
||||
|
||||
|
@ -2087,7 +2081,7 @@ FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py:
|
|||
body_block->Jump(rolled_body_block, {});
|
||||
auto rolled_body_call = dyn_cast<CNode>(body_block->func_graph()->output());
|
||||
if (use_fallback) {
|
||||
UpdateBlockPyParams(rolled_body_block, block);
|
||||
rolled_body_block->UpdateGlobalPyParam(block->global_py_params());
|
||||
}
|
||||
|
||||
// Parse loop body statements with loop context.
|
||||
|
@ -2095,7 +2089,7 @@ FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py:
|
|||
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
||||
FunctionBlockPtr after_body_block = ParseStatements(rolled_body_block, body_node);
|
||||
if (use_fallback) {
|
||||
UpdateBlockPyParams(after_body_block, block);
|
||||
after_body_block->UpdateGlobalPyParam(block->global_py_params());
|
||||
}
|
||||
MS_LOG(DEBUG) << "Finish rolled block, after_body_block: " << after_body_block->ToString()
|
||||
<< ", rolled_body_block: " << rolled_body_block->ToString();
|
||||
|
|
|
@ -313,8 +313,6 @@ class Parser {
|
|||
// Return a make tuple for input elements list
|
||||
AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes);
|
||||
|
||||
void UpdateBlockPyParams(const FunctionBlockPtr &block, const FunctionBlockPtr &pre_block);
|
||||
|
||||
// The shared_ptr will be hold by GraphManager, so just hold a weak ref here.
|
||||
static FuncGraphWeakPtr top_func_graph_;
|
||||
// Python function id, used to indicate whether two CNodes come from the same Python function
|
||||
|
|
|
@ -102,7 +102,11 @@ def test_while_after_if_numpy():
|
|||
assert (res.asnumpy() == [-3, -4]).all()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_while_after_if_numpy_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
|
|
@ -53,31 +53,11 @@ def test_while_after_if_in_while_tensor():
|
|||
assert res == 33
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_while_after_if_in_while_numpy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def control_flow_while_after_if_in_while():
|
||||
x = np.array([1])
|
||||
y = np.array([10])
|
||||
while Tensor(x) < Tensor(y):
|
||||
z = Tensor([-2])
|
||||
if Tensor(x) < z:
|
||||
y = 0
|
||||
else:
|
||||
x = y + x
|
||||
while Tensor(y[0]) < Tensor(x[0]):
|
||||
y += x
|
||||
return Tensor(y)
|
||||
res = control_flow_while_after_if_in_while()
|
||||
assert res == 21
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_while_after_if_in_while_numpy_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -97,7 +77,7 @@ def test_while_after_if_in_while_numpy_2():
|
|||
else:
|
||||
y = x + y
|
||||
while y[0] > x[0]:
|
||||
y[0] -= x[0]
|
||||
y -= x[0]
|
||||
return Tensor(x), Tensor(y)
|
||||
res_x, res_y = control_flow_while_after_if_in_while()
|
||||
assert res_x == 1
|
||||
|
|
|
@ -83,30 +83,6 @@ def test_while_after_if_in_for_tensor_2():
|
|||
assert res_y == 5
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_while_after_if_in_for_numpy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def control_flow_while_after_if_in_for():
|
||||
x = np.array([1])
|
||||
y = np.array([10])
|
||||
for _ in range(3):
|
||||
z = Tensor([-2])
|
||||
if Tensor(x) < z:
|
||||
y = 0
|
||||
else:
|
||||
x = y - x
|
||||
while Tensor(y[0]) > Tensor(x[0]):
|
||||
y -= x
|
||||
return Tensor(y)
|
||||
res = control_flow_while_after_if_in_for()
|
||||
assert res == 1
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_while_after_if_in_for_numpy_2():
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_while_after_while_in_while_numpy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def control_flow_while_after_while_in_while():
|
||||
x = Tensor([-1])
|
||||
y = Tensor([-2])
|
||||
while abs(x) <= abs(y):
|
||||
z = np.array([3, 4, 5])
|
||||
index = 0
|
||||
z_sum = 0
|
||||
while index < 3:
|
||||
z_sum += z[index]
|
||||
index += 1
|
||||
x = x + Tensor(z_sum)
|
||||
while y < x:
|
||||
y += x
|
||||
return x, y
|
||||
res = control_flow_while_after_while_in_while()
|
||||
assert res == (11, 20)
|
|
@ -118,7 +118,11 @@ def test_while_after_for_in_if_3():
|
|||
assert (res.asnumpy() == [-3, -4]).all()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_while_after_for_in_if_4():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback control flow if after for scenario"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
|
||||
|
@ -79,7 +78,6 @@ def test_if_after_for_tensor_3():
|
|||
assert res == 20
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Currently, a can not be parsed in if statement.")
|
||||
def test_if_after_for_tensor_zip():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -90,6 +88,7 @@ def test_if_after_for_tensor_zip():
|
|||
def control_flow_for():
|
||||
tuple_x = (Tensor(1), Tensor(3), Tensor(5))
|
||||
sum_x = Tensor(0)
|
||||
a = Tensor(0)
|
||||
for x in zip(tuple_x):
|
||||
sum_x += x
|
||||
a = x
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
|
||||
|
@ -69,32 +68,6 @@ def test_while_after_while_in_while_tensor_2():
|
|||
assert res == (2, -1, -3)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_while_after_while_in_while_numpy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def control_flow_while_after_while_in_while():
|
||||
x = Tensor([-1])
|
||||
y = Tensor([-2])
|
||||
while abs(x) <= abs(y):
|
||||
z = np.array([3, 4, 5])
|
||||
index = 0
|
||||
z_sum = 0
|
||||
while index < len(z):
|
||||
z_sum += z[index]
|
||||
index += 1
|
||||
x = x + Tensor(z_sum)
|
||||
while y < x:
|
||||
y += x
|
||||
return x, y
|
||||
res = control_flow_while_after_while_in_while()
|
||||
assert res == (11, 20)
|
||||
|
||||
|
||||
def test_while_after_while_in_while_numpy_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
|
|
@ -54,7 +54,7 @@ def test_for_after_if_4():
|
|||
x = x + min(x, y)
|
||||
|
||||
z = (Tensor(1), Tensor(2), Tensor(3))
|
||||
for i in zip(z):
|
||||
for i in z:
|
||||
x = x * i
|
||||
return x
|
||||
|
||||
|
|
|
@ -70,7 +70,6 @@ def test_for_after_if_in_if_tensor_2():
|
|||
assert res == 19
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Failed to find parent context.')
|
||||
def test_for_after_if_in_if_numpy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
|
|
@ -43,7 +43,6 @@ def test_for_after_for_in_if_1():
|
|||
assert res == 8
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_for_after_for_in_if_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -64,7 +63,7 @@ def test_for_after_for_in_if_2():
|
|||
for i in range(3):
|
||||
y = y + i
|
||||
|
||||
return x, y
|
||||
return Tensor(x), Tensor(y)
|
||||
|
||||
res_x, res_y = func3302()
|
||||
assert res_x == 4
|
||||
|
|
Loading…
Reference in New Issue