forked from mindspore-Ecosystem/mindspore
graph fallback control flow ForceToBoolNode change
This commit is contained in:
parent
517e7ec1fe
commit
2fc92bba08
|
@ -140,6 +140,7 @@
|
|||
"mindspore/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice_grad.py" "redefined-outer-name"
|
||||
"mindspore/tests/st/pynative/parser/test_parser_construct.py" "bad-super-call"
|
||||
"mindspore/tests/ut/python/optimizer/test_auto_grad.py" "broad-except"
|
||||
"mindspore/tests/st/fallback/control_flow/test_fallback_100_if_after_if.py" "unused-variable"
|
||||
|
||||
#MindSpore Lite
|
||||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/generator.py" "redefined-builtin"
|
||||
|
|
|
@ -185,8 +185,6 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
|
|||
}
|
||||
}
|
||||
}
|
||||
// If information transform by phi, need remove the var in interpret dict in fallback feature.
|
||||
EraseLocalPyParam(var_name);
|
||||
}
|
||||
|
||||
func_graph()->add_parameter(phi_param);
|
||||
|
|
|
@ -124,30 +124,25 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
}
|
||||
}
|
||||
|
||||
void EraseLocalPyParam(const std::string &name) {
|
||||
auto key_iter = local_py_params_keys_.find(name);
|
||||
auto value_iter = local_py_params_values_.find(name);
|
||||
if (key_iter != local_py_params_keys_.end() && value_iter != local_py_params_values_.end()) {
|
||||
MS_LOG(DEBUG) << "Erase '" << name << "' from local_py_params, the key node:" << key_iter->second->DebugString()
|
||||
<< ", the value node:" << value_iter->second->DebugString();
|
||||
local_py_params_keys_.erase(key_iter);
|
||||
local_py_params_values_.erase(value_iter);
|
||||
}
|
||||
}
|
||||
|
||||
// Update local parameters from previous block.
|
||||
void UpdateLocalPyParam(const std::map<std::string, AnfNodePtr> &keys, std::map<std::string, AnfNodePtr> values) {
|
||||
if (keys.size() != values.size()) {
|
||||
MS_LOG(EXCEPTION) << "keys size should be equal to values size.";
|
||||
}
|
||||
for (auto iter = keys.begin(); iter != keys.end(); ++iter) {
|
||||
const std::string &cur_key_name = iter->first;
|
||||
if (local_py_params_keys_.find(cur_key_name) == local_py_params_keys_.end()) {
|
||||
auto key_iter = local_py_params_keys_.find(cur_key_name);
|
||||
if (key_iter == local_py_params_keys_.end()) {
|
||||
(void)local_py_params_keys_.insert(std::pair<std::string, AnfNodePtr>(cur_key_name, iter->second));
|
||||
(void)local_py_params_values_.insert(std::pair<std::string, AnfNodePtr>(cur_key_name, values[cur_key_name]));
|
||||
MS_LOG(DEBUG) << "Add '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString();
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Update '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString();
|
||||
local_py_params_values_[cur_key_name] = values[cur_key_name];
|
||||
// The local variable is already in the current block. This means the current block has multiples previous
|
||||
// blocks. If this local variable is used in the current block, it should be converted to phi node. So we erase
|
||||
// it from local_py_params.
|
||||
(void)local_py_params_keys_.erase(key_iter);
|
||||
(void)local_py_params_values_.erase(cur_key_name);
|
||||
MS_LOG(DEBUG) << "Erase '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString();
|
||||
}
|
||||
}
|
||||
if (local_py_params_keys_.size() != local_py_params_values_.size()) {
|
||||
|
|
|
@ -263,3 +263,26 @@ def test_single_if_change_variable_value():
|
|||
return Tensor(0)
|
||||
res = control_flow_if()
|
||||
assert np.all(res.asnumpy() == np.array([4, 5, 6, 7]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_single_if_np_all():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def control_flow_if():
|
||||
x = np.array([1, 2, 3, 4])
|
||||
y = np.array([4, 5, 6])
|
||||
if np.all(x == np.array([1, 2, 3, 4])) and np.any(y == np.array([4, 4, 4])):
|
||||
x += 3
|
||||
return Tensor(x)
|
||||
return Tensor(0)
|
||||
res = control_flow_if()
|
||||
assert np.all(res.asnumpy() == np.array([4, 5, 6, 7]))
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback control flow if after if scenario"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_if_in_if_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def control_flow_if_in_while():
|
||||
x = Tensor(1)
|
||||
y = Tensor(0)
|
||||
if x < Tensor(5):
|
||||
y += Tensor(4)
|
||||
if y > Tensor(3):
|
||||
x += Tensor(3)
|
||||
return x + y
|
||||
res = control_flow_if_in_while()
|
||||
assert res == 8
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_if_in_if_tensor_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def control_flow_if_in_while():
|
||||
x = Tensor(1)
|
||||
y = Tensor(0)
|
||||
if x > Tensor(5):
|
||||
y -= Tensor(4)
|
||||
elif x > Tensor(2):
|
||||
y -= Tensor(1)
|
||||
else:
|
||||
y += Tensor(4)
|
||||
if y < Tensor(3):
|
||||
x += Tensor(3)
|
||||
else:
|
||||
x += y
|
||||
return x + y
|
||||
res = control_flow_if_in_while()
|
||||
assert res == 9
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_if_in_if_tensor_3():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def control_flow_if_in_while(a):
|
||||
if a > 15:
|
||||
y = Tensor(1)
|
||||
else:
|
||||
y = Tensor(2)
|
||||
if a == Tensor(10):
|
||||
a = Tensor(11)
|
||||
return a
|
||||
res = control_flow_if_in_while(Tensor(10))
|
||||
assert res == 11
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_if_in_if_numpy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def control_flow_if_in_while():
|
||||
x = np.array([1, 2, 3, 4])
|
||||
a = sum(x)
|
||||
if a > 15:
|
||||
y = np.array([1, 2, 3, 4])
|
||||
else:
|
||||
y = np.array([4, 5, 6])
|
||||
if np.all(y == np.array([1, 2, 3, 4])):
|
||||
ret = Tensor(1)
|
||||
else:
|
||||
ret = Tensor(2)
|
||||
return ret
|
||||
res = control_flow_if_in_while()
|
||||
assert res == 2
|
Loading…
Reference in New Issue