forked from mindspore-Ecosystem/mindspore
move call inputs to partial op intpus
This commit is contained in:
parent
289986fdfa
commit
c9c94b7340
|
@ -422,7 +422,7 @@ int LiteSession::IsolateOutputTensor() {
|
|||
Tensor *new_tensor =
|
||||
new Tensor(src_tensor->data_type(), src_tensor->shape(), src_tensor->format(), Tensor::GRAPH_OUTPUT);
|
||||
if (new_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "duplicate new outptu failed.";
|
||||
MS_LOG(ERROR) << "duplicate new output failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
new_tensor->set_allocator(src_tensor->allocator()); /* GPU use opencl allocator */
|
||||
|
|
|
@ -122,7 +122,7 @@ int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vecto
|
|||
std::vector<TensorC *> in_tensors;
|
||||
std::vector<TensorC *> out_tensors;
|
||||
if (parameter->type_ == schema::PrimitiveType_PartialFusion || parameter->type_ == schema::PrimitiveType_Switch ||
|
||||
parameter->type_ == schema::PrimitiveType_Call) {
|
||||
parameter->type_ == schema::PrimitiveType_Call || parameter->type_ == schema::PrimitiveType_SwitchLayer) {
|
||||
MS_LOG(INFO) << "no need infer shape.";
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -392,6 +392,16 @@ bool IsSwitch(const AnfNodePtr &node) {
|
|||
return opt::CheckPrimitiveType(node, prim::kPrimSwitch);
|
||||
}
|
||||
|
||||
bool IsSwitchLayer(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
return false;
|
||||
}
|
||||
return opt::CheckPrimitiveType(node, prim::kPrimSwitchLayer);
|
||||
}
|
||||
|
||||
bool IsMakeTuple(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
|
|
|
@ -419,6 +419,8 @@ bool IsCall(const AnfNodePtr &node);
|
|||
|
||||
bool IsSwitch(const AnfNodePtr &node);
|
||||
|
||||
bool IsSwitchLayer(const AnfNodePtr &node);
|
||||
|
||||
bool IsMakeTuple(const AnfNodePtr &node);
|
||||
|
||||
ValueNodePtr GetPartialFusionPrim();
|
||||
|
|
|
@ -26,13 +26,16 @@
|
|||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
namespace {
|
||||
constexpr const int kSwitchTruePartialIndex = 2;
|
||||
constexpr const int kSwitchFalsePartialIndex = 3;
|
||||
constexpr const int kSwitchInputSize = 4;
|
||||
constexpr const int kSwitchLayerInputSize = 3;
|
||||
constexpr const int kSwitchLayerMakeTupleIndex = 2;
|
||||
} // namespace
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
constexpr const int kSwitchTruePartialIndex = 2;
|
||||
constexpr const int kSwitchFalsePartialIndex = 3;
|
||||
constexpr const int kPartialFgVnodeIndex = 1;
|
||||
|
||||
bool MindIRControlFlowAdjust::HasCallAfter(const FuncGraphPtr &partial_fg) {
|
||||
MS_CHECK_TRUE_MSG(partial_fg != nullptr, false, "partial_fg is nullptr.");
|
||||
auto output_node = partial_fg->output();
|
||||
|
@ -229,6 +232,88 @@ int MindIRControlFlowAdjust::AddAfterFgForInlinedFg(const std::set<FuncGraphPtr>
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int MindIRControlFlowAdjust::MoveCallInputsToPartialFusionInputs(const std::set<FuncGraphPtr> &all_func_graphs) {
|
||||
for (auto &graph : all_func_graphs) {
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!IsCall(node)) {
|
||||
continue;
|
||||
}
|
||||
auto call_cnode = node->cast<CNodePtr>();
|
||||
MS_ASSERT(call_node != nullptr);
|
||||
auto call_cnode_inputs = call_cnode->inputs();
|
||||
if (call_cnode_inputs.size() == 1) {
|
||||
MS_LOG(DEBUG) << "no need move call inputs.";
|
||||
continue;
|
||||
}
|
||||
auto call_first_input = call_cnode->input(0);
|
||||
if (!utils::isa<CNodePtr>(call_first_input)) {
|
||||
// This situation will be handled in the InsertPartialFusionForRawCall function
|
||||
continue;
|
||||
}
|
||||
auto call_first_input_cnode = call_first_input->cast<CNodePtr>();
|
||||
MS_ASSERT(call_first_input_cnode != nullptr);
|
||||
if (IsPartialFusion(call_first_input_cnode)) {
|
||||
auto partial_cnode_inputs = call_first_input_cnode->inputs();
|
||||
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(), std::back_inserter(partial_cnode_inputs));
|
||||
call_first_input_cnode->set_inputs(partial_cnode_inputs);
|
||||
}
|
||||
|
||||
if (IsSwitch(call_first_input_cnode)) {
|
||||
auto switch_cnode_inputs = call_first_input_cnode->inputs();
|
||||
if (switch_cnode_inputs.size() == kSwitchInputSize) {
|
||||
MS_LOG(ERROR) << "switch op inputs size not right.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!IsPartialFusion(switch_cnode_inputs[kSwitchTruePartialIndex]) ||
|
||||
!IsPartialFusion(switch_cnode_inputs[kSwitchFalsePartialIndex])) {
|
||||
MS_LOG(ERROR) << "switch inputs not are partial ops, not support now.";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
||||
auto true_partial_cnode = switch_cnode_inputs.at(kSwitchTruePartialIndex)->cast<CNodePtr>();
|
||||
auto true_partial_cnode_inputs = true_partial_cnode->inputs();
|
||||
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(),
|
||||
std::back_inserter(true_partial_cnode_inputs));
|
||||
true_partial_cnode->set_inputs(true_partial_cnode_inputs);
|
||||
|
||||
auto false_partial_cnode = switch_cnode_inputs.at(kSwitchFalsePartialIndex)->cast<CNodePtr>();
|
||||
auto false_partial_cnode_inputs = false_partial_cnode->inputs();
|
||||
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(),
|
||||
std::back_inserter(false_partial_cnode_inputs));
|
||||
false_partial_cnode->set_inputs(false_partial_cnode_inputs);
|
||||
}
|
||||
|
||||
if (IsSwitchLayer(call_first_input_cnode)) {
|
||||
auto switch_layer_cnode_inputs = call_first_input_cnode->inputs();
|
||||
if (switch_layer_cnode_inputs.size() != kSwitchLayerInputSize) {
|
||||
MS_LOG(ERROR) << "switch layer op inputs size not right.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!IsMakeTuple(switch_layer_cnode_inputs[kSwitchLayerMakeTupleIndex])) {
|
||||
MS_LOG(ERROR) << "SwitchLayer op last input not is MakeTuple ops, not support now.";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
auto make_tuple_op = switch_layer_cnode_inputs[kSwitchLayerMakeTupleIndex]->cast<CNodePtr>();
|
||||
auto make_tuple_op_intpus = make_tuple_op->inputs();
|
||||
for (size_t i = 1; i < make_tuple_op_intpus.size(); i++) {
|
||||
if (!IsPartialFusion(make_tuple_op_intpus[i])) {
|
||||
MS_LOG(ERROR) << "switch layer op make tuple inputs not is partial fusion op, not support now.";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
auto partial_node = make_tuple_op_intpus[i]->cast<CNodePtr>();
|
||||
auto partial_node_inputs = partial_node->inputs();
|
||||
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(), std::back_inserter(partial_node_inputs));
|
||||
partial_node->set_inputs(partial_node_inputs);
|
||||
}
|
||||
}
|
||||
|
||||
call_cnode->set_inputs({call_first_input_cnode});
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MindIRControlFlowAdjust::InsertPartialFusionForRawCall(const std::set<FuncGraphPtr> &all_func_graphs) {
|
||||
for (auto &graph : all_func_graphs) {
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
|
@ -239,14 +324,13 @@ int MindIRControlFlowAdjust::InsertPartialFusionForRawCall(const std::set<FuncGr
|
|||
auto call_cnode = node->cast<CNodePtr>();
|
||||
MS_ASSERT(call_node != nullptr);
|
||||
auto call_cnode_inputs = call_cnode->inputs();
|
||||
auto cnode_first_input = call_cnode->input(0);
|
||||
if (!utils::isa<ValueNodePtr>(cnode_first_input)) {
|
||||
auto call_first_input = call_cnode->input(0);
|
||||
if (!utils::isa<ValueNodePtr>(call_first_input)) {
|
||||
continue;
|
||||
}
|
||||
if (GetValueNode<FuncGraphPtr>(cnode_first_input->cast<ValueNodePtr>()) == nullptr) {
|
||||
if (GetValueNode<FuncGraphPtr>(call_first_input->cast<ValueNodePtr>()) == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> partial_cnode_inputs = {lite::GetPartialFusionPrim()};
|
||||
std::copy(call_cnode_inputs.begin(), call_cnode_inputs.end(), std::back_inserter(partial_cnode_inputs));
|
||||
auto partial_cnode = graph->NewCNode(partial_cnode_inputs);
|
||||
|
@ -285,7 +369,12 @@ bool MindIRControlFlowAdjust::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(INFO) << "Not is control flow model.";
|
||||
return true;
|
||||
}
|
||||
int ret = InsertPartialFusionForRawCall(all_func_graphs);
|
||||
int ret = MoveCallInputsToPartialFusionInputs(all_func_graphs);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "MoveCallInputsToPartialFusionInputs failed.";
|
||||
return false;
|
||||
}
|
||||
ret = InsertPartialFusionForRawCall(all_func_graphs);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertPartialFusionForRawCall failed.";
|
||||
return false;
|
||||
|
|
|
@ -43,6 +43,7 @@ class MindIRControlFlowAdjust {
|
|||
int InsertPartialFusionForRawCall(const std::set<FuncGraphPtr> &all_func_graphs);
|
||||
CNodePtr GetMainFgSwitchNode(const FuncGraphPtr &fg);
|
||||
int ResetFuncGraph(const FuncGraphPtr &fg, std::set<FuncGraphPtr> all_func_graphs);
|
||||
int MoveCallInputsToPartialFusionInputs(const std::set<FuncGraphPtr> &all_func_graphs);
|
||||
|
||||
private:
|
||||
FmkType fmk_type_ = FmkType::kFmkTypeMs;
|
||||
|
|
Loading…
Reference in New Issue