!32081 [ME][Fallback] Support format string('%' symbol).

Merge pull request !32081 from Margaret_wangrui/raise_print_format
This commit is contained in:
i-robot 2022-03-29 01:27:16 +00:00 committed by Gitee
commit fd5c2c0503
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 110 additions and 27 deletions

View File

@ -761,6 +761,27 @@ AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &n
MS_EXCEPTION_IF_NULL(block->func_graph());
auto new_node = block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
UpdateInterpretForUserNode(new_node, {left_node, right_node});
// Handling % symbol in formatted string values by JIT Fallback.
// For example, string % var, "The string is: %s." % str or "The number is: %d." % num
if (IsPrimitiveCNode(left_node, prim::kPrimMakeTuple)) {
auto inputs = left_node->cast<CNodePtr>()->inputs();
if (inputs.size() <= 1) {
MS_LOG(EXCEPTION) << "Unexpected maketuple node:" << left_node->DebugString();
}
auto str_node = inputs[1];
auto op_cnode = op_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(op_cnode);
const size_t symbol_index = 2;
if (op_cnode->inputs().size() <= symbol_index) {
MS_LOG(EXCEPTION) << "Unexpected symbol node:" << op_node->DebugString();
}
auto mod_node = op_cnode->input(symbol_index);
if (IsValueNode<StringImm>(str_node) && IsValueNode<Symbol>(mod_node)) {
new_node->set_interpret(true);
auto new_interpret_node = HandleInterpret(block, new_node, node);
return new_interpret_node;
}
}
return new_node;
}

View File

@ -60,6 +60,7 @@ def test_np_print_2():
net = PrintNet()
res = net()
print("res: ", res)
assert (res.asnumpy() == [1, 2, 3, 4, 5]).all()
@pytest.mark.level0
@ -81,28 +82,6 @@ def test_tensor_print_1():
assert np.all(np_print().asnumpy() == np.array([1, 2, 3, 4, 5]))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_tensor_print_2():
"""
Feature: JIT Fallback
Description: Support print.
Expectation: No exception.
"""
class PrintNet(nn.Cell):
def construct(self):
x = np.array([1, 2, 3, 4, 5])
print("Tensor(x): ", Tensor(x))
return Tensor(x)
net = PrintNet()
res = net()
print("res: ", res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@ -124,6 +103,7 @@ def test_print_cnode_1():
y = Tensor(np.array([1, 2, 3, 4, 5]))
res = print_func(x, y)
print("res: ", res)
assert (res.asnumpy() == [2, 4, 6, 8, 10]).all()
@pytest.mark.level0
@ -147,6 +127,7 @@ def test_print_cnode_2():
res = print_func()
print("res: ", res)
assert (res.asnumpy() == [2, 4, 6, 8, 10]).all()
@pytest.mark.level0
@ -170,6 +151,7 @@ def test_print_cnode_3():
res = print_func()
print("res: ", res)
assert (res.asnumpy() == [2, 4, 6, 8, 10]).all()
@pytest.mark.level0
@ -225,3 +207,71 @@ def test_print_validate():
res = print_func()
print("res: ", res)
assert "Should not use Python object in runtime" in str(err.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_print_format_np():
"""
Feature: JIT Fallback
Description: Support print.
Expectation: No exception.
"""
@ms_function
def print_func():
np_x = np.array([1, 2, 3, 4, 5])
np_y = np.array([1, 2, 3, 4, 5])
np_sum = np_x + np_y
print("np_sum: {}".format(np_sum))
return Tensor(np_sum)
res = print_func()
print("res: ", res)
assert (res.asnumpy() == [2, 4, 6, 8, 10]).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_print_format_tensor():
"""
Feature: JIT Fallback
Description: Support print.
Expectation: No exception.
"""
@ms_function
def print_func():
x = Tensor(np.array([1, 2, 3, 4, 5]))
y = Tensor(np.array([1, 2, 3, 4, 5]))
tensor_sum = x + y
print("tensor_sum: {}".format(tensor_sum))
return tensor_sum
res = print_func()
print("res: ", res)
assert (res.asnumpy() == [2, 4, 6, 8, 10]).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_print_string_format():
"""
Feature: JIT Fallback
Description: Support print(string % var).
Expectation: No exception.
"""
@ms_function
def print_func():
print("I'm %s. I'm %d years old." % ('MindSpore', 3))
return 0
res = print_func()
print("res: ", res)

View File

@ -169,7 +169,11 @@ def test_raise_6():
print("res:", res)
@pytest.mark.skip(reason='Not support graph raise feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_7():
"""
Feature: graph raise.
@ -188,7 +192,11 @@ def test_raise_7():
assert "Not expected value, x is [1, 3, 5, 7, 9]" in str(info.value)
@pytest.mark.skip(reason='Not support graph raise feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_8():
"""
Feature: graph raise.
@ -237,11 +245,15 @@ def test_raise_9():
assert "The input can not be 11." in str(info.value)
@pytest.mark.skip(reason='Not support graph raise feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_10():
"""
Feature: graph raise.
Description: Test raise.
Feature: graph raise by JIT Fallback.
Description: Test raise(string % var).
Expectation: No exception.
"""
class RaiseNet(nn.Cell):