forked from mindspore-Ecosystem/mindspore
Fix the print issue with the node which has not abstract.
This commit is contained in:
parent
3b75995d22
commit
f11071db76
|
@ -60,7 +60,9 @@ class PrintTupleWrapper : public AnfVisitor {
|
|||
|
||||
class PrintConstStringWrapper : public AnfVisitor {
|
||||
bool CheckNeedConvert(const AbstractBasePtr &abs) const {
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
if (abs == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (abs->isa<abstract::AbstractSequence>()) {
|
||||
auto sequence_abs = abs->cast<abstract::AbstractSequencePtr>();
|
||||
const auto &elements = sequence_abs->elements();
|
||||
|
|
|
@ -305,7 +305,9 @@ class InlinerBase : public AnfVisitor {
|
|||
node_inputs.push_back(NewValueNode(new_fg));
|
||||
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs),
|
||||
[&args](size_t i) { return args[i]; });
|
||||
return node->func_graph()->NewCNode(node_inputs);
|
||||
auto ret_node = node->func_graph()->NewCNode(node_inputs);
|
||||
ret_node->set_abstract(node->abstract());
|
||||
return ret_node;
|
||||
}
|
||||
|
||||
bool CheckSwitchInputs(const std::vector<AnfNodePtr> &sw_inputs) {
|
||||
|
@ -366,7 +368,6 @@ class InlinerBase : public AnfVisitor {
|
|||
return has_branch;
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_checked_{false};
|
||||
bool is_recursive_{false};
|
||||
// If the user guarantee that fg has no recursive.
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import Tensor, jit
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.context as context
|
||||
|
@ -224,3 +224,32 @@ def test_print_tensor_dtype_in_nested_tuple(mode):
|
|||
y = Tensor([1, 2], dtype=ms.int32)
|
||||
net = PrintDtypeNet()
|
||||
net(x, y)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
def test_print_abs():
|
||||
"""
|
||||
Feature: Print op.
|
||||
Description: Print the result of max.
|
||||
Expectation: success.
|
||||
"""
|
||||
@jit
|
||||
def function():
|
||||
tuple_x = (Tensor(10).astype("float32"), Tensor(30).astype("float32"), Tensor(50).astype("float32"))
|
||||
sum_x = Tensor(0).astype("float32")
|
||||
max_x = Tensor(0).astype("float32")
|
||||
for i in range(3):
|
||||
max_x = max(tuple_x)
|
||||
sum_x += max_x
|
||||
print(max_x)
|
||||
print(i)
|
||||
for x in zip(tuple_x):
|
||||
sum_x = sum(x, sum_x)
|
||||
print(sum_x)
|
||||
return sum_x
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
out = function()
|
||||
print("out:", out)
|
||||
|
|
Loading…
Reference in New Issue