!11667 [MS][LITE]nested loop expand

From: @mengyuanli
Reviewed-by: @zhanghaibo5,@hangangqiang
Signed-off-by: @hangangqiang
This commit is contained in:
mindspore-ci-bot 2021-01-26 21:53:04 +08:00 committed by Gitee
commit d50bae7a87
7 changed files with 153 additions and 39 deletions

View File

@ -19,6 +19,7 @@ file(GLOB GRAPH_PASS
${CMAKE_CURRENT_SOURCE_DIR}/select_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/subgraph_node_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tensor_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/nested_loop_expand_pass.cc
)
set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
add_library(graph_pass_mid OBJECT ${GRAPH_PASS})

View File

@ -0,0 +1,98 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <set>
#include <algorithm>
#include <memory>
#include "tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.h"
#include "src/common/log_adapter.h"
#include "src/common/utils.h"
#include "tools/common/graph_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace lite {
bool NestedLoopExpandPass::IsNestedPartial(const std::unique_ptr<CNodeT> &node) {
if (node->primitive->value.type != PrimitiveType_Partial) {
return false;
}
auto subgraph_idx = ((schema::PartialT *)(node->primitive->value.value))->subGraphIndex;
auto &this_subgraph = graph_->subGraph.at(subgraph_idx);
for (auto &node_idx : this_subgraph->nodeIndices) {
auto &cnode = graph_->nodes.at(node_idx);
if (cnode->primitive->value.type == PrimitiveType_Partial) {
return true;
}
}
return false;
}
void NestedLoopExpandPass::ReplacePartialNodeWithSubgraph(const std::unique_ptr<SubGraphT> &main_graph) {
bool is_changed = false;
for (auto &node_idx : main_graph->nodeIndices) {
auto &node = graph_->nodes.at(node_idx);
if (!IsNestedPartial(node)) {
continue;
}
is_changed = true;
auto subgraph_idx = ((schema::PartialT *)(node->primitive->value.value))->subGraphIndex;
auto &this_subgraph = graph_->subGraph.at(subgraph_idx);
subgraph_to_drop_.push_back(subgraph_idx);
auto partial_pos = std::find(main_graph->nodeIndices.begin(), main_graph->nodeIndices.end(), node_idx);
std::vector<uint32_t> tmp;
tmp.assign(main_graph->nodeIndices.begin(), partial_pos);
tmp.insert(tmp.end(), this_subgraph->nodeIndices.begin(), this_subgraph->nodeIndices.end());
tmp.insert(tmp.end(), partial_pos + 1, main_graph->nodeIndices.end());
main_graph->nodeIndices.assign(tmp.begin(), tmp.end());
}
if (is_changed) {
ReplacePartialNodeWithSubgraph(main_graph);
}
}
STATUS NestedLoopExpandPass::Run(schema::MetaGraphT *graph) {
graph_ = graph;
auto &main_graph = graph_->subGraph[0];
ReplacePartialNodeWithSubgraph(main_graph);
for (auto idx : subgraph_to_drop_) {
graph_->subGraph.at(idx) = nullptr;
}
for (auto it = graph_->subGraph.begin(); it != graph_->subGraph.end();) {
if ((*it) == nullptr) {
it = graph_->subGraph.erase(it);
} else {
it++;
}
}
for (auto &node : graph_->nodes) {
if (node->primitive->value.type == PrimitiveType_Partial) {
((schema::PartialT *)(node->primitive->value.value))->subGraphIndex -= subgraph_to_drop_.size();
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NESTED_LOOP_EXPAND_PASS_H
#define MINDSPORE_LITE_NESTED_LOOP_EXPAND_PASS_H
#include <vector>
#include <utility>
#include <set>
#include <memory>
#include "tools/converter/optimizer.h"
namespace mindspore {
namespace lite {
class NestedLoopExpandPass : public GraphPass {
public:
NestedLoopExpandPass() = default;
~NestedLoopExpandPass() override = default;
STATUS Run(schema::MetaGraphT *graph) override;
private:
bool IsNestedPartial(const std::unique_ptr<CNodeT> &node);
void ReplacePartialNodeWithSubgraph(const std::unique_ptr<SubGraphT> &main_graph);
schema::MetaGraphT *graph_ = nullptr;
std::vector<int> subgraph_to_drop_{};
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H

View File

@ -35,7 +35,7 @@ STATUS TensorNamePass::Run(schema::MetaGraphT *graph) {
auto tensor_id = node->inputIndex.at(i);
auto &tensor = graph->allTensors.at(tensor_id);
if (tensor->name.empty()) {
MS_LOG(WARNING) << "input tensor (id = " << tensor_id << ") name is null";
MS_LOG(DEBUG) << "input tensor (id = " << tensor_id << ") name is null";
tensor->name = node->name + "/input-" + std::to_string(i);
}
}

View File

@ -57,27 +57,27 @@ bool UpdateConv2DParamPass::Run(const FuncGraphPtr &func_graph) {
auto conv2d_cnode = node->cast<CNodePtr>();
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv2d_cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "Conv2D node has no primitiveC.";
MS_LOG(DEBUG) << "Conv2D node has no primitiveC.";
continue;
}
auto primT = primitive_c->primitiveT();
if (primT == nullptr) {
MS_LOG(ERROR) << "Conv2D node has no primitiveT.";
MS_LOG(DEBUG) << "Conv2D node has no primitiveT.";
continue;
}
auto conv2d_primt = primT->value.AsConv2D();
auto weight_node = conv2d_cnode->input(lite::kAnfPopulaterInputNumTwo);
if (weight_node == nullptr) {
MS_LOG(ERROR) << "Conv2D weight node is nullptr.";
MS_LOG(DEBUG) << "Conv2D weight node is nullptr.";
continue;
}
if (!weight_node->isa<Parameter>()) {
MS_LOG(ERROR) << "Conv2D weight node is not parameter.";
MS_LOG(DEBUG) << "Conv2D weight node is not parameter.";
continue;
}
auto weight_param = weight_node->cast<ParameterPtr>();
if (!weight_param->has_default()) {
MS_LOG(ERROR) << "Conv2D weight node is not parameter.";
MS_LOG(DEBUG) << "Conv2D weight node is not parameter.";
continue;
}
auto default_param = weight_param->default_param();

View File

@ -44,29 +44,11 @@ ValueNodePtr WhilePass::GetSwitchAnfPrim() {
return nullptr;
}
auto partial_prim = std::make_shared<lite::Partial>(switch_primitiveT);
auto partial_prim = std::make_shared<lite::Switch>(switch_primitiveT);
ValueNodePtr partial_anf_prim = NewValueNode(partial_prim);
return partial_anf_prim;
}
void WhilePass::ReplaceInput(const std::vector<AnfNodePtr> &node_list, AnfNodePtr new_input_cnode,
std::string para_name) {
for (auto &node : node_list) {
if (utils::isa<CNodePtr>(node)) {
auto cnode = utils::cast<CNodePtr>(node);
for (size_t k = 0; k < cnode->inputs().size(); k++) {
if (!utils::isa<ParameterPtr>(cnode->input(k))) {
continue;
}
auto para_input = utils::cast<ParameterPtr>(cnode->input(k));
if (para_input->name() == para_name) {
cnode->set_input(k, new_input_cnode);
}
}
}
}
}
bool WhilePass::Run(const FuncGraphPtr &graph) {
auto node_list = TopoSort(graph->get_return());
static int count = 0;
@ -87,34 +69,23 @@ bool WhilePass::Run(const FuncGraphPtr &graph) {
// the order is fixed.
auto cond_vnode = while_cnode->input(kWhileCondIndex);
auto body_vnode = while_cnode->input(kWhileBodyIndex);
// body_vnode->cast<ValueNodePtr>()->set_value()
auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_vnode);
auto body_fg = GetValueNode<std::shared_ptr<FuncGraph>>(body_vnode);
if (cond_fg == nullptr || body_fg == nullptr) {
MS_LOG(ERROR) << "Get value as func_graph failed.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_FAILED);
return false;
}
// create cond partial cnode
std::vector<AnfNodePtr> cond_partial_op_inputs{cond_vnode};
// create body partial cnode
std::vector<AnfNodePtr> body_partial_op_inputs{body_vnode};
// add while op input to cond_cnode and body_cnode
cond_partial_op_inputs.insert(cond_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize,
while_cnode->inputs().end());
body_partial_op_inputs.insert(body_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize,
while_cnode->inputs().end());
static int idx = 0;
auto cond_partial_node = graph->NewCNode(cond_partial_op_inputs);
cond_partial_node->set_fullname_with_scope("Partial-while-cond-" + std::to_string(idx));
cond_partial_node->set_abstract(cond_fg->output()->abstract());
auto body_partial_node = graph->NewCNode(body_partial_op_inputs);
body_partial_node->set_fullname_with_scope("Partial-while-body-" + std::to_string(idx));
idx++;
@ -166,7 +137,6 @@ bool WhilePass::Run(const FuncGraphPtr &graph) {
}
abstract_list.push_back(cnode->abstract());
}
switch_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
// create cond partial cnode
@ -176,7 +146,6 @@ bool WhilePass::Run(const FuncGraphPtr &graph) {
manager->SetEdge(node_user.first, node_user.second, switch_cnode);
}
}
return true;
}
} // namespace mindspore::opt

View File

@ -32,7 +32,6 @@ class WhilePass : public Pass {
bool Run(const FuncGraphPtr &graph) override;
private:
void ReplaceInput(const std::vector<AnfNodePtr> &node_list, AnfNodePtr new_input_cnode, std::string para_name);
ValueNodePtr GetSwitchAnfPrim();
const size_t kWhileMinInputSize = 3;