!49360 支持raise变量控制流场景

Merge pull request !49360 from 李良灿/fixraise
This commit is contained in:
i-robot 2023-02-26 14:56:54 +00:00 committed by Gitee
commit 311d1be98d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 169 additions and 114 deletions

View File

@ -2064,6 +2064,14 @@ EvalResultPtr PyExecuteEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abst
// Call python script string.
MS_LOG(DEBUG) << "Call script: " << script << ", args: " << args_abs_list;
// when return value should be none
if (current_interpret_node->has_user_data("__py_execute_no_return_type__")) {
AbstractBasePtr res = std::make_shared<abstract::AbstractNone>();
res->set_value(kAnyValue);
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
return infer_result;
}
TypePtr type = kFloat64;
if (current_interpret_node->has_user_data("__py_execute_tensor_type__")) {
type = current_interpret_node->user_data<Type>("__py_execute_tensor_type__");
@ -2739,12 +2747,6 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
MS_EXCEPTION_IF_NULL(node);
auto cur_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(cur_graph);
if (cur_graph->is_tensor_condition_branch()) {
MS_LOG(EXCEPTION) << "Currently only supports raise in constant scenarios. "
<< "Tensor type data cannot exist in the conditional statement. "
<< "Please check your conditions which raise node is located at: "
<< trace::GetDebugInfo(node->debug_info());
}
if (args_abs_list.empty()) {
// process raise
MS_LOG(EXCEPTION) << "No active exception to reraise.";
@ -2772,7 +2774,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
const auto input = inputs[index];
auto input_abs = args_abs_list[index - 1];
MS_EXCEPTION_IF_NULL(input_abs);
bool need_symbol = CheckNeedSymbol(input, input_abs);
const bool need_symbol = CheckNeedSymbol(input, input_abs);
if (need_symbol) {
exception_string += "'";
}
@ -2807,6 +2809,8 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
// Build the PyExecute node for raise error.
const auto raise_error_node = cur_graph->NewCNodeInOrder(
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
auto none_type = std::make_shared<TypeNone>();
raise_error_node->set_user_data<Type>("__py_execute_no_return_type__", none_type);
cur_graph->ReplaceInOrder(node, raise_error_node);
AnalysisEnginePtr eng = out_conf->engine();
MS_EXCEPTION_IF_NULL(eng);
@ -2849,7 +2853,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
return need_symbol;
}
std::string GetExceptionString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node,
const bool need_comma = false, bool need_symbol = false) {
const bool need_comma = false, const bool need_symbol = false) {
std::string exception_str;
MS_EXCEPTION_IF_NULL(arg);
if (arg->isa<abstract::AbstractSequence>()) {
@ -2874,21 +2878,25 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
}
std::string GetTupleOrListString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node,
const bool need_comma, bool need_symbol = false) {
const bool need_comma, const bool need_symbol = false) {
MS_EXCEPTION_IF_NULL(arg);
std::string exception_str;
bool is_tuple = arg->isa<abstract::AbstractTuple>();
// Process raise ValueError("str")
auto arg_tuple = arg->cast_ptr<abstract::AbstractSequence>();
MS_EXCEPTION_IF_NULL(arg_tuple);
auto const &arg_tuple_elements = arg_tuple->elements();
const auto &arg_tuple_elements = arg_tuple->elements();
if (arg_tuple_elements.size() == 0) {
MS_LOG(EXCEPTION) << "The arg_tuple_elements can't be empty.";
}
if (!input->isa<CNode>()) {
std::string key = "__internal_error_value" + std::to_string(num_str_) + "__";
num_str_ += 1;
exception_str = exception_str + "{" + key + "}";
if (need_symbol) {
exception_str = exception_str + "'+f'{" + key + "}'+'";
} else {
exception_str = exception_str + key;
}
(void)keys_.emplace_back(NewValueNode(std::make_shared<StringImm>(key)));
(void)values_.emplace_back(input);
return exception_str;

View File

@ -101,6 +101,24 @@ bool CheckAbstractScalar(const AnfNodePtr &node) {
return false;
}
bool CheckIfRaise(const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimPyExecute)) {
auto cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
auto first = inputs[1];
auto script_node = first->cast<ValueNodePtr>();
if (script_node->value()->isa<StringImm>()) {
auto script = GetValueNode<StringImmPtr>(script_node)->value();
std::string raise_script = "raise_func";
auto idx = script.find(raise_script);
if (idx != string::npos) {
return true;
}
}
}
return false;
}
void ValidateAbstract(const AnfNodePtr &node) {
if (node == nullptr) {
MS_LOG(DEBUG) << "Node to validate is invalid";
@ -123,6 +141,11 @@ void ValidateAbstract(const AnfNodePtr &node) {
MS_LOG(DEBUG) << "AbstractError in the graph: " << abstract->ToString();
return;
}
if (CheckIfRaise(node)) {
ShapeVector shp{abstract::Shape::kShapeRankAny};
auto abs = std::make_shared<abstract::AbstractTensor>(kFloat64, std::make_shared<abstract::Shape>(shp));
node->set_abstract(abs);
}
bool is_legal_abstract = abstract->isa<AbstractType>() || abstract->isa<AbstractFunction>() ||
abstract->isa<AbstractTuple>() || abstract->isa<AbstractList>() ||
abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
""" test graph fallback control flow."""
import pytest
import numpy as np
import mindspore
from mindspore import Tensor, jit, context
@ -50,43 +49,6 @@ def test_for_after_while_in_if_1():
assert res == 8
@case_register.level1
@case_register.target_gpu
@case_register.target_ascend
def test_for_after_while_in_if_2():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@jit
def func3202():
x = Tensor([2])
y = Tensor([2])
if x == y and x > Tensor([0]).astype("int32"):
y = y + Tensor([3]).astype("int32")
z = Tensor([5], dtype=mindspore.int32)
while y == z and z == Tensor([5], dtype=mindspore.int32):
y = y * 1
assert y == z, "y must equal z"
z = z + 1
else:
raise ValueError("Enter this branch, not expect!")
for i in range(3):
z = Tensor([i]).astype("int32")
x = x + z
return x, y
with pytest.raises(RuntimeError, match="Currently only supports raise in constant scenarios."):
res_x, res_y = func3202()
assert res_x == 5
assert res_y == 5
@case_register.level1
@case_register.target_gpu
@case_register.target_ascend

View File

@ -168,3 +168,128 @@ def test_raise_with_variable_joinedstr_tensor():
print("res:", res)
assert "The input should not be 1" in str(
raise_info_joinedstr_tensor.value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_raise_with_variable_dic():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self):
x = Tensor(1)
y = Tensor(2)
z = {"x": x, "y": y}
raise ValueError(z)
with pytest.raises(RuntimeError) as raise_info_list:
net = RaiseNet()
res = net()
print("res:", res)
assert "Dictionary type is currently not supporting" in str(
raise_info_list.value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_raise_with_variable_control_flow1():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self, x, y):
if x == y:
raise RuntimeError(f"The input should not be {x}.")
with pytest.raises(RuntimeError) as raise_info_joinedstr_tensor:
net = RaiseNet()
x = Tensor(1)
y = Tensor(1)
res = net(x, y)
print("res:", res)
assert "The input should not be 1" in str(
raise_info_joinedstr_tensor.value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_raise_with_variable_control_flow2():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self, x, y):
if x == y:
raise RuntimeError(f"The input should not be {x}.")
return x
with pytest.raises(RuntimeError) as raise_info_joinedstr_tensor:
net = RaiseNet()
x = Tensor(1)
y = Tensor(1)
res = net(x, y)
print("res:", res)
assert "The input should not be 1" in str(
raise_info_joinedstr_tensor.value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_raise_with_variable_control_flow3():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self, x, y, z):
if x == y:
raise RuntimeError(f"The input should not be {x}.")
return z
with pytest.raises(RuntimeError) as raise_info_joinedstr_tensor:
net = RaiseNet()
x = Tensor(1)
y = Tensor(1)
z = (x, y)
res = net(x, y, z)
print("res:", res)
assert "The input should not be 1" in str(
raise_info_joinedstr_tensor.value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_raise_with_variable_control_flow4():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self, x, y, z):
if x == y:
raise RuntimeError(f"The input should not be {x}.")
return z
with pytest.raises(RuntimeError) as raise_info_joinedstr_tensor:
net = RaiseNet()
x = Tensor(1)
y = Tensor(1)
z = [x, y]
res = net(x, y, z)
print("res:", res)
assert "The input should not be 1" in str(
raise_info_joinedstr_tensor.value)

View File

@ -14,7 +14,7 @@
# ============================================================================
""" test_assert """
import pytest
from mindspore import nn, context, Tensor
from mindspore import nn, context
context.set_context(mode=context.GRAPH_MODE)
@ -153,22 +153,3 @@ def test_assert7():
net()
assert "1 not in [2, 3, 4]" in str(excinfo.value)
assert "assert x in [2, 3, 4]" in str(excinfo.value)
def test_assert8():
"""
Feature: support assert
Description: test assert with variable in condition
Expectation: no error
"""
class Net(nn.Cell):
def construct(self, x):
assert x == 1
return x
net = Net()
a = Tensor(1)
with pytest.raises(RuntimeError) as excinfo:
net(a)
assert "Currently only supports raise in constant scenarios." in str(excinfo.value)

View File

@ -17,7 +17,6 @@ import pytest
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, context
import mindspore.common.dtype as mstype
from mindspore.common.api import _cell_graph_executor
context.set_context(mode=context.GRAPH_MODE)
@ -77,25 +76,6 @@ def test_raise_3():
print("res:", res)
def test_raise_4():
"""
Feature: graph raise.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self, x):
if x == 1:
raise ValueError(f"The input should not be 1.")
return x
with pytest.raises(RuntimeError, match="Currently only supports raise in constant scenarios."):
net = RaiseNet()
x = Tensor(9, mstype.int32)
res = net(x)
assert res == 9
def test_raise_5():
"""
Feature: graph raise.

View File

@ -92,6 +92,7 @@ def test_nest_function_missing_return():
assert "For 'make_range', the 0th input should be a int64 scalar" in str(er.value)
@pytest.mark.skip(reason='Case will not appear for now, but may appear in the future')
def test_raise_in_method():
class NetRaiseInMethod(nn.Cell):
def construct(self, x, y, z):
@ -112,31 +113,6 @@ def test_raise_in_method():
assert "Currently only supports raise in constant scenarios." in str(er.value)
def test_raise_in_nested_function():
class NetNestRaise(nn.Cell):
def nest_fn(self, u):
if u > 0:
# add not support grammar 'raise' here
raise ValueError('Illegal case')
return u + z + 1
def construct(self, x, y, z):
if x == 1:
return Tensor(10, mstype.int32)
elif x == 20:
return self.nest_fn(y)
else:
return y + z
net = NetNestRaise()
x = Tensor(0, mstype.int32)
y = Tensor(5, mstype.int32)
z = Tensor(2, mstype.int32)
with pytest.raises(RuntimeError) as er:
net(x, y, z)
assert "Currently only supports raise in constant scenarios." in str(er.value)
def test_nest_branch_with_return():
class NetBranchWithReturn(nn.Cell):
def construct(self, x, y, z):