From db385e9b05b6e70fd43ecf5c770596773ea38257 Mon Sep 17 00:00:00 2001 From: linqingke Date: Tue, 23 Mar 2021 11:06:31 +0800 Subject: [PATCH] remove uselesee updatestate pattern and fix loss and J order --- .../optimizer/auto_monad_eliminate.cc | 25 +++---------------- mindspore/nn/wrap/cell_wrapper.py | 1 + 2 files changed, 4 insertions(+), 22 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc index be31d880e99..1102d3cca33 100644 --- a/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc @@ -107,20 +107,6 @@ std::vector> SplitGroup(const std::vector &topos // a = Load(para1, u1) // ... // b = Load(para1, u2) -// u3 = UpdateState(u2, b) -//==> -// delete the UpdateState -void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user, - const AnfNodePtr &load) { - const auto &load_cnode = load->cast(); - const auto &u = load_cnode->input(2); - manager->Replace(load_user, u); -} - -// Pattern2====================================== -// a = Load(para1, u1) -// ... -// b = Load(para1, u2) // t = make_tuple(x, b) // u3 = UpdateState(u2, t) //==> @@ -141,7 +127,7 @@ void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr manager->Replace(make_tuple, other_input); } -// Pattern3====================================== +// Pattern2====================================== // a = Load(para1, u1) // ... // b = Load(para1, u2) @@ -167,11 +153,6 @@ void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGrap void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) { auto load_users = manager->node_users()[load]; for (const auto &load_user : load_users) { - // Pattern1 - if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) { - DeleteLoadUserUpdateState(manager, load_user.first, load); - continue; - } if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) { const auto &make_tuple = load_user.first->cast(); auto &maketuple_users = manager->node_users()[make_tuple]; @@ -180,12 +161,12 @@ void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, if (!maketuple_as_input_of_update) { continue; } - // Pattern2 + // Pattern1 if (make_tuple->size() == 3) { DeleteLoadUserMakeTuple(manager, make_tuple, load); continue; } - // Pattern3 + // Pattern2 if (make_tuple->size() > 3) { ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load); } diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 003e257149a..b143a736bc4 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -352,6 +352,7 @@ class TrainOneStepCell(Cell): weights = self.weights loss = self.network(*inputs) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + sens = F.depend(sens, loss) grads = self.grad(self.network, weights)(*inputs, sens) grads = self.grad_reducer(grads) loss = F.depend(loss, self.optimizer(grads))