Handle string which is not JoinedStr in the raise statement.

This commit is contained in:
Margaret_wangrui 2022-03-30 15:36:28 +08:00
parent 95f02f851e
commit fb4fe5813e
2 changed files with 59 additions and 21 deletions

View File

@ -774,13 +774,10 @@ AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &n
auto new_node = block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
UpdateInterpretForUserNode(new_node, {left_node, right_node});
// Handling % symbol in formatted string values by JIT Fallback.
// For example, string % var, "The string is: %s." % str or "The number is: %d." % num
if (IsPrimitiveCNode(left_node, prim::kPrimMakeTuple)) {
auto inputs = left_node->cast<CNodePtr>()->inputs();
if (inputs.size() <= 1) {
MS_LOG(EXCEPTION) << "Unexpected maketuple node:" << left_node->DebugString();
}
auto str_node = inputs[1];
// The string AnfNode may be created by ParseJoinedStr or ParseStr.
// For example, string % var, f"The string is: %s." % str or "The number is: %d." % num
static const auto use_fallback = (support_fallback() != "0");
if (use_fallback) {
auto op_cnode = op_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(op_cnode);
const size_t symbol_index = 2;
@ -788,12 +785,28 @@ AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &n
MS_LOG(EXCEPTION) << "Unexpected symbol node:" << op_node->DebugString();
}
auto mod_node = op_cnode->input(symbol_index);
if (IsValueNode<StringImm>(str_node) && IsValueNode<Symbol>(mod_node)) {
new_node->set_interpret(true);
auto new_interpret_node = HandleInterpret(block, new_node, node);
return new_interpret_node;
if (IsValueNode<Symbol>(mod_node)) {
if (IsPrimitiveCNode(left_node, prim::kPrimMakeTuple)) {
// left_node created by ParseJoinedStr
auto inputs = left_node->cast<CNodePtr>()->inputs();
if (inputs.size() <= 1) {
MS_LOG(EXCEPTION) << "Unexpected maketuple node:" << left_node->DebugString();
}
auto str_node = inputs[1];
if (IsValueNode<StringImm>(str_node)) {
new_node->set_interpret(true);
auto new_interpret_node = HandleInterpret(block, new_node, node);
return new_interpret_node;
}
} else if (IsValueNode<StringImm>(left_node)) {
// left_node created by ParseStr
new_node->set_interpret(true);
auto new_interpret_node = HandleInterpret(block, new_node, node);
return new_interpret_node;
}
}
}
return new_node;
}

View File

@ -185,11 +185,11 @@ def test_raise_7():
x = [1, 3, 5, 7, 9]
raise ValueError("Not expected value, x is {}".format(x))
with pytest.raises(ValueError) as info:
with pytest.raises(ValueError) as raise_info_7:
net = RaiseNet()
res = net()
print("res:", res)
assert "Not expected value, x is [1, 3, 5, 7, 9]" in str(info.value)
assert "Not expected value, x is [1, 3, 5, 7, 9]" in str(raise_info_7.value)
@pytest.mark.level0
@ -215,11 +215,11 @@ def test_raise_8():
return 3
raise ValueError("Not expected value, x is {}".format(self.x))
with pytest.raises(ValueError) as info:
with pytest.raises(ValueError) as raise_info_8:
net = RaiseNet()
res = net()
print("res:", res)
assert "Not expected value, x is [1, 3, 5, 7]" in str(info.value)
assert "Not expected value, x is [1, 3, 5, 7]" in str(raise_info_8.value)
@pytest.mark.level0
@ -238,11 +238,11 @@ def test_raise_9():
x = 11
raise ValueError(f"The input can not be {x}.")
with pytest.raises(ValueError) as info:
with pytest.raises(ValueError) as raise_info_9:
net = RaiseNet()
res = net()
print("res:", res)
assert "The input can not be 11." in str(info.value)
assert "The input can not be 11." in str(raise_info_9.value)
@pytest.mark.level0
@ -260,11 +260,11 @@ def test_raise_10():
def construct(self, x):
raise ValueError(f"The input can not be %s." % x)
with pytest.raises(ValueError) as info:
with pytest.raises(ValueError) as raise_info_10:
net = RaiseNet()
res = net(11)
print("res:", res)
assert "The input can not be 11." in str(info.value)
assert "The input can not be 11." in str(raise_info_10.value)
@pytest.mark.level0
@ -282,8 +282,33 @@ def test_raise_11():
def construct(self, x):
raise ValueError(f"The input can not be ", x, ".")
with pytest.raises(ValueError) as info:
with pytest.raises(ValueError) as raise_info_11:
net = RaiseNet()
res = net(11)
print("res:", res)
assert "The input can not be 11." in str(info.value)
assert "The input can not be 11." in str(raise_info_11.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_12():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise(string % var).
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self):
x = 1
if x == 1:
raise ValueError("The var name is %s, it can not be %d." % ("x", x))
return x
with pytest.raises(ValueError) as raise_info_12:
net = RaiseNet()
res = net()
print("res:", res)
assert "The var name is x, it can not be 1." in str(raise_info_12.value)