!25604 [MS][LITE]move call inputs to partial inputs

Merge pull request !25604 from mengyuanli/support_switch_layer
This commit is contained in:
i-robot 2021-10-30 06:07:38 +00:00 committed by Gitee
commit 6f50a72575
6 changed files with 113 additions and 11 deletions

View File

@ -422,7 +422,7 @@ int LiteSession::IsolateOutputTensor() {
Tensor *new_tensor =
new Tensor(src_tensor->data_type(), src_tensor->shape(), src_tensor->format(), Tensor::GRAPH_OUTPUT);
if (new_tensor == nullptr) {
MS_LOG(ERROR) << "duplicate new outptu failed.";
MS_LOG(ERROR) << "duplicate new output failed.";
return RET_NULL_PTR;
}
new_tensor->set_allocator(src_tensor->allocator()); /* GPU use opencl allocator */

View File

@ -122,7 +122,7 @@ int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vecto
std::vector<TensorC *> in_tensors;
std::vector<TensorC *> out_tensors;
if (parameter->type_ == schema::PrimitiveType_PartialFusion || parameter->type_ == schema::PrimitiveType_Switch ||
parameter->type_ == schema::PrimitiveType_Call) {
parameter->type_ == schema::PrimitiveType_Call || parameter->type_ == schema::PrimitiveType_SwitchLayer) {
MS_LOG(INFO) << "no need infer shape.";
return RET_OK;
}

View File

@ -392,6 +392,16 @@ bool IsSwitch(const AnfNodePtr &node) {
return opt::CheckPrimitiveType(node, prim::kPrimSwitch);
}
bool IsSwitchLayer(const AnfNodePtr &node) {
if (node == nullptr) {
return false;
}
if (!utils::isa<CNodePtr>(node)) {
return false;
}
return opt::CheckPrimitiveType(node, prim::kPrimSwitchLayer);
}
bool IsMakeTuple(const AnfNodePtr &node) {
if (node == nullptr) {
return false;

View File

@ -419,6 +419,8 @@ bool IsCall(const AnfNodePtr &node);
bool IsSwitch(const AnfNodePtr &node);
bool IsSwitchLayer(const AnfNodePtr &node);
bool IsMakeTuple(const AnfNodePtr &node);
ValueNodePtr GetPartialFusionPrim();

View File

@ -26,13 +26,16 @@
#include "tools/converter/parser/parser_utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "nnacl/op_base.h"
namespace {
constexpr const int kSwitchTruePartialIndex = 2;
constexpr const int kSwitchFalsePartialIndex = 3;
constexpr const int kSwitchInputSize = 4;
constexpr const int kSwitchLayerInputSize = 3;
constexpr const int kSwitchLayerMakeTupleIndex = 2;
} // namespace
namespace mindspore {
namespace lite {
constexpr const int kSwitchTruePartialIndex = 2;
constexpr const int kSwitchFalsePartialIndex = 3;
constexpr const int kPartialFgVnodeIndex = 1;
bool MindIRControlFlowAdjust::HasCallAfter(const FuncGraphPtr &partial_fg) {
MS_CHECK_TRUE_MSG(partial_fg != nullptr, false, "partial_fg is nullptr.");
auto output_node = partial_fg->output();
@ -229,6 +232,88 @@ int MindIRControlFlowAdjust::AddAfterFgForInlinedFg(const std::set<FuncGraphPtr>
return RET_OK;
}
int MindIRControlFlowAdjust::MoveCallInputsToPartialFusionInputs(const std::set<FuncGraphPtr> &all_func_graphs) {
for (auto &graph : all_func_graphs) {
auto node_list = TopoSort(graph->get_return());
for (auto &node : node_list) {
if (!IsCall(node)) {
continue;
}
auto call_cnode = node->cast<CNodePtr>();
MS_ASSERT(call_node != nullptr);
auto call_cnode_inputs = call_cnode->inputs();
if (call_cnode_inputs.size() == 1) {
MS_LOG(DEBUG) << "no need move call inputs.";
continue;
}
auto call_first_input = call_cnode->input(0);
if (!utils::isa<CNodePtr>(call_first_input)) {
// This situation will be handled in the InsertPartialFusionForRawCall function
continue;
}
auto call_first_input_cnode = call_first_input->cast<CNodePtr>();
MS_ASSERT(call_first_input_cnode != nullptr);
if (IsPartialFusion(call_first_input_cnode)) {
auto partial_cnode_inputs = call_first_input_cnode->inputs();
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(), std::back_inserter(partial_cnode_inputs));
call_first_input_cnode->set_inputs(partial_cnode_inputs);
}
if (IsSwitch(call_first_input_cnode)) {
auto switch_cnode_inputs = call_first_input_cnode->inputs();
if (switch_cnode_inputs.size() == kSwitchInputSize) {
MS_LOG(ERROR) << "switch op inputs size not right.";
return RET_ERROR;
}
if (!IsPartialFusion(switch_cnode_inputs[kSwitchTruePartialIndex]) ||
!IsPartialFusion(switch_cnode_inputs[kSwitchFalsePartialIndex])) {
MS_LOG(ERROR) << "switch inputs not are partial ops, not support now.";
return RET_NOT_SUPPORT;
}
auto true_partial_cnode = switch_cnode_inputs.at(kSwitchTruePartialIndex)->cast<CNodePtr>();
auto true_partial_cnode_inputs = true_partial_cnode->inputs();
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(),
std::back_inserter(true_partial_cnode_inputs));
true_partial_cnode->set_inputs(true_partial_cnode_inputs);
auto false_partial_cnode = switch_cnode_inputs.at(kSwitchFalsePartialIndex)->cast<CNodePtr>();
auto false_partial_cnode_inputs = false_partial_cnode->inputs();
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(),
std::back_inserter(false_partial_cnode_inputs));
false_partial_cnode->set_inputs(false_partial_cnode_inputs);
}
if (IsSwitchLayer(call_first_input_cnode)) {
auto switch_layer_cnode_inputs = call_first_input_cnode->inputs();
if (switch_layer_cnode_inputs.size() != kSwitchLayerInputSize) {
MS_LOG(ERROR) << "switch layer op inputs size not right.";
return RET_ERROR;
}
if (!IsMakeTuple(switch_layer_cnode_inputs[kSwitchLayerMakeTupleIndex])) {
MS_LOG(ERROR) << "SwitchLayer op last input not is MakeTuple ops, not support now.";
return RET_NOT_SUPPORT;
}
auto make_tuple_op = switch_layer_cnode_inputs[kSwitchLayerMakeTupleIndex]->cast<CNodePtr>();
auto make_tuple_op_intpus = make_tuple_op->inputs();
for (size_t i = 1; i < make_tuple_op_intpus.size(); i++) {
if (!IsPartialFusion(make_tuple_op_intpus[i])) {
MS_LOG(ERROR) << "switch layer op make tuple inputs not is partial fusion op, not support now.";
return RET_NOT_SUPPORT;
}
auto partial_node = make_tuple_op_intpus[i]->cast<CNodePtr>();
auto partial_node_inputs = partial_node->inputs();
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(), std::back_inserter(partial_node_inputs));
partial_node->set_inputs(partial_node_inputs);
}
}
call_cnode->set_inputs({call_first_input_cnode});
}
}
return RET_OK;
}
int MindIRControlFlowAdjust::InsertPartialFusionForRawCall(const std::set<FuncGraphPtr> &all_func_graphs) {
for (auto &graph : all_func_graphs) {
auto node_list = TopoSort(graph->get_return());
@ -239,14 +324,13 @@ int MindIRControlFlowAdjust::InsertPartialFusionForRawCall(const std::set<FuncGr
auto call_cnode = node->cast<CNodePtr>();
MS_ASSERT(call_node != nullptr);
auto call_cnode_inputs = call_cnode->inputs();
auto cnode_first_input = call_cnode->input(0);
if (!utils::isa<ValueNodePtr>(cnode_first_input)) {
auto call_first_input = call_cnode->input(0);
if (!utils::isa<ValueNodePtr>(call_first_input)) {
continue;
}
if (GetValueNode<FuncGraphPtr>(cnode_first_input->cast<ValueNodePtr>()) == nullptr) {
if (GetValueNode<FuncGraphPtr>(call_first_input->cast<ValueNodePtr>()) == nullptr) {
continue;
}
std::vector<AnfNodePtr> partial_cnode_inputs = {lite::GetPartialFusionPrim()};
std::copy(call_cnode_inputs.begin(), call_cnode_inputs.end(), std::back_inserter(partial_cnode_inputs));
auto partial_cnode = graph->NewCNode(partial_cnode_inputs);
@ -285,7 +369,12 @@ bool MindIRControlFlowAdjust::Run(const FuncGraphPtr &func_graph) {
MS_LOG(INFO) << "Not is control flow model.";
return true;
}
int ret = InsertPartialFusionForRawCall(all_func_graphs);
int ret = MoveCallInputsToPartialFusionInputs(all_func_graphs);
if (ret != RET_OK) {
MS_LOG(ERROR) << "MoveCallInputsToPartialFusionInputs failed.";
return false;
}
ret = InsertPartialFusionForRawCall(all_func_graphs);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InsertPartialFusionForRawCall failed.";
return false;

View File

@ -43,6 +43,7 @@ class MindIRControlFlowAdjust {
int InsertPartialFusionForRawCall(const std::set<FuncGraphPtr> &all_func_graphs);
CNodePtr GetMainFgSwitchNode(const FuncGraphPtr &fg);
int ResetFuncGraph(const FuncGraphPtr &fg, std::set<FuncGraphPtr> all_func_graphs);
int MoveCallInputsToPartialFusionInputs(const std::set<FuncGraphPtr> &all_func_graphs);
private:
FmkType fmk_type_ = FmkType::kFmkTypeMs;