Fix the print issue with the node which has not abstract.

This commit is contained in:
Margaret_wangrui 2022-12-22 19:09:53 +08:00
parent 3b75995d22
commit f11071db76
3 changed files with 36 additions and 4 deletions

View File

@ -60,7 +60,9 @@ class PrintTupleWrapper : public AnfVisitor {
class PrintConstStringWrapper : public AnfVisitor {
bool CheckNeedConvert(const AbstractBasePtr &abs) const {
MS_EXCEPTION_IF_NULL(abs);
if (abs == nullptr) {
return false;
}
if (abs->isa<abstract::AbstractSequence>()) {
auto sequence_abs = abs->cast<abstract::AbstractSequencePtr>();
const auto &elements = sequence_abs->elements();

View File

@ -305,7 +305,9 @@ class InlinerBase : public AnfVisitor {
node_inputs.push_back(NewValueNode(new_fg));
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs),
[&args](size_t i) { return args[i]; });
return node->func_graph()->NewCNode(node_inputs);
auto ret_node = node->func_graph()->NewCNode(node_inputs);
ret_node->set_abstract(node->abstract());
return ret_node;
}
bool CheckSwitchInputs(const std::vector<AnfNodePtr> &sw_inputs) {
@ -366,7 +368,6 @@ class InlinerBase : public AnfVisitor {
return has_branch;
}
private:
bool is_checked_{false};
bool is_recursive_{false};
// If the user guarantee that fg has no recursive.

View File

@ -17,7 +17,7 @@
import numpy as np
import pytest
from mindspore import Tensor
from mindspore import Tensor, jit
import mindspore.nn as nn
from mindspore.ops import operations as P
import mindspore.context as context
@ -224,3 +224,32 @@ def test_print_tensor_dtype_in_nested_tuple(mode):
y = Tensor([1, 2], dtype=ms.int32)
net = PrintDtypeNet()
net(x, y)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
def test_print_abs():
"""
Feature: Print op.
Description: Print the result of max.
Expectation: success.
"""
@jit
def function():
tuple_x = (Tensor(10).astype("float32"), Tensor(30).astype("float32"), Tensor(50).astype("float32"))
sum_x = Tensor(0).astype("float32")
max_x = Tensor(0).astype("float32")
for i in range(3):
max_x = max(tuple_x)
sum_x += max_x
print(max_x)
print(i)
for x in zip(tuple_x):
sum_x = sum(x, sum_x)
print(sum_x)
return sum_x
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
out = function()
print("out:", out)