|
|
|
@ -26,13 +26,16 @@
|
|
|
|
|
#include "tools/converter/parser/parser_utils.h"
|
|
|
|
|
#include "tools/optimizer/common/gllo_utils.h"
|
|
|
|
|
#include "nnacl/op_base.h"
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr const int kSwitchTruePartialIndex = 2;
|
|
|
|
|
constexpr const int kSwitchFalsePartialIndex = 3;
|
|
|
|
|
constexpr const int kSwitchInputSize = 4;
|
|
|
|
|
constexpr const int kSwitchLayerInputSize = 3;
|
|
|
|
|
constexpr const int kSwitchLayerMakeTupleIndex = 2;
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace lite {
|
|
|
|
|
constexpr const int kSwitchTruePartialIndex = 2;
|
|
|
|
|
constexpr const int kSwitchFalsePartialIndex = 3;
|
|
|
|
|
constexpr const int kPartialFgVnodeIndex = 1;
|
|
|
|
|
|
|
|
|
|
bool MindIRControlFlowAdjust::HasCallAfter(const FuncGraphPtr &partial_fg) {
|
|
|
|
|
MS_CHECK_TRUE_MSG(partial_fg != nullptr, false, "partial_fg is nullptr.");
|
|
|
|
|
auto output_node = partial_fg->output();
|
|
|
|
@ -229,6 +232,88 @@ int MindIRControlFlowAdjust::AddAfterFgForInlinedFg(const std::set<FuncGraphPtr>
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int MindIRControlFlowAdjust::MoveCallInputsToPartialFusionInputs(const std::set<FuncGraphPtr> &all_func_graphs) {
|
|
|
|
|
for (auto &graph : all_func_graphs) {
|
|
|
|
|
auto node_list = TopoSort(graph->get_return());
|
|
|
|
|
for (auto &node : node_list) {
|
|
|
|
|
if (!IsCall(node)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto call_cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_ASSERT(call_node != nullptr);
|
|
|
|
|
auto call_cnode_inputs = call_cnode->inputs();
|
|
|
|
|
if (call_cnode_inputs.size() == 1) {
|
|
|
|
|
MS_LOG(DEBUG) << "no need move call inputs.";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto call_first_input = call_cnode->input(0);
|
|
|
|
|
if (!utils::isa<CNodePtr>(call_first_input)) {
|
|
|
|
|
// This situation will be handled in the InsertPartialFusionForRawCall function
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto call_first_input_cnode = call_first_input->cast<CNodePtr>();
|
|
|
|
|
MS_ASSERT(call_first_input_cnode != nullptr);
|
|
|
|
|
if (IsPartialFusion(call_first_input_cnode)) {
|
|
|
|
|
auto partial_cnode_inputs = call_first_input_cnode->inputs();
|
|
|
|
|
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(), std::back_inserter(partial_cnode_inputs));
|
|
|
|
|
call_first_input_cnode->set_inputs(partial_cnode_inputs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (IsSwitch(call_first_input_cnode)) {
|
|
|
|
|
auto switch_cnode_inputs = call_first_input_cnode->inputs();
|
|
|
|
|
if (switch_cnode_inputs.size() == kSwitchInputSize) {
|
|
|
|
|
MS_LOG(ERROR) << "switch op inputs size not right.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (!IsPartialFusion(switch_cnode_inputs[kSwitchTruePartialIndex]) ||
|
|
|
|
|
!IsPartialFusion(switch_cnode_inputs[kSwitchFalsePartialIndex])) {
|
|
|
|
|
MS_LOG(ERROR) << "switch inputs not are partial ops, not support now.";
|
|
|
|
|
return RET_NOT_SUPPORT;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto true_partial_cnode = switch_cnode_inputs.at(kSwitchTruePartialIndex)->cast<CNodePtr>();
|
|
|
|
|
auto true_partial_cnode_inputs = true_partial_cnode->inputs();
|
|
|
|
|
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(),
|
|
|
|
|
std::back_inserter(true_partial_cnode_inputs));
|
|
|
|
|
true_partial_cnode->set_inputs(true_partial_cnode_inputs);
|
|
|
|
|
|
|
|
|
|
auto false_partial_cnode = switch_cnode_inputs.at(kSwitchFalsePartialIndex)->cast<CNodePtr>();
|
|
|
|
|
auto false_partial_cnode_inputs = false_partial_cnode->inputs();
|
|
|
|
|
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(),
|
|
|
|
|
std::back_inserter(false_partial_cnode_inputs));
|
|
|
|
|
false_partial_cnode->set_inputs(false_partial_cnode_inputs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (IsSwitchLayer(call_first_input_cnode)) {
|
|
|
|
|
auto switch_layer_cnode_inputs = call_first_input_cnode->inputs();
|
|
|
|
|
if (switch_layer_cnode_inputs.size() != kSwitchLayerInputSize) {
|
|
|
|
|
MS_LOG(ERROR) << "switch layer op inputs size not right.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (!IsMakeTuple(switch_layer_cnode_inputs[kSwitchLayerMakeTupleIndex])) {
|
|
|
|
|
MS_LOG(ERROR) << "SwitchLayer op last input not is MakeTuple ops, not support now.";
|
|
|
|
|
return RET_NOT_SUPPORT;
|
|
|
|
|
}
|
|
|
|
|
auto make_tuple_op = switch_layer_cnode_inputs[kSwitchLayerMakeTupleIndex]->cast<CNodePtr>();
|
|
|
|
|
auto make_tuple_op_intpus = make_tuple_op->inputs();
|
|
|
|
|
for (size_t i = 1; i < make_tuple_op_intpus.size(); i++) {
|
|
|
|
|
if (!IsPartialFusion(make_tuple_op_intpus[i])) {
|
|
|
|
|
MS_LOG(ERROR) << "switch layer op make tuple inputs not is partial fusion op, not support now.";
|
|
|
|
|
return RET_NOT_SUPPORT;
|
|
|
|
|
}
|
|
|
|
|
auto partial_node = make_tuple_op_intpus[i]->cast<CNodePtr>();
|
|
|
|
|
auto partial_node_inputs = partial_node->inputs();
|
|
|
|
|
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(), std::back_inserter(partial_node_inputs));
|
|
|
|
|
partial_node->set_inputs(partial_node_inputs);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
call_cnode->set_inputs({call_first_input_cnode});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int MindIRControlFlowAdjust::InsertPartialFusionForRawCall(const std::set<FuncGraphPtr> &all_func_graphs) {
|
|
|
|
|
for (auto &graph : all_func_graphs) {
|
|
|
|
|
auto node_list = TopoSort(graph->get_return());
|
|
|
|
@ -239,14 +324,13 @@ int MindIRControlFlowAdjust::InsertPartialFusionForRawCall(const std::set<FuncGr
|
|
|
|
|
auto call_cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_ASSERT(call_node != nullptr);
|
|
|
|
|
auto call_cnode_inputs = call_cnode->inputs();
|
|
|
|
|
auto cnode_first_input = call_cnode->input(0);
|
|
|
|
|
if (!utils::isa<ValueNodePtr>(cnode_first_input)) {
|
|
|
|
|
auto call_first_input = call_cnode->input(0);
|
|
|
|
|
if (!utils::isa<ValueNodePtr>(call_first_input)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (GetValueNode<FuncGraphPtr>(cnode_first_input->cast<ValueNodePtr>()) == nullptr) {
|
|
|
|
|
if (GetValueNode<FuncGraphPtr>(call_first_input->cast<ValueNodePtr>()) == nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> partial_cnode_inputs = {lite::GetPartialFusionPrim()};
|
|
|
|
|
std::copy(call_cnode_inputs.begin(), call_cnode_inputs.end(), std::back_inserter(partial_cnode_inputs));
|
|
|
|
|
auto partial_cnode = graph->NewCNode(partial_cnode_inputs);
|
|
|
|
@ -285,7 +369,12 @@ bool MindIRControlFlowAdjust::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
|
MS_LOG(INFO) << "Not is control flow model.";
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
int ret = InsertPartialFusionForRawCall(all_func_graphs);
|
|
|
|
|
int ret = MoveCallInputsToPartialFusionInputs(all_func_graphs);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "MoveCallInputsToPartialFusionInputs failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
ret = InsertPartialFusionForRawCall(all_func_graphs);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "InsertPartialFusionForRawCall failed.";
|
|
|
|
|
return false;
|
|
|
|
|