!13462 [lite]fix bug for train model

From: @xu_anyue
Reviewed-by: @hangangqiang,@HilbertDavid
Signed-off-by: @hangangqiang
This commit is contained in:
mindspore-ci-bot 2021-03-17 15:49:41 +08:00 committed by Gitee
commit dd607fbb65
4 changed files with 57 additions and 11 deletions

View File

@ -356,18 +356,16 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
} }
} }
RemoveIfMakeTuple(cnode);
RemoveIfDepend(cnode); RemoveIfDepend(cnode);
if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend) { if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend ||
prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) {
continue; continue;
} }
if (prim->name() == "make_tuple") { if (prim->name() == "make_tuple") {
continue; continue;
} }
RemoveIfMakeTuple(cnode);
if (prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) {
continue;
}
auto node = std::make_unique<schema::CNodeT>(); auto node = std::make_unique<schema::CNodeT>();
if (node == nullptr) { if (node == nullptr) {
MS_LOG(ERROR) << "object failed to be constructed"; MS_LOG(ERROR) << "object failed to be constructed";

View File

@ -139,8 +139,9 @@ static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveT
static const std::vector<schema::PrimitiveType> needInsertOpList = { static const std::vector<schema::PrimitiveType> needInsertOpList = {
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
schema::PrimitiveType_PowFusion, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_AddFusion, schema::PrimitiveType_PowFusion, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_AddFusion,
schema::PrimitiveType_Split, schema::PrimitiveType_SliceFusion, schema::PrimitiveType_Crop, schema::PrimitiveType_AddN, schema::PrimitiveType_Split, schema::PrimitiveType_SliceFusion,
schema::PrimitiveType_MulFusion, schema::PrimitiveType_Maximum, schema::PrimitiveType_ActivationGrad}; schema::PrimitiveType_Crop, schema::PrimitiveType_MulFusion, schema::PrimitiveType_Maximum,
schema::PrimitiveType_ActivationGrad};
static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}}; static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}};

View File

@ -15,12 +15,28 @@
*/ */
#include "tools/optimizer/graph/redundant_op_remove_pass.h" #include "tools/optimizer/graph/redundant_op_remove_pass.h"
#include "mindspore/lite/include/errorcode.h" #include <memory>
#include <vector>
#include "include/errorcode.h"
#include "ops/make_tuple.h"
namespace mindspore::opt { namespace mindspore::opt {
namespace { namespace {
constexpr size_t InputDoubleNum = 2; constexpr size_t InputDoubleNum = 2;
constexpr size_t InputTripleNum = 3; constexpr size_t InputTripleNum = 3;
void FetchCNodeFromMakeTuple(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *inputs) {
MS_ASSERT(anf_node != nullptr);
MS_ASSERT(inputs != nullptr);
auto cnode = anf_node->cast<CNodePtr>();
if (cnode == nullptr) {
return;
}
for (size_t i = 1; i < cnode->size(); ++i) {
if (cnode->input(i)->isa<CNode>()) {
inputs->push_back(cnode->input(i));
}
}
}
} // namespace } // namespace
int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
if (!utils::isa<CNodePtr>(anf_node)) { if (!utils::isa<CNodePtr>(anf_node)) {
@ -58,6 +74,36 @@ int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraph
return RET_OK; return RET_OK;
} }
int RemoveRedundantOpPass::ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node) {
if (!utils::isa<CNodePtr>(anf_node)) {
MS_LOG(DEBUG) << "anf node is node a cnode.";
return lite::RET_NO_CHANGE;
}
auto cnode = anf_node->cast<CNodePtr>();
auto inputs = cnode->inputs();
std::vector<AnfNodePtr> new_inputs;
for (size_t i = 1; i < inputs.size(); ++i) {
if (!inputs[i]->isa<CNode>()) {
continue;
}
if (CheckPrimitiveType(inputs[i], prim::kPrimMakeTuple)) {
FetchCNodeFromMakeTuple(inputs[i], &new_inputs);
continue;
}
new_inputs.push_back(inputs[i]);
}
for (auto &node : new_inputs) {
func_graph->get_return()->add_input(node);
}
auto value = std::make_shared<UMonad>();
bool replace_succ = func_graph->manager()->Replace(anf_node, NewValueNode(value));
if (!replace_succ) {
MS_LOG(ERROR) << "replace redundant op failed.";
return lite::RET_ERROR;
}
return RET_OK;
}
int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
if (!utils::isa<CNodePtr>(anf_node)) { if (!utils::isa<CNodePtr>(anf_node)) {
MS_LOG(DEBUG) << "anf node is node a cnode."; MS_LOG(DEBUG) << "anf node is node a cnode.";
@ -111,7 +157,7 @@ bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
status = ReplaceOp(node, manager); status = ReplaceOp(node, manager);
} }
if (CheckPrimitiveType(node, prim::kPrimUpdateState)) { if (CheckPrimitiveType(node, prim::kPrimUpdateState)) {
status = ReplaceOp(node, manager); status = ReplaceUpdateStateOp(func_graph, node);
} }
if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
status = ReplaceTupleGetItem(node, manager); status = ReplaceTupleGetItem(node, manager);

View File

@ -29,6 +29,7 @@ class RemoveRedundantOpPass : public Pass {
RemoveRedundantOpPass() : Pass("remove_redundant_op_pass") {} RemoveRedundantOpPass() : Pass("remove_redundant_op_pass") {}
~RemoveRedundantOpPass() override = default; ~RemoveRedundantOpPass() override = default;
int ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); int ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
int ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node);
int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
bool Run(const FuncGraphPtr &graph) override; bool Run(const FuncGraphPtr &graph) override;