forked from mindspore-Ecosystem/mindspore
fix circle bug of opt depend && merge cast
This commit is contained in:
parent
e3899c552c
commit
e95b42496c
|
@ -71,7 +71,8 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
} else if (is_insert_input) {
|
||||
// if need padding & is input need insert a transdata
|
||||
// reshape[padding shape] -> transdata[padding shape] -> node
|
||||
auto padding_shape = trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, 0));
|
||||
auto padding_shape =
|
||||
trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, insert_index));
|
||||
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
|
||||
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
||||
trans_node = trans_data;
|
||||
|
|
|
@ -553,6 +553,30 @@ bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
|||
return output_node_list->size() > 1;
|
||||
}
|
||||
|
||||
bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto output_node_list = GetRealNodeUsedList(graph, node);
|
||||
MS_EXCEPTION_IF_NULL(output_node_list);
|
||||
if (output_node_list->empty()) {
|
||||
return true;
|
||||
}
|
||||
for (const auto &output : *output_node_list) {
|
||||
auto out_node = output.first;
|
||||
auto name = AnfAlgo::GetCNodeName(out_node);
|
||||
if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() ||
|
||||
name == prim::kPrimTupleGetItem->name()) {
|
||||
auto result = IsNotRealUsedByOthers(graph, out_node);
|
||||
if (!result) {
|
||||
return result;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) {
|
||||
auto idx = NewValueNode(SizeToInt(output_idx));
|
||||
MS_EXCEPTION_IF_NULL(idx);
|
||||
|
|
|
@ -175,6 +175,7 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con
|
|||
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
|
||||
const AnfNodePtr &node,
|
||||
size_t output_index);
|
||||
bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);
|
||||
|
||||
void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs);
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
constexpr auto kSingleInputIndex = 1;
|
||||
namespace {
|
||||
AnfNodePtr GetReplaceNode(const AnfNodePtr &node) {
|
||||
AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
|
@ -40,6 +40,9 @@ AnfNodePtr GetReplaceNode(const AnfNodePtr &node) {
|
|||
if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!IsNotRealUsedByOthers(func_graph, cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
CheckCNodeInputSize(cnode, kSingleInputIndex + 1);
|
||||
return cnode->input(kSingleInputIndex);
|
||||
}
|
||||
|
@ -50,10 +53,11 @@ AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnod
|
|||
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> new_make_tuple_inputs;
|
||||
std::vector<AnfNodePtr> new_make_tuple_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||
bool need_update = false;
|
||||
for (const auto &input : cnode->inputs()) {
|
||||
AnfNodePtr replace_input = GetReplaceNode(input);
|
||||
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) {
|
||||
auto input = AnfAlgo::GetInputNode(cnode, index);
|
||||
AnfNodePtr replace_input = GetReplaceNode(func_graph, input);
|
||||
// If replace input is not null, it will be the input of the TransData or Cast.
|
||||
if (replace_input == nullptr) {
|
||||
new_make_tuple_inputs.push_back(input);
|
||||
|
@ -149,7 +153,7 @@ const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, c
|
|||
if (make_tuple_replace_node != nullptr) {
|
||||
return make_tuple_replace_node;
|
||||
}
|
||||
AnfNodePtr replace_node = GetReplaceNode(replacing_cnode);
|
||||
AnfNodePtr replace_node = GetReplaceNode(graph, replacing_cnode);
|
||||
if (replace_node == nullptr) {
|
||||
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString();
|
||||
return replacing_node;
|
||||
|
|
|
@ -392,8 +392,8 @@ std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
|
|||
} else {
|
||||
host_shape = AnfAlgo::GetOutputInferShape(node, index);
|
||||
}
|
||||
if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, 0), host_shape.size())) {
|
||||
host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, 0));
|
||||
if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, index), host_shape.size())) {
|
||||
host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, index));
|
||||
}
|
||||
std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt);
|
||||
return shape;
|
||||
|
|
Loading…
Reference in New Issue