forked from mindspore-Ecosystem/mindspore
!11667 [MS][LITE]nested loop expand
From: @mengyuanli Reviewed-by: @zhanghaibo5,@hangangqiang Signed-off-by: @hangangqiang
This commit is contained in:
commit
d50bae7a87
|
@ -19,6 +19,7 @@ file(GLOB GRAPH_PASS
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/select_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/subgraph_node_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tensor_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nested_loop_expand_pass.cc
|
||||
)
|
||||
set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
|
||||
add_library(graph_pass_mid OBJECT ${GRAPH_PASS})
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
bool NestedLoopExpandPass::IsNestedPartial(const std::unique_ptr<CNodeT> &node) {
|
||||
if (node->primitive->value.type != PrimitiveType_Partial) {
|
||||
return false;
|
||||
}
|
||||
auto subgraph_idx = ((schema::PartialT *)(node->primitive->value.value))->subGraphIndex;
|
||||
auto &this_subgraph = graph_->subGraph.at(subgraph_idx);
|
||||
|
||||
for (auto &node_idx : this_subgraph->nodeIndices) {
|
||||
auto &cnode = graph_->nodes.at(node_idx);
|
||||
if (cnode->primitive->value.type == PrimitiveType_Partial) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void NestedLoopExpandPass::ReplacePartialNodeWithSubgraph(const std::unique_ptr<SubGraphT> &main_graph) {
|
||||
bool is_changed = false;
|
||||
for (auto &node_idx : main_graph->nodeIndices) {
|
||||
auto &node = graph_->nodes.at(node_idx);
|
||||
if (!IsNestedPartial(node)) {
|
||||
continue;
|
||||
}
|
||||
is_changed = true;
|
||||
auto subgraph_idx = ((schema::PartialT *)(node->primitive->value.value))->subGraphIndex;
|
||||
auto &this_subgraph = graph_->subGraph.at(subgraph_idx);
|
||||
subgraph_to_drop_.push_back(subgraph_idx);
|
||||
auto partial_pos = std::find(main_graph->nodeIndices.begin(), main_graph->nodeIndices.end(), node_idx);
|
||||
std::vector<uint32_t> tmp;
|
||||
tmp.assign(main_graph->nodeIndices.begin(), partial_pos);
|
||||
tmp.insert(tmp.end(), this_subgraph->nodeIndices.begin(), this_subgraph->nodeIndices.end());
|
||||
tmp.insert(tmp.end(), partial_pos + 1, main_graph->nodeIndices.end());
|
||||
main_graph->nodeIndices.assign(tmp.begin(), tmp.end());
|
||||
}
|
||||
|
||||
if (is_changed) {
|
||||
ReplacePartialNodeWithSubgraph(main_graph);
|
||||
}
|
||||
}
|
||||
|
||||
STATUS NestedLoopExpandPass::Run(schema::MetaGraphT *graph) {
|
||||
graph_ = graph;
|
||||
auto &main_graph = graph_->subGraph[0];
|
||||
|
||||
ReplacePartialNodeWithSubgraph(main_graph);
|
||||
|
||||
for (auto idx : subgraph_to_drop_) {
|
||||
graph_->subGraph.at(idx) = nullptr;
|
||||
}
|
||||
|
||||
for (auto it = graph_->subGraph.begin(); it != graph_->subGraph.end();) {
|
||||
if ((*it) == nullptr) {
|
||||
it = graph_->subGraph.erase(it);
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &node : graph_->nodes) {
|
||||
if (node->primitive->value.type == PrimitiveType_Partial) {
|
||||
((schema::PartialT *)(node->primitive->value.value))->subGraphIndex -= subgraph_to_drop_.size();
|
||||
}
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NESTED_LOOP_EXPAND_PASS_H
|
||||
#define MINDSPORE_LITE_NESTED_LOOP_EXPAND_PASS_H
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include "tools/converter/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class NestedLoopExpandPass : public GraphPass {
|
||||
public:
|
||||
NestedLoopExpandPass() = default;
|
||||
|
||||
~NestedLoopExpandPass() override = default;
|
||||
|
||||
STATUS Run(schema::MetaGraphT *graph) override;
|
||||
|
||||
private:
|
||||
bool IsNestedPartial(const std::unique_ptr<CNodeT> &node);
|
||||
|
||||
void ReplacePartialNodeWithSubgraph(const std::unique_ptr<SubGraphT> &main_graph);
|
||||
|
||||
schema::MetaGraphT *graph_ = nullptr;
|
||||
|
||||
std::vector<int> subgraph_to_drop_{};
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H
|
|
@ -35,7 +35,7 @@ STATUS TensorNamePass::Run(schema::MetaGraphT *graph) {
|
|||
auto tensor_id = node->inputIndex.at(i);
|
||||
auto &tensor = graph->allTensors.at(tensor_id);
|
||||
if (tensor->name.empty()) {
|
||||
MS_LOG(WARNING) << "input tensor (id = " << tensor_id << ") name is null";
|
||||
MS_LOG(DEBUG) << "input tensor (id = " << tensor_id << ") name is null";
|
||||
tensor->name = node->name + "/input-" + std::to_string(i);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -57,27 +57,27 @@ bool UpdateConv2DParamPass::Run(const FuncGraphPtr &func_graph) {
|
|||
auto conv2d_cnode = node->cast<CNodePtr>();
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv2d_cnode->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "Conv2D node has no primitiveC.";
|
||||
MS_LOG(DEBUG) << "Conv2D node has no primitiveC.";
|
||||
continue;
|
||||
}
|
||||
auto primT = primitive_c->primitiveT();
|
||||
if (primT == nullptr) {
|
||||
MS_LOG(ERROR) << "Conv2D node has no primitiveT.";
|
||||
MS_LOG(DEBUG) << "Conv2D node has no primitiveT.";
|
||||
continue;
|
||||
}
|
||||
auto conv2d_primt = primT->value.AsConv2D();
|
||||
auto weight_node = conv2d_cnode->input(lite::kAnfPopulaterInputNumTwo);
|
||||
if (weight_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Conv2D weight node is nullptr.";
|
||||
MS_LOG(DEBUG) << "Conv2D weight node is nullptr.";
|
||||
continue;
|
||||
}
|
||||
if (!weight_node->isa<Parameter>()) {
|
||||
MS_LOG(ERROR) << "Conv2D weight node is not parameter.";
|
||||
MS_LOG(DEBUG) << "Conv2D weight node is not parameter.";
|
||||
continue;
|
||||
}
|
||||
auto weight_param = weight_node->cast<ParameterPtr>();
|
||||
if (!weight_param->has_default()) {
|
||||
MS_LOG(ERROR) << "Conv2D weight node is not parameter.";
|
||||
MS_LOG(DEBUG) << "Conv2D weight node is not parameter.";
|
||||
continue;
|
||||
}
|
||||
auto default_param = weight_param->default_param();
|
||||
|
|
|
@ -44,29 +44,11 @@ ValueNodePtr WhilePass::GetSwitchAnfPrim() {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto partial_prim = std::make_shared<lite::Partial>(switch_primitiveT);
|
||||
auto partial_prim = std::make_shared<lite::Switch>(switch_primitiveT);
|
||||
ValueNodePtr partial_anf_prim = NewValueNode(partial_prim);
|
||||
return partial_anf_prim;
|
||||
}
|
||||
|
||||
void WhilePass::ReplaceInput(const std::vector<AnfNodePtr> &node_list, AnfNodePtr new_input_cnode,
|
||||
std::string para_name) {
|
||||
for (auto &node : node_list) {
|
||||
if (utils::isa<CNodePtr>(node)) {
|
||||
auto cnode = utils::cast<CNodePtr>(node);
|
||||
for (size_t k = 0; k < cnode->inputs().size(); k++) {
|
||||
if (!utils::isa<ParameterPtr>(cnode->input(k))) {
|
||||
continue;
|
||||
}
|
||||
auto para_input = utils::cast<ParameterPtr>(cnode->input(k));
|
||||
if (para_input->name() == para_name) {
|
||||
cnode->set_input(k, new_input_cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool WhilePass::Run(const FuncGraphPtr &graph) {
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
static int count = 0;
|
||||
|
@ -87,34 +69,23 @@ bool WhilePass::Run(const FuncGraphPtr &graph) {
|
|||
// the order is fixed.
|
||||
auto cond_vnode = while_cnode->input(kWhileCondIndex);
|
||||
auto body_vnode = while_cnode->input(kWhileBodyIndex);
|
||||
|
||||
// body_vnode->cast<ValueNodePtr>()->set_value()
|
||||
auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_vnode);
|
||||
auto body_fg = GetValueNode<std::shared_ptr<FuncGraph>>(body_vnode);
|
||||
|
||||
if (cond_fg == nullptr || body_fg == nullptr) {
|
||||
MS_LOG(ERROR) << "Get value as func_graph failed.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_FAILED);
|
||||
return false;
|
||||
}
|
||||
|
||||
// create cond partial cnode
|
||||
std::vector<AnfNodePtr> cond_partial_op_inputs{cond_vnode};
|
||||
|
||||
// create body partial cnode
|
||||
std::vector<AnfNodePtr> body_partial_op_inputs{body_vnode};
|
||||
|
||||
// add while op input to cond_cnode and body_cnode
|
||||
cond_partial_op_inputs.insert(cond_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize,
|
||||
while_cnode->inputs().end());
|
||||
body_partial_op_inputs.insert(body_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize,
|
||||
while_cnode->inputs().end());
|
||||
|
||||
static int idx = 0;
|
||||
auto cond_partial_node = graph->NewCNode(cond_partial_op_inputs);
|
||||
cond_partial_node->set_fullname_with_scope("Partial-while-cond-" + std::to_string(idx));
|
||||
cond_partial_node->set_abstract(cond_fg->output()->abstract());
|
||||
|
||||
auto body_partial_node = graph->NewCNode(body_partial_op_inputs);
|
||||
body_partial_node->set_fullname_with_scope("Partial-while-body-" + std::to_string(idx));
|
||||
idx++;
|
||||
|
@ -166,7 +137,6 @@ bool WhilePass::Run(const FuncGraphPtr &graph) {
|
|||
}
|
||||
abstract_list.push_back(cnode->abstract());
|
||||
}
|
||||
|
||||
switch_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
|
||||
// create cond partial cnode
|
||||
|
@ -176,7 +146,6 @@ bool WhilePass::Run(const FuncGraphPtr &graph) {
|
|||
manager->SetEdge(node_user.first, node_user.second, switch_cnode);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore::opt
|
||||
|
|
|
@ -32,7 +32,6 @@ class WhilePass : public Pass {
|
|||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
void ReplaceInput(const std::vector<AnfNodePtr> &node_list, AnfNodePtr new_input_cnode, std::string para_name);
|
||||
ValueNodePtr GetSwitchAnfPrim();
|
||||
|
||||
const size_t kWhileMinInputSize = 3;
|
||||
|
|
Loading…
Reference in New Issue