forked from mindspore-Ecosystem/mindspore
!42248 [lite]optimize SplitReduceConcatFusion
Merge pull request !42248 from 徐安越/r1.8_2
This commit is contained in:
commit
fbfbbc6630
|
@ -32,14 +32,12 @@ int CalShape(const int *data, const TensorC *const *inputs, int *out_shape, size
|
|||
}
|
||||
ShapePush(out_shape, out_shape_size, data[i]);
|
||||
}
|
||||
if (size == 0) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
|
||||
if ((int)(data[index]) == -1) {
|
||||
if (index >= MAX_SHAPE_SIZE) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
out_shape[index] = input_count / size;
|
||||
out_shape[index] = size == 0 ? 0 : input_count / size;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -1207,7 +1207,8 @@ std::vector<std::vector<uint32_t>> SearchSubGraph::GetNextNodeIndex(LiteGraph::N
|
|||
return next_node;
|
||||
}
|
||||
|
||||
bool SearchSubGraph::SatifyReduceReshapeConcatParse(Subgraph *subgraph, uint32_t in_node, int split_concat_axis) {
|
||||
bool SearchSubGraph::SatifyReduceReshapeConcatParse(Subgraph *subgraph, uint32_t in_node, int split_concat_axis,
|
||||
std::vector<uint32_t> *positions) {
|
||||
// reduce op
|
||||
uint32_t reduce_node1_index = in_node;
|
||||
auto &reduce_node1 = node_list_.at(reduce_node1_index);
|
||||
|
@ -1255,11 +1256,17 @@ bool SearchSubGraph::SatifyReduceReshapeConcatParse(Subgraph *subgraph, uint32_t
|
|||
} else if (subgraph->ends_.at(0) != concat_node1_index) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto reduce_out = reduce_node1->output_indices_.front();
|
||||
const auto &concat_in = concat_node1->input_indices_;
|
||||
for (size_t i = 0; i < concat_in.size(); ++i) {
|
||||
if (concat_in[i] == reduce_out) {
|
||||
positions->push_back(i);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void SearchSubGraph::DeleteOriginNode(Subgraph *subgraph) {
|
||||
void SearchSubGraph::DeleteOriginNode(Subgraph *subgraph, const std::vector<uint32_t> &positions) {
|
||||
auto sub_graph = model_->graph_.sub_graphs_.at(0);
|
||||
auto &subgraph_node_indices = sub_graph->node_indices_;
|
||||
for (auto &node_index : subgraph->nodes_) {
|
||||
|
@ -1278,21 +1285,26 @@ void SearchSubGraph::DeleteOriginNode(Subgraph *subgraph) {
|
|||
for (auto &node_index : subgraph->ends_) {
|
||||
// delete tensors_ tensor info
|
||||
auto &input_indices = node_list_.at(node_index)->input_indices_;
|
||||
for (auto input_indice : input_indices) {
|
||||
tensors_.at(input_indice).in_nodes_.clear();
|
||||
tensors_.at(input_indice).out_nodes_.clear();
|
||||
bool reserve_concat = input_indices.size() != positions.size();
|
||||
for (size_t i = (reserve_concat ? 1 : 0); i < positions.size(); ++i) {
|
||||
tensors_.at(input_indices[positions[i]]).in_nodes_.clear();
|
||||
tensors_.at(input_indices[positions[i]]).out_nodes_.clear();
|
||||
}
|
||||
auto &output_indices = node_list_.at(node_index)->output_indices_;
|
||||
auto &output_indices = node_list_.at(subgraph->heads_.front())->output_indices_;
|
||||
for (auto output_indice : output_indices) {
|
||||
tensors_.at(output_indice).out_nodes_.clear();
|
||||
tensors_.at(output_indice).out_nodes_.emplace_back(subgraph->heads_.front());
|
||||
}
|
||||
|
||||
node_list_.at(node_index)->input_indices_.clear();
|
||||
node_list_.at(node_index)->output_indices_.clear();
|
||||
|
||||
auto indice_itr = std::find(subgraph_node_indices.begin(), subgraph_node_indices.end(), node_index);
|
||||
subgraph_node_indices.erase(indice_itr);
|
||||
if (reserve_concat) {
|
||||
for (size_t i = positions.size() - 1; i > 0; --i) {
|
||||
input_indices.erase(input_indices.begin() + positions[i]);
|
||||
}
|
||||
} else {
|
||||
node_list_.at(node_index)->input_indices_.clear();
|
||||
node_list_.at(node_index)->output_indices_.clear();
|
||||
auto indice_itr = std::find(subgraph_node_indices.begin(), subgraph_node_indices.end(), node_index);
|
||||
subgraph_node_indices.erase(indice_itr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1308,62 +1320,50 @@ bool SearchSubGraph::DoSplitReduceConcatFusion(uint32_t node_id) {
|
|||
if (split_param == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto ReleaseSplitParameter = [split_param]() {
|
||||
if (split_param->op_parameter_.destroy_func_ != nullptr) {
|
||||
split_param->op_parameter_.destroy_func_(&split_param->op_parameter_);
|
||||
}
|
||||
free(split_param);
|
||||
};
|
||||
auto split_concat_axis = split_param->split_dim_;
|
||||
if (split_concat_axis < 0) {
|
||||
if (split_param_base->destroy_func_ != nullptr) {
|
||||
split_param_base->destroy_func_(split_param_base);
|
||||
free(split_param);
|
||||
split_param_base = nullptr;
|
||||
split_param = nullptr;
|
||||
} else {
|
||||
free(split_param);
|
||||
split_param_base = nullptr;
|
||||
split_param = nullptr;
|
||||
}
|
||||
ReleaseSplitParameter();
|
||||
return false;
|
||||
}
|
||||
auto output_indices = node->output_indices_;
|
||||
Subgraph subgraph;
|
||||
subgraph.heads_.emplace_back(node_id);
|
||||
std::vector<uint32_t> positions;
|
||||
for (auto &output_indice : output_indices) {
|
||||
auto &tensor = tensors_.at(output_indice);
|
||||
auto &in_nodes = tensor.in_nodes_;
|
||||
for (auto &in_node : in_nodes) {
|
||||
// satisfy reduce + reshape + concat struct
|
||||
if (!SatifyReduceReshapeConcatParse(&subgraph, in_node, split_concat_axis)) {
|
||||
if (split_param_base->destroy_func_ != nullptr) {
|
||||
split_param_base->destroy_func_(split_param_base);
|
||||
free(split_param);
|
||||
split_param_base = nullptr;
|
||||
split_param = nullptr;
|
||||
} else {
|
||||
free(split_param);
|
||||
split_param_base = nullptr;
|
||||
split_param = nullptr;
|
||||
}
|
||||
// satisfy reduce + concat struct
|
||||
if (!SatifyReduceReshapeConcatParse(&subgraph, in_node, split_concat_axis, &positions)) {
|
||||
ReleaseSplitParameter();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// do fusion
|
||||
auto ret = CreateCustomNode(node, &subgraph, split_param);
|
||||
if (split_param_base->destroy_func_ != nullptr) {
|
||||
split_param_base->destroy_func_(split_param_base);
|
||||
free(split_param);
|
||||
split_param_base = nullptr;
|
||||
split_param = nullptr;
|
||||
} else {
|
||||
free(split_param);
|
||||
split_param_base = nullptr;
|
||||
split_param = nullptr;
|
||||
if (positions.size() != output_indices.size() || positions.empty()) {
|
||||
ReleaseSplitParameter();
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 1; i < positions.size(); ++i) {
|
||||
if (positions[i] - positions[i - 1] != 1) {
|
||||
ReleaseSplitParameter();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// do fusion
|
||||
auto ret = CreateCustomNode(node, &subgraph, split_param, positions);
|
||||
ReleaseSplitParameter();
|
||||
if (ret != RET_OK) {
|
||||
return false;
|
||||
}
|
||||
|
||||
DeleteOriginNode(&subgraph);
|
||||
DeleteOriginNode(&subgraph, positions);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -1378,7 +1378,8 @@ flatbuffers::Offset<mindspore::schema::Attribute> SetDataToUint8Vector(void *src
|
|||
return attr;
|
||||
}
|
||||
|
||||
int SearchSubGraph::CreateCustomNode(LiteGraph::Node *node, Subgraph *subgraph, SplitParameter *split_param) {
|
||||
int SearchSubGraph::CreateCustomNode(LiteGraph::Node *node, Subgraph *subgraph, SplitParameter *split_param,
|
||||
const std::vector<uint32_t> &positions) {
|
||||
MS_ASSERT(node != nullptr);
|
||||
flatbuffers::FlatBufferBuilder fbb(kInitialSize);
|
||||
|
||||
|
@ -1409,7 +1410,12 @@ int SearchSubGraph::CreateCustomNode(LiteGraph::Node *node, Subgraph *subgraph,
|
|||
node->primitive_ = online_fusion_prim;
|
||||
node->node_type_ = PrimType::PrimType_Inner_SplitReduceConcatFusion;
|
||||
node->input_indices_ = model_->graph_.all_nodes_.at(subgraph->heads_.front())->input_indices_;
|
||||
node->output_indices_ = model_->graph_.all_nodes_.at(subgraph->ends_.front())->output_indices_;
|
||||
if (positions.size() == model_->graph_.all_nodes_.at(subgraph->ends_.front())->input_indices_.size()) {
|
||||
node->output_indices_ = model_->graph_.all_nodes_.at(subgraph->ends_.front())->output_indices_;
|
||||
} else {
|
||||
node->output_indices_ = {
|
||||
model_->graph_.all_nodes_.at(subgraph->ends_.front())->input_indices_.at(positions.front())};
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -93,7 +93,8 @@ class SearchSubGraph {
|
|||
void InsertNodeBegin(uint32_t index, Subgraph *subgraph, std::vector<size_t> *outputs);
|
||||
void DoOnlineFusion();
|
||||
void DoSplitReduceConcatFusionPass();
|
||||
int CreateCustomNode(LiteGraph::Node *node, Subgraph *subgraph, SplitParameter *split_param);
|
||||
int CreateCustomNode(LiteGraph::Node *node, Subgraph *subgraph, SplitParameter *split_param,
|
||||
const std::vector<uint32_t> &positions);
|
||||
int ParseNodePrimitive(Subgraph *subgraph);
|
||||
OpParameter *GetNodeOpParameter(LiteGraph::Node *node);
|
||||
std::vector<std::vector<uint32_t>> GetNextNodeIndex(LiteGraph::Node *cur_node);
|
||||
|
@ -103,8 +104,9 @@ class SearchSubGraph {
|
|||
void InitSearchSubGraphByOutput();
|
||||
void InsertNode(uint32_t index, Subgraph *subgraph, uint32_t last_index);
|
||||
bool DoSplitReduceConcatFusion(uint32_t node_id);
|
||||
bool SatifyReduceReshapeConcatParse(Subgraph *subgraph, uint32_t node_id, int split_concat_axis);
|
||||
void DeleteOriginNode(Subgraph *subgraph);
|
||||
bool SatifyReduceReshapeConcatParse(Subgraph *subgraph, uint32_t node_id, int split_concat_axis,
|
||||
std::vector<uint32_t> *positions);
|
||||
void DeleteOriginNode(Subgraph *subgraph, const std::vector<uint32_t> &positions);
|
||||
|
||||
private: /* split by middle */
|
||||
void SubGraphSplitByMiddle();
|
||||
|
|
|
@ -49,6 +49,10 @@ const AnfNodePtr MulActivationFusion::Process(const FuncGraphPtr &func_graph, co
|
|||
MS_CHECK_TRUE_RET(act_cnode != nullptr, nullptr);
|
||||
auto act_prim = ops::GetOperator<ops::Activation>(act_cnode->input(0));
|
||||
MS_CHECK_TRUE_RET(act_prim != nullptr, nullptr);
|
||||
if (IsQuantParameterNode(act_prim->GetPrim())) {
|
||||
MS_LOG(INFO) << "node is a quant-node";
|
||||
return nullptr;
|
||||
}
|
||||
if (act_prim->get_activation_type() != ActivationType::RELU && act_prim->get_activation_type() != RELU6) {
|
||||
MS_LOG(INFO) << "activation is not relu or relu6";
|
||||
return nullptr;
|
||||
|
@ -57,8 +61,16 @@ const AnfNodePtr MulActivationFusion::Process(const FuncGraphPtr &func_graph, co
|
|||
MS_CHECK_TRUE_RET(mul_node != nullptr, nullptr);
|
||||
auto mul_cnode = mul_node->cast<CNodePtr>();
|
||||
MS_CHECK_TRUE_RET(mul_cnode != nullptr, nullptr);
|
||||
if (IsMultiOutputTensors(func_graph, mul_cnode)) {
|
||||
MS_LOG(INFO) << "mul has multiple out-nodes";
|
||||
return nullptr;
|
||||
}
|
||||
auto mul_prim = ops::GetOperator<ops::MulFusion>(mul_cnode->input(0));
|
||||
MS_CHECK_TRUE_RET(mul_prim != nullptr, nullptr);
|
||||
if (IsQuantParameterNode(mul_prim->GetPrim())) {
|
||||
MS_LOG(INFO) << "node is a quant-node";
|
||||
return nullptr;
|
||||
}
|
||||
if (mul_prim->get_activation_type() != NO_ACTIVATION) {
|
||||
MS_LOG(INFO) << "Mul already has activaton fusion, fusion type: " << mul_prim->get_activation_type();
|
||||
return nullptr;
|
||||
|
|
Loading…
Reference in New Issue