forked from mindspore-Ecosystem/mindspore
!20882 Fix GradOperation recursive call issues.
Merge pull request !20882 from 张清华/opt
This commit is contained in:
commit
b8901526b9
|
@ -31,6 +31,7 @@
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "pipeline/jit/action.h"
|
#include "pipeline/jit/action.h"
|
||||||
#include "pipeline/jit/parse/resolve.h"
|
#include "pipeline/jit/parse/resolve.h"
|
||||||
|
#include "debug/anf_ir_dump.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ad {
|
namespace ad {
|
||||||
|
@ -453,6 +454,9 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &fv : free_variables_nodes) {
|
for (auto &fv : free_variables_nodes) {
|
||||||
|
if (IsPrimitiveCNode(fv, prim::kPrimJ)) { // Ignore if FV is a J CNode.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
auto fv_adjoint = anfnode_to_adjoin_.find(fv);
|
auto fv_adjoint = anfnode_to_adjoin_.find(fv);
|
||||||
if (fv_adjoint == anfnode_to_adjoin_.end()) {
|
if (fv_adjoint == anfnode_to_adjoin_.end()) {
|
||||||
MS_LOG(EXCEPTION) << "AttachFvDoutToTape fv adjoint does not exist " << fv->ToString() << ".";
|
MS_LOG(EXCEPTION) << "AttachFvDoutToTape fv adjoint does not exist " << fv->ToString() << ".";
|
||||||
|
@ -832,7 +836,31 @@ CNodePtr GetJUser(const NodeUsersMap &node_user_map, const CNodePtr &cnode, int
|
||||||
auto &j_users = it->second;
|
auto &j_users = it->second;
|
||||||
auto size = j_users.size();
|
auto size = j_users.size();
|
||||||
if (size != 1) {
|
if (size != 1) {
|
||||||
MS_LOG(EXCEPTION) << "Wrong J CNode use size " << size << " {" << cnode->DebugString(2) << "/" << index << "}";
|
bool has_multiple_j_call_user = false;
|
||||||
|
CNodePtr j_call_user = nullptr;
|
||||||
|
for (auto &user : j_users) {
|
||||||
|
// If J CNode is used as a FV, the j_users.size may exceed 1 user. It is allowed.
|
||||||
|
if (user.second == 0) {
|
||||||
|
// Real J CNode call user.
|
||||||
|
if (j_call_user == nullptr) { // First user.
|
||||||
|
j_call_user = user.first->cast<CNodePtr>();
|
||||||
|
} else { // More than 1 call user. Not allowed.
|
||||||
|
has_multiple_j_call_user = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (has_multiple_j_call_user) { // Has multiple J CNode call user.
|
||||||
|
std::ostringstream user_info;
|
||||||
|
for (auto &user : j_users) {
|
||||||
|
user_info << " user: " << user.first->DebugString() << ", index: " << user.second << "\n";
|
||||||
|
}
|
||||||
|
DumpIR("J_User_Ex_" + cnode->func_graph()->ToString() + ".ir", cnode->func_graph());
|
||||||
|
MS_LOG(EXCEPTION) << "Incorrect J CNode user size: " << size << ", of {" << cnode->DebugString(2) << "/" << index
|
||||||
|
<< "}\nUser Info:\n"
|
||||||
|
<< user_info.str();
|
||||||
|
} else {
|
||||||
|
return j_call_user;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return j_users.begin()->first->cast<CNodePtr>();
|
return j_users.begin()->first->cast<CNodePtr>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -954,17 +954,19 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
|
||||||
MS_LOG(DEBUG) << fg->ToString() << " had been checked";
|
MS_LOG(DEBUG) << fg->ToString() << " had been checked";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check J FuncGraph input.
|
||||||
const auto &j_values = fg->j_value_nodes();
|
const auto &j_values = fg->j_value_nodes();
|
||||||
if (!j_values.empty()) {
|
if (!j_values.empty()) {
|
||||||
auto contains_j =
|
auto contains_j =
|
||||||
std::find_if(j_values.begin(), j_values.end(), [seen_num](const std::pair<AnfNodePtr, int> &iter) {
|
std::find_if(j_values.begin(), j_values.end(), [seen_num](const std::pair<AnfNodePtr, int> &iter) {
|
||||||
// check g1->J(fg)->g2->g cycle.
|
// Check g1->J(fg)->g2->g cycle.
|
||||||
if (IsValueNode<FuncGraph>(iter.first)) {
|
if (IsValueNode<FuncGraph>(iter.first)) {
|
||||||
auto func_graph = GetValueNode<FuncGraphPtr>(iter.first);
|
auto func_graph = GetValueNode<FuncGraphPtr>(iter.first);
|
||||||
return func_graph->seen_ != seen_num;
|
return func_graph->seen_ != seen_num;
|
||||||
}
|
}
|
||||||
if (IsValueNode<Primitive>(iter.first)) {
|
if (IsValueNode<Primitive>(iter.first)) {
|
||||||
// exclude the primitive of J itself.
|
// Exclude the primitive of J itself.
|
||||||
auto prim = GetValueNode<PrimitivePtr>(iter.first);
|
auto prim = GetValueNode<PrimitivePtr>(iter.first);
|
||||||
return prim->name() != prim::kPrimJ->name();
|
return prim->name() != prim::kPrimJ->name();
|
||||||
}
|
}
|
||||||
|
@ -975,9 +977,26 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fg->seen_ = seen_num;
|
|
||||||
|
|
||||||
// check if func graphs used contains J(func_graph) or J(Primitive)
|
// Check J CNode as FV.
|
||||||
|
const auto &fv_nodes = fg->free_variables();
|
||||||
|
if (!fv_nodes.empty()) {
|
||||||
|
auto contains_j_cnode =
|
||||||
|
std::find_if(fv_nodes.begin(), fv_nodes.end(), [seen_num](const std::pair<AnfNodePtr, int> &iter) {
|
||||||
|
// Check if the FV is a J call CNode.
|
||||||
|
if (IsPrimitiveCNode(iter.first, prim::kPrimJ)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
if (contains_j_cnode != fv_nodes.end()) {
|
||||||
|
MS_LOG(DEBUG) << fg->ToString() << " contains FV J(" << contains_j_cnode->first->DebugString() << ")";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if func graphs used contains J(func_graph) or J(Primitive)
|
||||||
|
fg->seen_ = seen_num;
|
||||||
for (auto &item : fg->func_graphs_used()) {
|
for (auto &item : fg->func_graphs_used()) {
|
||||||
auto used_g = item.first;
|
auto used_g = item.first;
|
||||||
if (SeekJ(used_g, seen_num)) {
|
if (SeekJ(used_g, seen_num)) {
|
||||||
|
|
|
@ -24,7 +24,9 @@ from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, Mult
|
||||||
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
|
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ...common.api import ms_function, _pynative_exec, _wrap_func
|
from ...common.api import ms_function, _pynative_exec, _wrap_func
|
||||||
from .. import functional as F
|
from ..primitive import Primitive
|
||||||
|
from ..operations import _grad_ops
|
||||||
|
from .. import operations as P
|
||||||
from .. import signature as sig
|
from .. import signature as sig
|
||||||
|
|
||||||
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
||||||
|
@ -659,7 +661,12 @@ zip_operation = _ZipOperation('zip_operation')
|
||||||
env_get = MultitypeFuncGraph("env_get")
|
env_get = MultitypeFuncGraph("env_get")
|
||||||
|
|
||||||
|
|
||||||
|
env_getitem = Primitive('env_getitem')
|
||||||
|
ref_to_embed = _grad_ops.RefToEmbed()
|
||||||
|
zeros_like = P.ZerosLike()
|
||||||
|
|
||||||
|
|
||||||
@env_get.register("EnvType", "Tensor")
|
@env_get.register("EnvType", "Tensor")
|
||||||
def _tensor_env_get(env, parameter):
|
def _tensor_env_get(env, parameter):
|
||||||
"""Used to get env."""
|
"""Used to get env."""
|
||||||
return F.env_getitem(env, F.ref_to_embed(parameter), F.zeros_like(parameter))
|
return env_getitem(env, ref_to_embed(parameter), zeros_like(parameter))
|
||||||
|
|
|
@ -22,6 +22,7 @@ from mindspore.ops import _constants
|
||||||
from .primitive import Primitive
|
from .primitive import Primitive
|
||||||
from . import operations as P
|
from . import operations as P
|
||||||
from .operations import _grad_ops
|
from .operations import _grad_ops
|
||||||
|
from .composite import GradOperation
|
||||||
|
|
||||||
typeof = Primitive('typeof')
|
typeof = Primitive('typeof')
|
||||||
hastype = Primitive('hastype')
|
hastype = Primitive('hastype')
|
||||||
|
@ -146,6 +147,23 @@ partial = P.Partial()
|
||||||
depend = P.Depend()
|
depend = P.Depend()
|
||||||
identity = P.identity()
|
identity = P.identity()
|
||||||
|
|
||||||
|
grad_first_parameter = GradOperation(get_all=False, get_by_list=False, sens_param=False)
|
||||||
|
grad_all_parameters = GradOperation(get_all=True, get_by_list=False, sens_param=False)
|
||||||
|
|
||||||
|
def grad(fn, grad_first_param=False):
|
||||||
|
"""
|
||||||
|
A wrapper function to generate the gradient function for the input function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn (Function): Function to do GradOperation.
|
||||||
|
grad_first_param (bool): If True, get the gradient with respect to first input.
|
||||||
|
If False, get all the gradients with respect to inputs. Default: False.
|
||||||
|
"""
|
||||||
|
if grad_first_param:
|
||||||
|
return grad_first_parameter(fn)
|
||||||
|
return grad_all_parameters(fn)
|
||||||
|
|
||||||
|
|
||||||
tuple_setitem = Primitive('tuple_setitem')
|
tuple_setitem = Primitive('tuple_setitem')
|
||||||
tuple_getitem = Primitive(_constants.kTupleGetItem)
|
tuple_getitem = Primitive(_constants.kTupleGetItem)
|
||||||
list_getitem = Primitive('list_getitem')
|
list_getitem = Primitive('list_getitem')
|
||||||
|
|
|
@ -22,7 +22,6 @@ from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||||
from ..._checkparam import Validator as validator, Rel
|
from ..._checkparam import Validator as validator, Rel
|
||||||
from .._utils import get_concat_offset
|
from .._utils import get_concat_offset
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from .. import functional as F
|
|
||||||
from ... import context
|
from ... import context
|
||||||
|
|
||||||
|
|
||||||
|
@ -1951,6 +1950,7 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output'])
|
self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output'])
|
||||||
self.add_prim_attr('primitive_target', 'CPU')
|
self.add_prim_attr('primitive_target', 'CPU')
|
||||||
|
self.tuple_setitem = Primitive('tuple_setitem')
|
||||||
|
|
||||||
def __infer__(self, dy, split_num):
|
def __infer__(self, dy, split_num):
|
||||||
"""
|
"""
|
||||||
|
@ -1965,7 +1965,7 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer):
|
||||||
dy_shape = tuple(dy['shape'])
|
dy_shape = tuple(dy['shape'])
|
||||||
split_num_value = split_num['value']
|
split_num_value = split_num['value']
|
||||||
validator.check_value_type("split_num_value", split_num_value, [int], self.name)
|
validator.check_value_type("split_num_value", split_num_value, [int], self.name)
|
||||||
dy_shape_all = F.tuple_setitem(dy_shape, 0, dy_shape[0] * 8)
|
dy_shape_all = self.tuple_setitem(dy_shape, 0, dy_shape[0] * 8)
|
||||||
return {'shape': dy_shape_all,
|
return {'shape': dy_shape_all,
|
||||||
'dtype': dy['dtype'],
|
'dtype': dy['dtype'],
|
||||||
'value': None}
|
'value': None}
|
||||||
|
|
|
@ -22,6 +22,7 @@ import mindspore.nn as nn
|
||||||
import mindspore.ops.composite as C
|
import mindspore.ops.composite as C
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore import ops, Parameter, context
|
from mindspore import ops, Parameter, context
|
||||||
|
from mindspore import ms_function
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
@ -42,6 +43,44 @@ from ....ops_common import convert
|
||||||
grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
|
grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TargetNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TargetNet, self).__init__()
|
||||||
|
self.mul = P.Mul()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.mul(x, y)
|
||||||
|
|
||||||
|
# Recursive GradOperation in Cell.
|
||||||
|
class Grad(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(Grad, self).__init__()
|
||||||
|
self.grad = C.GradOperation()
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.grad(self.network)(x, y)
|
||||||
|
|
||||||
|
# Recursive GradOperaton with GradOperation object.
|
||||||
|
grad1 = C.GradOperation()
|
||||||
|
@ms_function
|
||||||
|
def f1(x, y):
|
||||||
|
return grad1(grad1(TargetNet()))(x, y)
|
||||||
|
|
||||||
|
# Recursive GradOperaton with F.grad.
|
||||||
|
@ms_function
|
||||||
|
def f2(x, y):
|
||||||
|
return F.grad(F.grad(TargetNet()))(x, y)
|
||||||
|
|
||||||
|
def test_recursive_grad():
|
||||||
|
x = Tensor(3, mstype.float32)
|
||||||
|
y = Tensor(1, mstype.float32)
|
||||||
|
|
||||||
|
Grad(Grad(TargetNet()))(x, y)
|
||||||
|
f1(x, y)
|
||||||
|
f2(x, y)
|
||||||
|
|
||||||
|
|
||||||
class InputBackward(nn.Cell):
|
class InputBackward(nn.Cell):
|
||||||
def __init__(self, network):
|
def __init__(self, network):
|
||||||
super(InputBackward, self).__init__()
|
super(InputBackward, self).__init__()
|
||||||
|
|
Loading…
Reference in New Issue