forked from mindspore-Ecosystem/mindspore
Handle string which is not JoinedStr in the raise statement.
This commit is contained in:
parent
95f02f851e
commit
fb4fe5813e
|
@ -774,13 +774,10 @@ AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &n
|
|||
auto new_node = block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
|
||||
UpdateInterpretForUserNode(new_node, {left_node, right_node});
|
||||
// Handling % symbol in formatted string values by JIT Fallback.
|
||||
// For example, string % var, "The string is: %s." % str or "The number is: %d." % num
|
||||
if (IsPrimitiveCNode(left_node, prim::kPrimMakeTuple)) {
|
||||
auto inputs = left_node->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() <= 1) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected maketuple node:" << left_node->DebugString();
|
||||
}
|
||||
auto str_node = inputs[1];
|
||||
// The string AnfNode may be created by ParseJoinedStr or ParseStr.
|
||||
// For example, string % var, f"The string is: %s." % str or "The number is: %d." % num
|
||||
static const auto use_fallback = (support_fallback() != "0");
|
||||
if (use_fallback) {
|
||||
auto op_cnode = op_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(op_cnode);
|
||||
const size_t symbol_index = 2;
|
||||
|
@ -788,12 +785,28 @@ AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &n
|
|||
MS_LOG(EXCEPTION) << "Unexpected symbol node:" << op_node->DebugString();
|
||||
}
|
||||
auto mod_node = op_cnode->input(symbol_index);
|
||||
if (IsValueNode<StringImm>(str_node) && IsValueNode<Symbol>(mod_node)) {
|
||||
new_node->set_interpret(true);
|
||||
auto new_interpret_node = HandleInterpret(block, new_node, node);
|
||||
return new_interpret_node;
|
||||
if (IsValueNode<Symbol>(mod_node)) {
|
||||
if (IsPrimitiveCNode(left_node, prim::kPrimMakeTuple)) {
|
||||
// left_node created by ParseJoinedStr
|
||||
auto inputs = left_node->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() <= 1) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected maketuple node:" << left_node->DebugString();
|
||||
}
|
||||
auto str_node = inputs[1];
|
||||
if (IsValueNode<StringImm>(str_node)) {
|
||||
new_node->set_interpret(true);
|
||||
auto new_interpret_node = HandleInterpret(block, new_node, node);
|
||||
return new_interpret_node;
|
||||
}
|
||||
} else if (IsValueNode<StringImm>(left_node)) {
|
||||
// left_node created by ParseStr
|
||||
new_node->set_interpret(true);
|
||||
auto new_interpret_node = HandleInterpret(block, new_node, node);
|
||||
return new_interpret_node;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return new_node;
|
||||
}
|
||||
|
||||
|
|
|
@ -185,11 +185,11 @@ def test_raise_7():
|
|||
x = [1, 3, 5, 7, 9]
|
||||
raise ValueError("Not expected value, x is {}".format(x))
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
with pytest.raises(ValueError) as raise_info_7:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "Not expected value, x is [1, 3, 5, 7, 9]" in str(info.value)
|
||||
assert "Not expected value, x is [1, 3, 5, 7, 9]" in str(raise_info_7.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -215,11 +215,11 @@ def test_raise_8():
|
|||
return 3
|
||||
raise ValueError("Not expected value, x is {}".format(self.x))
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
with pytest.raises(ValueError) as raise_info_8:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "Not expected value, x is [1, 3, 5, 7]" in str(info.value)
|
||||
assert "Not expected value, x is [1, 3, 5, 7]" in str(raise_info_8.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -238,11 +238,11 @@ def test_raise_9():
|
|||
x = 11
|
||||
raise ValueError(f"The input can not be {x}.")
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
with pytest.raises(ValueError) as raise_info_9:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "The input can not be 11." in str(info.value)
|
||||
assert "The input can not be 11." in str(raise_info_9.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -260,11 +260,11 @@ def test_raise_10():
|
|||
def construct(self, x):
|
||||
raise ValueError(f"The input can not be %s." % x)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
with pytest.raises(ValueError) as raise_info_10:
|
||||
net = RaiseNet()
|
||||
res = net(11)
|
||||
print("res:", res)
|
||||
assert "The input can not be 11." in str(info.value)
|
||||
assert "The input can not be 11." in str(raise_info_10.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -282,8 +282,33 @@ def test_raise_11():
|
|||
def construct(self, x):
|
||||
raise ValueError(f"The input can not be ", x, ".")
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
with pytest.raises(ValueError) as raise_info_11:
|
||||
net = RaiseNet()
|
||||
res = net(11)
|
||||
print("res:", res)
|
||||
assert "The input can not be 11." in str(info.value)
|
||||
assert "The input can not be 11." in str(raise_info_11.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_12():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise(string % var).
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self):
|
||||
x = 1
|
||||
if x == 1:
|
||||
raise ValueError("The var name is %s, it can not be %d." % ("x", x))
|
||||
return x
|
||||
|
||||
with pytest.raises(ValueError) as raise_info_12:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "The var name is x, it can not be 1." in str(raise_info_12.value)
|
||||
|
|
Loading…
Reference in New Issue