graph fallback control flow ForceToBoolNode change

This commit is contained in:
liangzhibo 2022-04-26 19:10:21 +08:00
parent 517e7ec1fe
commit 2fc92bba08
5 changed files with 159 additions and 16 deletions

View File

@ -140,6 +140,7 @@
"mindspore/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice_grad.py" "redefined-outer-name"
"mindspore/tests/st/pynative/parser/test_parser_construct.py" "bad-super-call"
"mindspore/tests/ut/python/optimizer/test_auto_grad.py" "broad-except"
"mindspore/tests/st/fallback/control_flow/test_fallback_100_if_after_if.py" "unused-variable"
#MindSpore Lite
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/generator.py" "redefined-builtin"

View File

@ -185,8 +185,6 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
}
}
}
// If information transform by phi, need remove the var in interpret dict in fallback feature.
EraseLocalPyParam(var_name);
}
func_graph()->add_parameter(phi_param);

View File

@ -124,30 +124,25 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
}
}
void EraseLocalPyParam(const std::string &name) {
auto key_iter = local_py_params_keys_.find(name);
auto value_iter = local_py_params_values_.find(name);
if (key_iter != local_py_params_keys_.end() && value_iter != local_py_params_values_.end()) {
MS_LOG(DEBUG) << "Erase '" << name << "' from local_py_params, the key node:" << key_iter->second->DebugString()
<< ", the value node:" << value_iter->second->DebugString();
local_py_params_keys_.erase(key_iter);
local_py_params_values_.erase(value_iter);
}
}
// Update local parameters from previous block.
void UpdateLocalPyParam(const std::map<std::string, AnfNodePtr> &keys, std::map<std::string, AnfNodePtr> values) {
if (keys.size() != values.size()) {
MS_LOG(EXCEPTION) << "keys size should be equal to values size.";
}
for (auto iter = keys.begin(); iter != keys.end(); ++iter) {
const std::string &cur_key_name = iter->first;
if (local_py_params_keys_.find(cur_key_name) == local_py_params_keys_.end()) {
auto key_iter = local_py_params_keys_.find(cur_key_name);
if (key_iter == local_py_params_keys_.end()) {
(void)local_py_params_keys_.insert(std::pair<std::string, AnfNodePtr>(cur_key_name, iter->second));
(void)local_py_params_values_.insert(std::pair<std::string, AnfNodePtr>(cur_key_name, values[cur_key_name]));
MS_LOG(DEBUG) << "Add '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString();
} else {
MS_LOG(DEBUG) << "Update '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString();
local_py_params_values_[cur_key_name] = values[cur_key_name];
// The local variable is already in the current block. This means the current block has multiples previous
// blocks. If this local variable is used in the current block, it should be converted to phi node. So we erase
// it from local_py_params.
(void)local_py_params_keys_.erase(key_iter);
(void)local_py_params_values_.erase(cur_key_name);
MS_LOG(DEBUG) << "Erase '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString();
}
}
if (local_py_params_keys_.size() != local_py_params_values_.size()) {

View File

@ -263,3 +263,26 @@ def test_single_if_change_variable_value():
return Tensor(0)
res = control_flow_if()
assert np.all(res.asnumpy() == np.array([4, 5, 6, 7]))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_single_if_np_all():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def control_flow_if():
x = np.array([1, 2, 3, 4])
y = np.array([4, 5, 6])
if np.all(x == np.array([1, 2, 3, 4])) and np.any(y == np.array([4, 4, 4])):
x += 3
return Tensor(x)
return Tensor(0)
res = control_flow_if()
assert np.all(res.asnumpy() == np.array([4, 5, 6, 7]))

View File

@ -0,0 +1,126 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test graph fallback control flow if after if scenario"""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_in_if_tensor():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def control_flow_if_in_while():
x = Tensor(1)
y = Tensor(0)
if x < Tensor(5):
y += Tensor(4)
if y > Tensor(3):
x += Tensor(3)
return x + y
res = control_flow_if_in_while()
assert res == 8
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_in_if_tensor_2():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def control_flow_if_in_while():
x = Tensor(1)
y = Tensor(0)
if x > Tensor(5):
y -= Tensor(4)
elif x > Tensor(2):
y -= Tensor(1)
else:
y += Tensor(4)
if y < Tensor(3):
x += Tensor(3)
else:
x += y
return x + y
res = control_flow_if_in_while()
assert res == 9
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_in_if_tensor_3():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def control_flow_if_in_while(a):
if a > 15:
y = Tensor(1)
else:
y = Tensor(2)
if a == Tensor(10):
a = Tensor(11)
return a
res = control_flow_if_in_while(Tensor(10))
assert res == 11
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_in_if_numpy():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def control_flow_if_in_while():
x = np.array([1, 2, 3, 4])
a = sum(x)
if a > 15:
y = np.array([1, 2, 3, 4])
else:
y = np.array([4, 5, 6])
if np.all(y == np.array([1, 2, 3, 4])):
ret = Tensor(1)
else:
ret = Tensor(2)
return ret
res = control_flow_if_in_while()
assert res == 2