fix circle bug of opt depend && merge cast

This commit is contained in:
WilliamLian 2020-08-24 16:40:52 +08:00
parent e3899c552c
commit e95b42496c
5 changed files with 38 additions and 8 deletions

View File

@ -71,7 +71,8 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
} else if (is_insert_input) {
// if need padding & is input need insert a transdata
// reshape[padding shape] -> transdata[padding shape] -> node
auto padding_shape = trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, 0));
auto padding_shape =
trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, insert_index));
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name());
trans_node = trans_data;

View File

@ -553,6 +553,30 @@ bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
return output_node_list->size() > 1;
}
bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto output_node_list = GetRealNodeUsedList(graph, node);
MS_EXCEPTION_IF_NULL(output_node_list);
if (output_node_list->empty()) {
return true;
}
for (const auto &output : *output_node_list) {
auto out_node = output.first;
auto name = AnfAlgo::GetCNodeName(out_node);
if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() ||
name == prim::kPrimTupleGetItem->name()) {
auto result = IsNotRealUsedByOthers(graph, out_node);
if (!result) {
return result;
}
continue;
}
return false;
}
return true;
}
AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) {
auto idx = NewValueNode(SizeToInt(output_idx));
MS_EXCEPTION_IF_NULL(idx);

View File

@ -175,6 +175,7 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
const AnfNodePtr &node,
size_t output_index);
bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);
void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs);

View File

@ -28,7 +28,7 @@ namespace mindspore {
namespace opt {
constexpr auto kSingleInputIndex = 1;
namespace {
AnfNodePtr GetReplaceNode(const AnfNodePtr &node) {
AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return nullptr;
@ -40,6 +40,9 @@ AnfNodePtr GetReplaceNode(const AnfNodePtr &node) {
if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) {
return nullptr;
}
if (!IsNotRealUsedByOthers(func_graph, cnode)) {
return nullptr;
}
CheckCNodeInputSize(cnode, kSingleInputIndex + 1);
return cnode->input(kSingleInputIndex);
}
@ -50,10 +53,11 @@ AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnod
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) {
return nullptr;
}
std::vector<AnfNodePtr> new_make_tuple_inputs;
std::vector<AnfNodePtr> new_make_tuple_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
bool need_update = false;
for (const auto &input : cnode->inputs()) {
AnfNodePtr replace_input = GetReplaceNode(input);
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) {
auto input = AnfAlgo::GetInputNode(cnode, index);
AnfNodePtr replace_input = GetReplaceNode(func_graph, input);
// If replace input is not null, it will be the input of the TransData or Cast.
if (replace_input == nullptr) {
new_make_tuple_inputs.push_back(input);
@ -149,7 +153,7 @@ const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, c
if (make_tuple_replace_node != nullptr) {
return make_tuple_replace_node;
}
AnfNodePtr replace_node = GetReplaceNode(replacing_cnode);
AnfNodePtr replace_node = GetReplaceNode(graph, replacing_cnode);
if (replace_node == nullptr) {
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString();
return replacing_node;

View File

@ -392,8 +392,8 @@ std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
} else {
host_shape = AnfAlgo::GetOutputInferShape(node, index);
}
if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, 0), host_shape.size())) {
host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, 0));
if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, index), host_shape.size())) {
host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, index));
}
std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt);
return shape;