!20882 Fix GradOperation recursive call issues.

Merge pull request !20882 from 张清华/opt
This commit is contained in:
i-robot 2021-07-27 01:00:39 +00:00 committed by Gitee
commit b8901526b9
6 changed files with 120 additions and 9 deletions

View File

@ -31,6 +31,7 @@
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "pipeline/jit/action.h" #include "pipeline/jit/action.h"
#include "pipeline/jit/parse/resolve.h" #include "pipeline/jit/parse/resolve.h"
#include "debug/anf_ir_dump.h"
namespace mindspore { namespace mindspore {
namespace ad { namespace ad {
@ -453,6 +454,9 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
} }
for (auto &fv : free_variables_nodes) { for (auto &fv : free_variables_nodes) {
if (IsPrimitiveCNode(fv, prim::kPrimJ)) { // Ignore if FV is a J CNode.
continue;
}
auto fv_adjoint = anfnode_to_adjoin_.find(fv); auto fv_adjoint = anfnode_to_adjoin_.find(fv);
if (fv_adjoint == anfnode_to_adjoin_.end()) { if (fv_adjoint == anfnode_to_adjoin_.end()) {
MS_LOG(EXCEPTION) << "AttachFvDoutToTape fv adjoint does not exist " << fv->ToString() << "."; MS_LOG(EXCEPTION) << "AttachFvDoutToTape fv adjoint does not exist " << fv->ToString() << ".";
@ -832,7 +836,31 @@ CNodePtr GetJUser(const NodeUsersMap &node_user_map, const CNodePtr &cnode, int
auto &j_users = it->second; auto &j_users = it->second;
auto size = j_users.size(); auto size = j_users.size();
if (size != 1) { if (size != 1) {
MS_LOG(EXCEPTION) << "Wrong J CNode use size " << size << " {" << cnode->DebugString(2) << "/" << index << "}"; bool has_multiple_j_call_user = false;
CNodePtr j_call_user = nullptr;
for (auto &user : j_users) {
// If J CNode is used as a FV, the j_users.size may exceed 1 user. It is allowed.
if (user.second == 0) {
// Real J CNode call user.
if (j_call_user == nullptr) { // First user.
j_call_user = user.first->cast<CNodePtr>();
} else { // More than 1 call user. Not allowed.
has_multiple_j_call_user = true;
}
}
}
if (has_multiple_j_call_user) { // Has multiple J CNode call user.
std::ostringstream user_info;
for (auto &user : j_users) {
user_info << " user: " << user.first->DebugString() << ", index: " << user.second << "\n";
}
DumpIR("J_User_Ex_" + cnode->func_graph()->ToString() + ".ir", cnode->func_graph());
MS_LOG(EXCEPTION) << "Incorrect J CNode user size: " << size << ", of {" << cnode->DebugString(2) << "/" << index
<< "}\nUser Info:\n"
<< user_info.str();
} else {
return j_call_user;
}
} }
return j_users.begin()->first->cast<CNodePtr>(); return j_users.begin()->first->cast<CNodePtr>();
} }

View File

@ -954,17 +954,19 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
MS_LOG(DEBUG) << fg->ToString() << " had been checked"; MS_LOG(DEBUG) << fg->ToString() << " had been checked";
return false; return false;
} }
// Check J FuncGraph input.
const auto &j_values = fg->j_value_nodes(); const auto &j_values = fg->j_value_nodes();
if (!j_values.empty()) { if (!j_values.empty()) {
auto contains_j = auto contains_j =
std::find_if(j_values.begin(), j_values.end(), [seen_num](const std::pair<AnfNodePtr, int> &iter) { std::find_if(j_values.begin(), j_values.end(), [seen_num](const std::pair<AnfNodePtr, int> &iter) {
// check g1->J(fg)->g2->g cycle. // Check g1->J(fg)->g2->g cycle.
if (IsValueNode<FuncGraph>(iter.first)) { if (IsValueNode<FuncGraph>(iter.first)) {
auto func_graph = GetValueNode<FuncGraphPtr>(iter.first); auto func_graph = GetValueNode<FuncGraphPtr>(iter.first);
return func_graph->seen_ != seen_num; return func_graph->seen_ != seen_num;
} }
if (IsValueNode<Primitive>(iter.first)) { if (IsValueNode<Primitive>(iter.first)) {
// exclude the primitive of J itself. // Exclude the primitive of J itself.
auto prim = GetValueNode<PrimitivePtr>(iter.first); auto prim = GetValueNode<PrimitivePtr>(iter.first);
return prim->name() != prim::kPrimJ->name(); return prim->name() != prim::kPrimJ->name();
} }
@ -975,9 +977,26 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
return true; return true;
} }
} }
fg->seen_ = seen_num;
// check if func graphs used contains J(func_graph) or J(Primitive) // Check J CNode as FV.
const auto &fv_nodes = fg->free_variables();
if (!fv_nodes.empty()) {
auto contains_j_cnode =
std::find_if(fv_nodes.begin(), fv_nodes.end(), [seen_num](const std::pair<AnfNodePtr, int> &iter) {
// Check if the FV is a J call CNode.
if (IsPrimitiveCNode(iter.first, prim::kPrimJ)) {
return true;
}
return false;
});
if (contains_j_cnode != fv_nodes.end()) {
MS_LOG(DEBUG) << fg->ToString() << " contains FV J(" << contains_j_cnode->first->DebugString() << ")";
return true;
}
}
// Check if func graphs used contains J(func_graph) or J(Primitive)
fg->seen_ = seen_num;
for (auto &item : fg->func_graphs_used()) { for (auto &item : fg->func_graphs_used()) {
auto used_g = item.first; auto used_g = item.first;
if (SeekJ(used_g, seen_num)) { if (SeekJ(used_g, seen_num)) {

View File

@ -24,7 +24,9 @@ from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, Mult
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_ TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.api import ms_function, _pynative_exec, _wrap_func from ...common.api import ms_function, _pynative_exec, _wrap_func
from .. import functional as F from ..primitive import Primitive
from ..operations import _grad_ops
from .. import operations as P
from .. import signature as sig from .. import signature as sig
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
@ -659,7 +661,12 @@ zip_operation = _ZipOperation('zip_operation')
env_get = MultitypeFuncGraph("env_get") env_get = MultitypeFuncGraph("env_get")
env_getitem = Primitive('env_getitem')
ref_to_embed = _grad_ops.RefToEmbed()
zeros_like = P.ZerosLike()
@env_get.register("EnvType", "Tensor") @env_get.register("EnvType", "Tensor")
def _tensor_env_get(env, parameter): def _tensor_env_get(env, parameter):
"""Used to get env.""" """Used to get env."""
return F.env_getitem(env, F.ref_to_embed(parameter), F.zeros_like(parameter)) return env_getitem(env, ref_to_embed(parameter), zeros_like(parameter))

View File

@ -22,6 +22,7 @@ from mindspore.ops import _constants
from .primitive import Primitive from .primitive import Primitive
from . import operations as P from . import operations as P
from .operations import _grad_ops from .operations import _grad_ops
from .composite import GradOperation
typeof = Primitive('typeof') typeof = Primitive('typeof')
hastype = Primitive('hastype') hastype = Primitive('hastype')
@ -146,6 +147,23 @@ partial = P.Partial()
depend = P.Depend() depend = P.Depend()
identity = P.identity() identity = P.identity()
grad_first_parameter = GradOperation(get_all=False, get_by_list=False, sens_param=False)
grad_all_parameters = GradOperation(get_all=True, get_by_list=False, sens_param=False)
def grad(fn, grad_first_param=False):
"""
A wrapper function to generate the gradient function for the input function.
Args:
fn (Function): Function to do GradOperation.
grad_first_param (bool): If True, get the gradient with respect to first input.
If False, get all the gradients with respect to inputs. Default: False.
"""
if grad_first_param:
return grad_first_parameter(fn)
return grad_all_parameters(fn)
tuple_setitem = Primitive('tuple_setitem') tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive(_constants.kTupleGetItem) tuple_getitem = Primitive(_constants.kTupleGetItem)
list_getitem = Primitive('list_getitem') list_getitem = Primitive('list_getitem')

View File

@ -22,7 +22,6 @@ from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from ..._checkparam import Validator as validator, Rel from ..._checkparam import Validator as validator, Rel
from .._utils import get_concat_offset from .._utils import get_concat_offset
from ...common import dtype as mstype from ...common import dtype as mstype
from .. import functional as F
from ... import context from ... import context
@ -1951,6 +1950,7 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer):
def __init__(self): def __init__(self):
self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output']) self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output'])
self.add_prim_attr('primitive_target', 'CPU') self.add_prim_attr('primitive_target', 'CPU')
self.tuple_setitem = Primitive('tuple_setitem')
def __infer__(self, dy, split_num): def __infer__(self, dy, split_num):
""" """
@ -1965,7 +1965,7 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer):
dy_shape = tuple(dy['shape']) dy_shape = tuple(dy['shape'])
split_num_value = split_num['value'] split_num_value = split_num['value']
validator.check_value_type("split_num_value", split_num_value, [int], self.name) validator.check_value_type("split_num_value", split_num_value, [int], self.name)
dy_shape_all = F.tuple_setitem(dy_shape, 0, dy_shape[0] * 8) dy_shape_all = self.tuple_setitem(dy_shape, 0, dy_shape[0] * 8)
return {'shape': dy_shape_all, return {'shape': dy_shape_all,
'dtype': dy['dtype'], 'dtype': dy['dtype'],
'value': None} 'value': None}

View File

@ -22,6 +22,7 @@ import mindspore.nn as nn
import mindspore.ops.composite as C import mindspore.ops.composite as C
from mindspore import Tensor from mindspore import Tensor
from mindspore import ops, Parameter, context from mindspore import ops, Parameter, context
from mindspore import ms_function
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
@ -42,6 +43,44 @@ from ....ops_common import convert
grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True) grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
class TargetNet(nn.Cell):
def __init__(self):
super(TargetNet, self).__init__()
self.mul = P.Mul()
def construct(self, x, y):
return self.mul(x, y)
# Recursive GradOperation in Cell.
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = C.GradOperation()
self.network = network
def construct(self, x, y):
return self.grad(self.network)(x, y)
# Recursive GradOperaton with GradOperation object.
grad1 = C.GradOperation()
@ms_function
def f1(x, y):
return grad1(grad1(TargetNet()))(x, y)
# Recursive GradOperaton with F.grad.
@ms_function
def f2(x, y):
return F.grad(F.grad(TargetNet()))(x, y)
def test_recursive_grad():
x = Tensor(3, mstype.float32)
y = Tensor(1, mstype.float32)
Grad(Grad(TargetNet()))(x, y)
f1(x, y)
f2(x, y)
class InputBackward(nn.Cell): class InputBackward(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(InputBackward, self).__init__() super(InputBackward, self).__init__()