forked from mindspore-Ecosystem/mindspore
!13462 [lite]fix bug for train model
From: @xu_anyue Reviewed-by: @hangangqiang,@HilbertDavid Signed-off-by: @hangangqiang
This commit is contained in:
commit
dd607fbb65
|
@ -356,18 +356,16 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
|
|||
}
|
||||
}
|
||||
|
||||
RemoveIfMakeTuple(cnode);
|
||||
RemoveIfDepend(cnode);
|
||||
if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend) {
|
||||
if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend ||
|
||||
prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) {
|
||||
continue;
|
||||
}
|
||||
if (prim->name() == "make_tuple") {
|
||||
continue;
|
||||
}
|
||||
RemoveIfMakeTuple(cnode);
|
||||
|
||||
if (prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) {
|
||||
continue;
|
||||
}
|
||||
auto node = std::make_unique<schema::CNodeT>();
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "object failed to be constructed";
|
||||
|
|
|
@ -137,10 +137,11 @@ static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveT
|
|||
schema::PrimitiveType_L2NormalizeFusion};
|
||||
|
||||
static const std::vector<schema::PrimitiveType> needInsertOpList = {
|
||||
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
|
||||
schema::PrimitiveType_PowFusion, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_AddFusion,
|
||||
schema::PrimitiveType_Split, schema::PrimitiveType_SliceFusion, schema::PrimitiveType_Crop,
|
||||
schema::PrimitiveType_MulFusion, schema::PrimitiveType_Maximum, schema::PrimitiveType_ActivationGrad};
|
||||
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
|
||||
schema::PrimitiveType_PowFusion, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_AddFusion,
|
||||
schema::PrimitiveType_AddN, schema::PrimitiveType_Split, schema::PrimitiveType_SliceFusion,
|
||||
schema::PrimitiveType_Crop, schema::PrimitiveType_MulFusion, schema::PrimitiveType_Maximum,
|
||||
schema::PrimitiveType_ActivationGrad};
|
||||
|
||||
static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}};
|
||||
|
||||
|
|
|
@ -15,12 +15,28 @@
|
|||
*/
|
||||
|
||||
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
|
||||
#include "mindspore/lite/include/errorcode.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "include/errorcode.h"
|
||||
#include "ops/make_tuple.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
constexpr size_t InputDoubleNum = 2;
|
||||
constexpr size_t InputTripleNum = 3;
|
||||
void FetchCNodeFromMakeTuple(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *inputs) {
|
||||
MS_ASSERT(anf_node != nullptr);
|
||||
MS_ASSERT(inputs != nullptr);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
return;
|
||||
}
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
if (cnode->input(i)->isa<CNode>()) {
|
||||
inputs->push_back(cnode->input(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
||||
if (!utils::isa<CNodePtr>(anf_node)) {
|
||||
|
@ -58,6 +74,36 @@ int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraph
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int RemoveRedundantOpPass::ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node) {
|
||||
if (!utils::isa<CNodePtr>(anf_node)) {
|
||||
MS_LOG(DEBUG) << "anf node is node a cnode.";
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
if (!inputs[i]->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
if (CheckPrimitiveType(inputs[i], prim::kPrimMakeTuple)) {
|
||||
FetchCNodeFromMakeTuple(inputs[i], &new_inputs);
|
||||
continue;
|
||||
}
|
||||
new_inputs.push_back(inputs[i]);
|
||||
}
|
||||
for (auto &node : new_inputs) {
|
||||
func_graph->get_return()->add_input(node);
|
||||
}
|
||||
auto value = std::make_shared<UMonad>();
|
||||
bool replace_succ = func_graph->manager()->Replace(anf_node, NewValueNode(value));
|
||||
if (!replace_succ) {
|
||||
MS_LOG(ERROR) << "replace redundant op failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
||||
if (!utils::isa<CNodePtr>(anf_node)) {
|
||||
MS_LOG(DEBUG) << "anf node is node a cnode.";
|
||||
|
@ -111,7 +157,7 @@ bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
|
|||
status = ReplaceOp(node, manager);
|
||||
}
|
||||
if (CheckPrimitiveType(node, prim::kPrimUpdateState)) {
|
||||
status = ReplaceOp(node, manager);
|
||||
status = ReplaceUpdateStateOp(func_graph, node);
|
||||
}
|
||||
if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
status = ReplaceTupleGetItem(node, manager);
|
||||
|
|
|
@ -29,6 +29,7 @@ class RemoveRedundantOpPass : public Pass {
|
|||
RemoveRedundantOpPass() : Pass("remove_redundant_op_pass") {}
|
||||
~RemoveRedundantOpPass() override = default;
|
||||
int ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
|
||||
int ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node);
|
||||
int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
|
|
Loading…
Reference in New Issue