Fix remove internal output for unique device target

This commit is contained in:
yujianfeng 2020-08-26 16:02:45 +08:00
parent 8eff6c96b4
commit e688e1df32
2 changed files with 4 additions and 4 deletions
mindspore/ccsrc/backend/optimizer/ascend/format_type
tests/ut/cpp/pre_activate/ascend/format_type

View File

@ -58,7 +58,7 @@ const AnfNodePtr RemoveInternalOutput::Process(const FuncGraphPtr &func_graph, c
if (kernel_graph == nullptr) { if (kernel_graph == nullptr) {
return nullptr; return nullptr;
} }
if (!kernel_graph->IsInternalOutput(node, 0)) { if (!kernel_graph->IsUniqueTargetInternalOutput(node, 0)) {
return nullptr; return nullptr;
} }
if (!UsedForOutputOnly(func_graph, node)) { if (!UsedForOutputOnly(func_graph, node)) {

View File

@ -49,7 +49,7 @@ class TestHWRemoveInternalOutput : public BackendCommon {
auto make_tuple = GetMakeTuple(kg); auto make_tuple = GetMakeTuple(kg);
auto add = make_tuple->cast<CNodePtr>()->input(1); auto add = make_tuple->cast<CNodePtr>()->input(1);
MS_EXCEPTION_IF_NULL(add); MS_EXCEPTION_IF_NULL(add);
kg->AddInternalOutput(add, add); kg->AddInternalOutput(add, add, 0, true);
KernelBuildInfoBuilder builder; KernelBuildInfoBuilder builder;
builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()}); builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
@ -77,8 +77,8 @@ class TestHWRemoveInternalOutput : public BackendCommon {
MS_EXCEPTION_IF_NULL(tuple_getitem2); MS_EXCEPTION_IF_NULL(tuple_getitem2);
auto max_pool = tuple_getitem1->cast<CNodePtr>()->input(1); auto max_pool = tuple_getitem1->cast<CNodePtr>()->input(1);
MS_EXCEPTION_IF_NULL(max_pool); MS_EXCEPTION_IF_NULL(max_pool);
kg->AddInternalOutput(tuple_getitem1, max_pool); kg->AddInternalOutput(tuple_getitem1, max_pool, 0, true);
kg->AddInternalOutput(tuple_getitem2, max_pool); kg->AddInternalOutput(tuple_getitem2, max_pool, 1, true);
KernelBuildInfoBuilder builder; KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{}}); builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}, {}}); builder.SetOutputsReshapeType({{}, {}});