!32081 [ME][Fallback] Support format string('%' symbol).
Merge pull request !32081 from Margaret_wangrui/raise_print_format
This commit is contained in:
commit
fd5c2c0503
|
@ -761,6 +761,27 @@ AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &n
|
|||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
||||
auto new_node = block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
|
||||
UpdateInterpretForUserNode(new_node, {left_node, right_node});
|
||||
// Handling % symbol in formatted string values by JIT Fallback.
|
||||
// For example, string % var, "The string is: %s." % str or "The number is: %d." % num
|
||||
if (IsPrimitiveCNode(left_node, prim::kPrimMakeTuple)) {
|
||||
auto inputs = left_node->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() <= 1) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected maketuple node:" << left_node->DebugString();
|
||||
}
|
||||
auto str_node = inputs[1];
|
||||
auto op_cnode = op_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(op_cnode);
|
||||
const size_t symbol_index = 2;
|
||||
if (op_cnode->inputs().size() <= symbol_index) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected symbol node:" << op_node->DebugString();
|
||||
}
|
||||
auto mod_node = op_cnode->input(symbol_index);
|
||||
if (IsValueNode<StringImm>(str_node) && IsValueNode<Symbol>(mod_node)) {
|
||||
new_node->set_interpret(true);
|
||||
auto new_interpret_node = HandleInterpret(block, new_node, node);
|
||||
return new_interpret_node;
|
||||
}
|
||||
}
|
||||
return new_node;
|
||||
}
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ def test_np_print_2():
|
|||
net = PrintNet()
|
||||
res = net()
|
||||
print("res: ", res)
|
||||
assert (res.asnumpy() == [1, 2, 3, 4, 5]).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -81,28 +82,6 @@ def test_tensor_print_1():
|
|||
assert np.all(np_print().asnumpy() == np.array([1, 2, 3, 4, 5]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_print_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Support print.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class PrintNet(nn.Cell):
|
||||
def construct(self):
|
||||
x = np.array([1, 2, 3, 4, 5])
|
||||
print("Tensor(x): ", Tensor(x))
|
||||
return Tensor(x)
|
||||
|
||||
net = PrintNet()
|
||||
res = net()
|
||||
print("res: ", res)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -124,6 +103,7 @@ def test_print_cnode_1():
|
|||
y = Tensor(np.array([1, 2, 3, 4, 5]))
|
||||
res = print_func(x, y)
|
||||
print("res: ", res)
|
||||
assert (res.asnumpy() == [2, 4, 6, 8, 10]).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -147,6 +127,7 @@ def test_print_cnode_2():
|
|||
|
||||
res = print_func()
|
||||
print("res: ", res)
|
||||
assert (res.asnumpy() == [2, 4, 6, 8, 10]).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -170,6 +151,7 @@ def test_print_cnode_3():
|
|||
|
||||
res = print_func()
|
||||
print("res: ", res)
|
||||
assert (res.asnumpy() == [2, 4, 6, 8, 10]).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -225,3 +207,71 @@ def test_print_validate():
|
|||
res = print_func()
|
||||
print("res: ", res)
|
||||
assert "Should not use Python object in runtime" in str(err.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_format_np():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Support print.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def print_func():
|
||||
np_x = np.array([1, 2, 3, 4, 5])
|
||||
np_y = np.array([1, 2, 3, 4, 5])
|
||||
np_sum = np_x + np_y
|
||||
print("np_sum: {}".format(np_sum))
|
||||
return Tensor(np_sum)
|
||||
|
||||
res = print_func()
|
||||
print("res: ", res)
|
||||
assert (res.asnumpy() == [2, 4, 6, 8, 10]).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_format_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Support print.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def print_func():
|
||||
x = Tensor(np.array([1, 2, 3, 4, 5]))
|
||||
y = Tensor(np.array([1, 2, 3, 4, 5]))
|
||||
tensor_sum = x + y
|
||||
print("tensor_sum: {}".format(tensor_sum))
|
||||
return tensor_sum
|
||||
|
||||
res = print_func()
|
||||
print("res: ", res)
|
||||
assert (res.asnumpy() == [2, 4, 6, 8, 10]).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_string_format():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Support print(string % var).
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def print_func():
|
||||
print("I'm %s. I'm %d years old." % ('MindSpore', 3))
|
||||
return 0
|
||||
|
||||
res = print_func()
|
||||
print("res: ", res)
|
||||
|
|
|
@ -169,7 +169,11 @@ def test_raise_6():
|
|||
print("res:", res)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph raise feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_7():
|
||||
"""
|
||||
Feature: graph raise.
|
||||
|
@ -188,7 +192,11 @@ def test_raise_7():
|
|||
assert "Not expected value, x is [1, 3, 5, 7, 9]" in str(info.value)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph raise feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_8():
|
||||
"""
|
||||
Feature: graph raise.
|
||||
|
@ -237,11 +245,15 @@ def test_raise_9():
|
|||
assert "The input can not be 11." in str(info.value)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph raise feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_10():
|
||||
"""
|
||||
Feature: graph raise.
|
||||
Description: Test raise.
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise(string % var).
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
|
|
Loading…
Reference in New Issue