!15625 [lite]copy subgraph input and add op unsupport to infer

From: @xu_anyue
Reviewed-by: @hangangqiang,@HilbertDavid
Signed-off-by: @hangangqiang
This commit is contained in:
mindspore-ci-bot 2021-04-26 21:03:26 +08:00 committed by Gitee
commit e9221f158b
4 changed files with 113 additions and 86 deletions

View File

@ -109,6 +109,20 @@ tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) {
}
} // namespace
bool NodeInferShape::JudgeOpSupportInfer(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
auto prim_t = lite::GetPrimitiveT(cnode->input(0));
if (prim_t == nullptr) {
return false;
}
auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim_t->value.type, lite::SCHEMA_CUR);
if (parameter_gen == nullptr) {
delete prim_t;
return false;
}
return true;
}
STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
auto anf_prim = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));

View File

@ -38,6 +38,7 @@ class NodeInferShape {
train_flag_ = train_flag;
}
STATUS InferShape(const CNodePtr &cnode);
bool JudgeOpSupportInfer(const CNodePtr &cnode);
std::vector<int> GetInputShape(const CNodePtr &cnode, size_t index);
std::vector<int> GetIntVecInput(const CNodePtr &cnode, size_t index);

View File

@ -539,10 +539,7 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN
if (!need_reset_ && TransTransFusion(func_graph, cnode)) {
return lite::RET_OK;
}
std::unordered_map<AnfNodePtr, AnfNodePtr> match;
PreProcessFowardInsert(func_graph, cnode, &match);
auto status = node_infer_shape_.InferShape(cnode);
PostProcessFowardInsert(func_graph, cnode, match);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed: " << cnode->fullname_with_scope();
return lite::RET_ERROR;
@ -551,8 +548,6 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN
}
auto before_perm = trans_info.pre_ == kNHWC2NCHW ? NH2NC : NC2NH;
auto after_perm = trans_info.post_ == kNCHW2NHWC ? NC2NH : NH2NC;
std::unordered_map<AnfNodePtr, AnfNodePtr> match;
PreProcessFowardInsert(func_graph, cnode, &match);
if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope();
return lite::RET_ERROR;
@ -562,7 +557,6 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN
MS_LOG(ERROR) << "infer shape failed.";
return lite::RET_ERROR;
}
PostProcessFowardInsert(func_graph, cnode, match);
if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
return lite::RET_ERROR;
@ -629,57 +623,9 @@ STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, con
return lite::RET_OK;
}
void UnifyFormatPass::PreProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
std::unordered_map<AnfNodePtr, AnfNodePtr> *match) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
if (sub_inputs_map_.find(graph_name) == sub_inputs_map_.end()) {
return;
}
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
auto tr = manager->Transact();
for (size_t i = 1; i < cnode->size(); ++i) {
if (sub_inputs_map_[graph_name].find(cnode->input(i)) == sub_inputs_map_[graph_name].end()) {
continue;
}
match->insert(std::make_pair(sub_inputs_map_[graph_name][cnode->input(i)], cnode->input(i)));
tr.SetEdge(cnode, i, sub_inputs_map_[graph_name][cnode->input(i)]);
tr.Commit();
}
}
void UnifyFormatPass::PostProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::unordered_map<AnfNodePtr, AnfNodePtr> &match) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (match.empty()) {
return;
}
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
auto tr = manager->Transact();
for (size_t i = 1; i < cnode->size(); ++i) {
if (match.find(cnode->input(i)) != match.end()) {
tr.SetEdge(cnode, i, match.at(cnode->input(i)));
tr.Commit();
}
if (CheckPrimitiveType(cnode->input(i), prim::kPrimTranspose)) {
auto trans_cnode = cnode->input(i)->cast<CNodePtr>();
for (size_t j = 1; j < trans_cnode->size(); ++j) {
if (match.find(trans_cnode->input(j)) == match.end()) {
continue;
}
tr.SetEdge(trans_cnode, j, match.at(trans_cnode->input(j)));
tr.Commit();
}
}
}
}
void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto subgraph_name = GetValue<std::string>(sub_graph->get_attr("graph_name"));
sub_inputs_map_[subgraph_name] = {};
sub_inputs_map_[sub_graph] = {};
auto sub_inputs = sub_graph->get_inputs();
for (auto &node : sub_inputs) {
auto param_node = node->cast<ParameterPtr>();
@ -689,19 +635,52 @@ void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr
node_name = node_name.substr(0, last_underline);
last_underline = node_name.find_last_of("_");
auto index = std::stoi(node_name.substr(last_underline + 1)) + 3;
if (utils::isa<CNodePtr>(cnode->input(index)) && CheckPrimitiveType(cnode->input(index), prim::kPrimTranspose)) {
std::vector<int> shape = {-1};
auto trans_cnode = cnode->input(index)->cast<CNodePtr>();
param_node->set_abstract(GetCNodeInputAbstract(cnode, index)->Clone());
if (utils::isa<CNodePtr>(cnode->input(index))) {
ShapeVector shape_vec = {-1};
auto out_cnode = cnode->input(index)->cast<CNodePtr>();
MS_ASSERT(trans_cnode != nullptr);
auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
if (trans_prim->GetAttr(kInferDone) != nullptr && GetValue<bool>(trans_prim->GetAttr(kInferDone))) {
shape = node_infer_shape_.GetInputShape(cnode, index);
auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0));
if (out_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(kInferDone))) {
param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vec));
}
auto type = trans_cnode->abstract()->cast<abstract::AbstractTensorPtr>()->element()->GetTypeTrack();
std::vector<int64_t> shape_vec(shape.begin(), shape.end());
param_node->set_abstract(std::make_shared<abstract::AbstractTensor>(type, shape_vec));
} else {
sub_inputs_map_[subgraph_name][node] = cnode->input(index);
lite::DataInfo data_info;
if (utils::isa<ParameterPtr>(cnode->input(index))) {
if (cnode->input(index)->cast<ParameterPtr>()->has_default()) {
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
sub_inputs_map_[sub_graph].push_back(param_node);
}
continue;
}
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info);
if (status != lite::RET_OK) {
continue;
}
ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
if (data_info.data_.empty()) {
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec));
} else {
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec,
data_info.data_.data(), data_info.data_.size()));
}
sub_inputs_map_[sub_graph].push_back(param_node);
}
}
}
void UnifyFormatPass::ResetSubGraphInput() {
for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
auto &sub_graph = iter->first;
auto &sub_inputs = iter->second;
auto manager = sub_graph->manager();
MS_ASSERT(manager != nullptr);
for (auto &sub_input : sub_inputs) {
auto param_node = sub_graph->add_parameter();
MS_ASSERT(param_node != nullptr);
param_node->set_abstract(sub_input->abstract()->Clone());
param_node->set_name(sub_input->name());
manager->Replace(sub_input, param_node);
}
}
}
@ -804,13 +783,6 @@ bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_gra
}
}
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
auto origin_inputs = cnode->inputs();
for (size_t i = 3; i < cnode->size(); ++i) {
if (sub_inputs_map_.find(graph_name) != sub_inputs_map_.end() &&
sub_inputs_map_[graph_name].find(cnode->input(i)) != sub_inputs_map_[graph_name].end()) {
cnode->set_input(i, sub_inputs_map_[graph_name][cnode->input(i)]);
}
}
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
@ -828,7 +800,6 @@ bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_gra
(void)BasicProcess(sub_func_graph, false);
SetSubGraphOutput(cnode, sub_func_graph);
SetSubGraphAbstract(cnode, sub_func_graph);
cnode->set_inputs(origin_inputs);
continue;
}
status = HandleGraphNode(func_graph, cnode);
@ -836,6 +807,7 @@ bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_gra
return false;
}
}
ResetSubGraphInput();
return true;
}
@ -858,13 +830,6 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap
continue;
}
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
auto origin_inputs = cnode->inputs();
for (size_t i = 3; i < cnode->size(); ++i) {
if (sub_inputs_map_.find(graph_name) != sub_inputs_map_.end() &&
sub_inputs_map_[graph_name].find(cnode->input(i)) != sub_inputs_map_[graph_name].end()) {
cnode->set_input(i, sub_inputs_map_[graph_name][cnode->input(i)]);
}
}
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
@ -882,7 +847,6 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap
(void)DecreaseTransposeForSingleOp(sub_func_graph);
SetSubGraphOutput(cnode, sub_func_graph);
SetSubGraphAbstract(cnode, sub_func_graph);
cnode->set_inputs(origin_inputs);
continue;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
@ -904,6 +868,7 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap
return false;
}
}
ResetSubGraphInput();
return true;
}
@ -1010,8 +975,53 @@ bool UnifyFormatPass::ResetFuncGraph(const FuncGraphPtr &func_graph) {
return true;
}
bool UnifyFormatPass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto node_list = TopoSort(func_graph->get_return());
bool all_op_can_infer = true;
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (IsSpecialType(cnode)) {
continue;
}
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
all_op_can_infer = false;
} else {
all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph);
}
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
all_op_can_infer = false;
} else {
all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph);
}
continue;
}
auto cur_op_can_infer = node_infer_shape_.JudgeOpSupportInfer(cnode);
if (!cur_op_can_infer) {
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(prim != nullptr);
lite::NotSupportOp::GetInstance()->InsertOp(prim->name());
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NOT_SUPPORT);
all_op_can_infer = false;
}
}
return all_op_can_infer;
}
bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
if (!JudgeAllOpsCanInfer(func_graph)) {
MS_LOG(ERROR) << "exist op cannot support infer shape.";
return false;
}
need_reset_ = true;
// insert transpose for some ops whose format must be NHWC, which is depend on framework.
// In this process, transpose op cannot be fused to restore the original graph.
@ -1039,6 +1049,10 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) {
return true;
}
}
if (!JudgeAllOpsCanInfer(func_graph)) {
MS_LOG(ERROR) << "exist op cannot support infer shape.";
return false;
}
// insert transpose for some ops whose format must be NHWC, which is depend on framework.
// In this process, tranpose can be fused, which the original graph may not be able to restored.
if (!BasicProcess(func_graph, true)) {

View File

@ -45,6 +45,7 @@ class UnifyFormatPass : public Pass {
bool RunOnlyForShape(const FuncGraphPtr &func_graph);
private:
bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph);
bool ResetFuncGraph(const FuncGraphPtr &func_graph);
bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph);
bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph);
@ -61,11 +62,8 @@ class UnifyFormatPass : public Pass {
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info);
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
void PreProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
std::unordered_map<AnfNodePtr, AnfNodePtr> *match);
void PostProcessFowardInsert(const FuncGraphPtr &funcgraph, const CNodePtr &cnode,
const std::unordered_map<AnfNodePtr, AnfNodePtr> &match);
void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void ResetSubGraphInput();
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
FmkType fmk_type_{lite::converter::FmkType_MS};
@ -75,7 +73,7 @@ class UnifyFormatPass : public Pass {
TransposeStrategy transpose_strategy_;
std::set<AnfNodePtr> pre_insert_trans_;
std::set<AnfNodePtr> post_insert_trans_;
std::unordered_map<std::string, std::unordered_map<AnfNodePtr, AnfNodePtr>> sub_inputs_map_;
std::unordered_map<FuncGraphPtr, std::vector<ParameterPtr>> sub_inputs_map_;
};
} // namespace opt
} // namespace mindspore