forked from mindspore-Ecosystem/mindspore
!13462 [lite]fix bug for train model
From: @xu_anyue Reviewed-by: @hangangqiang,@HilbertDavid Signed-off-by: @hangangqiang
This commit is contained in:
commit
dd607fbb65
|
@ -356,18 +356,16 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RemoveIfMakeTuple(cnode);
|
|
||||||
RemoveIfDepend(cnode);
|
RemoveIfDepend(cnode);
|
||||||
if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend) {
|
if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend ||
|
||||||
|
prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (prim->name() == "make_tuple") {
|
if (prim->name() == "make_tuple") {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
RemoveIfMakeTuple(cnode);
|
||||||
|
|
||||||
if (prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto node = std::make_unique<schema::CNodeT>();
|
auto node = std::make_unique<schema::CNodeT>();
|
||||||
if (node == nullptr) {
|
if (node == nullptr) {
|
||||||
MS_LOG(ERROR) << "object failed to be constructed";
|
MS_LOG(ERROR) << "object failed to be constructed";
|
||||||
|
|
|
@ -137,10 +137,11 @@ static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveT
|
||||||
schema::PrimitiveType_L2NormalizeFusion};
|
schema::PrimitiveType_L2NormalizeFusion};
|
||||||
|
|
||||||
static const std::vector<schema::PrimitiveType> needInsertOpList = {
|
static const std::vector<schema::PrimitiveType> needInsertOpList = {
|
||||||
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
|
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
|
||||||
schema::PrimitiveType_PowFusion, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_AddFusion,
|
schema::PrimitiveType_PowFusion, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_AddFusion,
|
||||||
schema::PrimitiveType_Split, schema::PrimitiveType_SliceFusion, schema::PrimitiveType_Crop,
|
schema::PrimitiveType_AddN, schema::PrimitiveType_Split, schema::PrimitiveType_SliceFusion,
|
||||||
schema::PrimitiveType_MulFusion, schema::PrimitiveType_Maximum, schema::PrimitiveType_ActivationGrad};
|
schema::PrimitiveType_Crop, schema::PrimitiveType_MulFusion, schema::PrimitiveType_Maximum,
|
||||||
|
schema::PrimitiveType_ActivationGrad};
|
||||||
|
|
||||||
static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}};
|
static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}};
|
||||||
|
|
||||||
|
|
|
@ -15,12 +15,28 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
|
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
|
||||||
#include "mindspore/lite/include/errorcode.h"
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include "include/errorcode.h"
|
||||||
|
#include "ops/make_tuple.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t InputDoubleNum = 2;
|
constexpr size_t InputDoubleNum = 2;
|
||||||
constexpr size_t InputTripleNum = 3;
|
constexpr size_t InputTripleNum = 3;
|
||||||
|
void FetchCNodeFromMakeTuple(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *inputs) {
|
||||||
|
MS_ASSERT(anf_node != nullptr);
|
||||||
|
MS_ASSERT(inputs != nullptr);
|
||||||
|
auto cnode = anf_node->cast<CNodePtr>();
|
||||||
|
if (cnode == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||||
|
if (cnode->input(i)->isa<CNode>()) {
|
||||||
|
inputs->push_back(cnode->input(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
||||||
if (!utils::isa<CNodePtr>(anf_node)) {
|
if (!utils::isa<CNodePtr>(anf_node)) {
|
||||||
|
@ -58,6 +74,36 @@ int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraph
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int RemoveRedundantOpPass::ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node) {
|
||||||
|
if (!utils::isa<CNodePtr>(anf_node)) {
|
||||||
|
MS_LOG(DEBUG) << "anf node is node a cnode.";
|
||||||
|
return lite::RET_NO_CHANGE;
|
||||||
|
}
|
||||||
|
auto cnode = anf_node->cast<CNodePtr>();
|
||||||
|
auto inputs = cnode->inputs();
|
||||||
|
std::vector<AnfNodePtr> new_inputs;
|
||||||
|
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||||
|
if (!inputs[i]->isa<CNode>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (CheckPrimitiveType(inputs[i], prim::kPrimMakeTuple)) {
|
||||||
|
FetchCNodeFromMakeTuple(inputs[i], &new_inputs);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
new_inputs.push_back(inputs[i]);
|
||||||
|
}
|
||||||
|
for (auto &node : new_inputs) {
|
||||||
|
func_graph->get_return()->add_input(node);
|
||||||
|
}
|
||||||
|
auto value = std::make_shared<UMonad>();
|
||||||
|
bool replace_succ = func_graph->manager()->Replace(anf_node, NewValueNode(value));
|
||||||
|
if (!replace_succ) {
|
||||||
|
MS_LOG(ERROR) << "replace redundant op failed.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
||||||
if (!utils::isa<CNodePtr>(anf_node)) {
|
if (!utils::isa<CNodePtr>(anf_node)) {
|
||||||
MS_LOG(DEBUG) << "anf node is node a cnode.";
|
MS_LOG(DEBUG) << "anf node is node a cnode.";
|
||||||
|
@ -111,7 +157,7 @@ bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
|
||||||
status = ReplaceOp(node, manager);
|
status = ReplaceOp(node, manager);
|
||||||
}
|
}
|
||||||
if (CheckPrimitiveType(node, prim::kPrimUpdateState)) {
|
if (CheckPrimitiveType(node, prim::kPrimUpdateState)) {
|
||||||
status = ReplaceOp(node, manager);
|
status = ReplaceUpdateStateOp(func_graph, node);
|
||||||
}
|
}
|
||||||
if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||||
status = ReplaceTupleGetItem(node, manager);
|
status = ReplaceTupleGetItem(node, manager);
|
||||||
|
|
|
@ -29,6 +29,7 @@ class RemoveRedundantOpPass : public Pass {
|
||||||
RemoveRedundantOpPass() : Pass("remove_redundant_op_pass") {}
|
RemoveRedundantOpPass() : Pass("remove_redundant_op_pass") {}
|
||||||
~RemoveRedundantOpPass() override = default;
|
~RemoveRedundantOpPass() override = default;
|
||||||
int ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
|
int ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
|
||||||
|
int ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node);
|
||||||
int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
|
int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
|
||||||
bool Run(const FuncGraphPtr &graph) override;
|
bool Run(const FuncGraphPtr &graph) override;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue