!13808 [auto_monad]Remove uselesee updatestate pattern and fix loss and J order.

From: @linqingke
Reviewed-by: @hwhewei,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-03-23 20:20:02 +08:00 committed by Gitee
commit ae3127f12f
2 changed files with 4 additions and 22 deletions

View File

@ -107,20 +107,6 @@ std::vector<std::vector<size_t>> SplitGroup(const std::vector<AnfNodePtr> &topos
// a = Load(para1, u1)
// ...
// b = Load(para1, u2)
// u3 = UpdateState(u2, b)
//==>
// delete the UpdateState
void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user,
const AnfNodePtr &load) {
const auto &load_cnode = load->cast<CNodePtr>();
const auto &u = load_cnode->input(2);
manager->Replace(load_user, u);
}
// Pattern2======================================
// a = Load(para1, u1)
// ...
// b = Load(para1, u2)
// t = make_tuple(x, b)
// u3 = UpdateState(u2, t)
//==>
@ -141,7 +127,7 @@ void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr
manager->Replace(make_tuple, other_input);
}
// Pattern3======================================
// Pattern2======================================
// a = Load(para1, u1)
// ...
// b = Load(para1, u2)
@ -167,11 +153,6 @@ void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGrap
void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) {
auto load_users = manager->node_users()[load];
for (const auto &load_user : load_users) {
// Pattern1
if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) {
DeleteLoadUserUpdateState(manager, load_user.first, load);
continue;
}
if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) {
const auto &make_tuple = load_user.first->cast<CNodePtr>();
auto &maketuple_users = manager->node_users()[make_tuple];
@ -180,12 +161,12 @@ void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg,
if (!maketuple_as_input_of_update) {
continue;
}
// Pattern2
// Pattern1
if (make_tuple->size() == 3) {
DeleteLoadUserMakeTuple(manager, make_tuple, load);
continue;
}
// Pattern3
// Pattern2
if (make_tuple->size() > 3) {
ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load);
}

View File

@ -352,6 +352,7 @@ class TrainOneStepCell(Cell):
weights = self.weights
loss = self.network(*inputs)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
sens = F.depend(sens, loss)
grads = self.grad(self.network, weights)(*inputs, sens)
grads = self.grad_reducer(grads)
loss = F.depend(loss, self.optimizer(grads))