forked from mindspore-Ecosystem/mindspore
!15625 [lite]copy subgraph input and add op unsupport to infer
From: @xu_anyue Reviewed-by: @hangangqiang,@HilbertDavid Signed-off-by: @hangangqiang
This commit is contained in:
commit
e9221f158b
|
@ -109,6 +109,20 @@ tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
bool NodeInferShape::JudgeOpSupportInfer(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto prim_t = lite::GetPrimitiveT(cnode->input(0));
|
||||
if (prim_t == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim_t->value.type, lite::SCHEMA_CUR);
|
||||
if (parameter_gen == nullptr) {
|
||||
delete prim_t;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto anf_prim = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));
|
||||
|
|
|
@ -38,6 +38,7 @@ class NodeInferShape {
|
|||
train_flag_ = train_flag;
|
||||
}
|
||||
STATUS InferShape(const CNodePtr &cnode);
|
||||
bool JudgeOpSupportInfer(const CNodePtr &cnode);
|
||||
std::vector<int> GetInputShape(const CNodePtr &cnode, size_t index);
|
||||
std::vector<int> GetIntVecInput(const CNodePtr &cnode, size_t index);
|
||||
|
||||
|
|
|
@ -539,10 +539,7 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN
|
|||
if (!need_reset_ && TransTransFusion(func_graph, cnode)) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> match;
|
||||
PreProcessFowardInsert(func_graph, cnode, &match);
|
||||
auto status = node_infer_shape_.InferShape(cnode);
|
||||
PostProcessFowardInsert(func_graph, cnode, match);
|
||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "infer shape failed: " << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
|
@ -551,8 +548,6 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN
|
|||
}
|
||||
auto before_perm = trans_info.pre_ == kNHWC2NCHW ? NH2NC : NC2NH;
|
||||
auto after_perm = trans_info.post_ == kNCHW2NHWC ? NC2NH : NH2NC;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> match;
|
||||
PreProcessFowardInsert(func_graph, cnode, &match);
|
||||
if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
|
@ -562,7 +557,6 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN
|
|||
MS_LOG(ERROR) << "infer shape failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
PostProcessFowardInsert(func_graph, cnode, match);
|
||||
if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
|
@ -629,57 +623,9 @@ STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, con
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void UnifyFormatPass::PreProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *match) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
|
||||
if (sub_inputs_map_.find(graph_name) == sub_inputs_map_.end()) {
|
||||
return;
|
||||
}
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
auto tr = manager->Transact();
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
if (sub_inputs_map_[graph_name].find(cnode->input(i)) == sub_inputs_map_[graph_name].end()) {
|
||||
continue;
|
||||
}
|
||||
match->insert(std::make_pair(sub_inputs_map_[graph_name][cnode->input(i)], cnode->input(i)));
|
||||
tr.SetEdge(cnode, i, sub_inputs_map_[graph_name][cnode->input(i)]);
|
||||
tr.Commit();
|
||||
}
|
||||
}
|
||||
|
||||
void UnifyFormatPass::PostProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::unordered_map<AnfNodePtr, AnfNodePtr> &match) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
if (match.empty()) {
|
||||
return;
|
||||
}
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
auto tr = manager->Transact();
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
if (match.find(cnode->input(i)) != match.end()) {
|
||||
tr.SetEdge(cnode, i, match.at(cnode->input(i)));
|
||||
tr.Commit();
|
||||
}
|
||||
if (CheckPrimitiveType(cnode->input(i), prim::kPrimTranspose)) {
|
||||
auto trans_cnode = cnode->input(i)->cast<CNodePtr>();
|
||||
for (size_t j = 1; j < trans_cnode->size(); ++j) {
|
||||
if (match.find(trans_cnode->input(j)) == match.end()) {
|
||||
continue;
|
||||
}
|
||||
tr.SetEdge(trans_cnode, j, match.at(trans_cnode->input(j)));
|
||||
tr.Commit();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||
auto subgraph_name = GetValue<std::string>(sub_graph->get_attr("graph_name"));
|
||||
sub_inputs_map_[subgraph_name] = {};
|
||||
sub_inputs_map_[sub_graph] = {};
|
||||
auto sub_inputs = sub_graph->get_inputs();
|
||||
for (auto &node : sub_inputs) {
|
||||
auto param_node = node->cast<ParameterPtr>();
|
||||
|
@ -689,19 +635,52 @@ void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr
|
|||
node_name = node_name.substr(0, last_underline);
|
||||
last_underline = node_name.find_last_of("_");
|
||||
auto index = std::stoi(node_name.substr(last_underline + 1)) + 3;
|
||||
if (utils::isa<CNodePtr>(cnode->input(index)) && CheckPrimitiveType(cnode->input(index), prim::kPrimTranspose)) {
|
||||
std::vector<int> shape = {-1};
|
||||
auto trans_cnode = cnode->input(index)->cast<CNodePtr>();
|
||||
param_node->set_abstract(GetCNodeInputAbstract(cnode, index)->Clone());
|
||||
if (utils::isa<CNodePtr>(cnode->input(index))) {
|
||||
ShapeVector shape_vec = {-1};
|
||||
auto out_cnode = cnode->input(index)->cast<CNodePtr>();
|
||||
MS_ASSERT(trans_cnode != nullptr);
|
||||
auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
|
||||
if (trans_prim->GetAttr(kInferDone) != nullptr && GetValue<bool>(trans_prim->GetAttr(kInferDone))) {
|
||||
shape = node_infer_shape_.GetInputShape(cnode, index);
|
||||
auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0));
|
||||
if (out_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(kInferDone))) {
|
||||
param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vec));
|
||||
}
|
||||
auto type = trans_cnode->abstract()->cast<abstract::AbstractTensorPtr>()->element()->GetTypeTrack();
|
||||
std::vector<int64_t> shape_vec(shape.begin(), shape.end());
|
||||
param_node->set_abstract(std::make_shared<abstract::AbstractTensor>(type, shape_vec));
|
||||
} else {
|
||||
sub_inputs_map_[subgraph_name][node] = cnode->input(index);
|
||||
lite::DataInfo data_info;
|
||||
if (utils::isa<ParameterPtr>(cnode->input(index))) {
|
||||
if (cnode->input(index)->cast<ParameterPtr>()->has_default()) {
|
||||
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
|
||||
sub_inputs_map_[sub_graph].push_back(param_node);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info);
|
||||
if (status != lite::RET_OK) {
|
||||
continue;
|
||||
}
|
||||
ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
|
||||
if (data_info.data_.empty()) {
|
||||
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec));
|
||||
} else {
|
||||
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec,
|
||||
data_info.data_.data(), data_info.data_.size()));
|
||||
}
|
||||
sub_inputs_map_[sub_graph].push_back(param_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UnifyFormatPass::ResetSubGraphInput() {
|
||||
for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
|
||||
auto &sub_graph = iter->first;
|
||||
auto &sub_inputs = iter->second;
|
||||
auto manager = sub_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
for (auto &sub_input : sub_inputs) {
|
||||
auto param_node = sub_graph->add_parameter();
|
||||
MS_ASSERT(param_node != nullptr);
|
||||
param_node->set_abstract(sub_input->abstract()->Clone());
|
||||
param_node->set_name(sub_input->name());
|
||||
manager->Replace(sub_input, param_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -804,13 +783,6 @@ bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_gra
|
|||
}
|
||||
}
|
||||
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||
auto origin_inputs = cnode->inputs();
|
||||
for (size_t i = 3; i < cnode->size(); ++i) {
|
||||
if (sub_inputs_map_.find(graph_name) != sub_inputs_map_.end() &&
|
||||
sub_inputs_map_[graph_name].find(cnode->input(i)) != sub_inputs_map_[graph_name].end()) {
|
||||
cnode->set_input(i, sub_inputs_map_[graph_name][cnode->input(i)]);
|
||||
}
|
||||
}
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
|
@ -828,7 +800,6 @@ bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_gra
|
|||
(void)BasicProcess(sub_func_graph, false);
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
SetSubGraphAbstract(cnode, sub_func_graph);
|
||||
cnode->set_inputs(origin_inputs);
|
||||
continue;
|
||||
}
|
||||
status = HandleGraphNode(func_graph, cnode);
|
||||
|
@ -836,6 +807,7 @@ bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_gra
|
|||
return false;
|
||||
}
|
||||
}
|
||||
ResetSubGraphInput();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -858,13 +830,6 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap
|
|||
continue;
|
||||
}
|
||||
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||
auto origin_inputs = cnode->inputs();
|
||||
for (size_t i = 3; i < cnode->size(); ++i) {
|
||||
if (sub_inputs_map_.find(graph_name) != sub_inputs_map_.end() &&
|
||||
sub_inputs_map_[graph_name].find(cnode->input(i)) != sub_inputs_map_[graph_name].end()) {
|
||||
cnode->set_input(i, sub_inputs_map_[graph_name][cnode->input(i)]);
|
||||
}
|
||||
}
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
|
@ -882,7 +847,6 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap
|
|||
(void)DecreaseTransposeForSingleOp(sub_func_graph);
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
SetSubGraphAbstract(cnode, sub_func_graph);
|
||||
cnode->set_inputs(origin_inputs);
|
||||
continue;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
|
@ -904,6 +868,7 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap
|
|||
return false;
|
||||
}
|
||||
}
|
||||
ResetSubGraphInput();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1010,8 +975,53 @@ bool UnifyFormatPass::ResetFuncGraph(const FuncGraphPtr &func_graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool UnifyFormatPass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
bool all_op_can_infer = true;
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsSpecialType(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
all_op_can_infer = false;
|
||||
} else {
|
||||
all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph);
|
||||
}
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
all_op_can_infer = false;
|
||||
} else {
|
||||
all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto cur_op_can_infer = node_infer_shape_.JudgeOpSupportInfer(cnode);
|
||||
if (!cur_op_can_infer) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_ASSERT(prim != nullptr);
|
||||
lite::NotSupportOp::GetInstance()->InsertOp(prim->name());
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NOT_SUPPORT);
|
||||
all_op_can_infer = false;
|
||||
}
|
||||
}
|
||||
return all_op_can_infer;
|
||||
}
|
||||
|
||||
bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
if (!JudgeAllOpsCanInfer(func_graph)) {
|
||||
MS_LOG(ERROR) << "exist op cannot support infer shape.";
|
||||
return false;
|
||||
}
|
||||
need_reset_ = true;
|
||||
// insert transpose for some ops whose format must be NHWC, which is depend on framework.
|
||||
// In this process, transpose op cannot be fused to restore the original graph.
|
||||
|
@ -1039,6 +1049,10 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) {
|
|||
return true;
|
||||
}
|
||||
}
|
||||
if (!JudgeAllOpsCanInfer(func_graph)) {
|
||||
MS_LOG(ERROR) << "exist op cannot support infer shape.";
|
||||
return false;
|
||||
}
|
||||
// insert transpose for some ops whose format must be NHWC, which is depend on framework.
|
||||
// In this process, tranpose can be fused, which the original graph may not be able to restored.
|
||||
if (!BasicProcess(func_graph, true)) {
|
||||
|
|
|
@ -45,6 +45,7 @@ class UnifyFormatPass : public Pass {
|
|||
bool RunOnlyForShape(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph);
|
||||
bool ResetFuncGraph(const FuncGraphPtr &func_graph);
|
||||
bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph);
|
||||
bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph);
|
||||
|
@ -61,11 +62,8 @@ class UnifyFormatPass : public Pass {
|
|||
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info);
|
||||
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||
void PreProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *match);
|
||||
void PostProcessFowardInsert(const FuncGraphPtr &funcgraph, const CNodePtr &cnode,
|
||||
const std::unordered_map<AnfNodePtr, AnfNodePtr> &match);
|
||||
void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void ResetSubGraphInput();
|
||||
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
FmkType fmk_type_{lite::converter::FmkType_MS};
|
||||
|
@ -75,7 +73,7 @@ class UnifyFormatPass : public Pass {
|
|||
TransposeStrategy transpose_strategy_;
|
||||
std::set<AnfNodePtr> pre_insert_trans_;
|
||||
std::set<AnfNodePtr> post_insert_trans_;
|
||||
std::unordered_map<std::string, std::unordered_map<AnfNodePtr, AnfNodePtr>> sub_inputs_map_;
|
||||
std::unordered_map<FuncGraphPtr, std::vector<ParameterPtr>> sub_inputs_map_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue