forked from mindspore-Ecosystem/mindspore
Fix GradOperation descripition and support tuple/list for 'get_by_list'.
This commit is contained in:
@ -13,7 +13,9 @@ mindspore/mindspore/ccsrc/frontend/optimizer/
@ -32,7 +34,6 @@ mindspore/mindspore/lite/src/runtime/
@ -45,6 +45,11 @@ mindspore.ops.GradOperation
4. 用 `net` 的输入作为参数调用梯度函数,得到关于所有输入和给定参数的梯度:`gradient_function(x, y)` 。
1. 如果仅有一个梯度结果,返回单独值;
2. 如果有多个梯度结果,返回tuple;
3. 如果没有任何梯度结果,返回空tuple;
我们可以设置 `sens_param` 等于True来配置灵敏度(关于输出的梯度),向梯度函数传递一个额外的灵敏度输入值。这个输入值必须与 `net` 的输出具有相同的形状和类型(见样例中的 `GradNetWrtXYWithSensParam` )。
1. 构建一个带有 `get_all=True` 和 `sens_param=True` 参数的 `GradOperation` 高阶函数:`grad_op = GradOperation(get_all=True, sens_param=True)` 。
@ -555,9 +555,9 @@ FuncGraphPtr Tail::GenerateGradFuncGraph(const AbstractTuplePtr &tuple_arg, cons
if (tail_type_ == kGradFirst) {
AnfNodePtr tuple_parameter = fg->add_parameter();
PrimitivePtr getitem_op = prim::kPrimTupleGetItem;
if (CanGradArgument(tuple_arg, 1) || EnableGradFirstForTuple(tuple_arg, enable_tuple_grad_first_)) {
fg->set_output(fg->NewCNode({NewValueNode(getitem_op), tuple_parameter, NewValueNode(SizeToLong(1))}));
fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(1))}));
} else {
@ -572,14 +572,22 @@ FuncGraphPtr Tail::GenerateGradFuncGraph(const AbstractTuplePtr &tuple_arg, cons
if (tail_type_ == kGradAll) {
AnfNodePtr tuple_parameter = fg->add_parameter();
std::vector<AnfNodePtr> elements = {NewValueNode(prim::kPrimMakeTuple)};
PrimitivePtr op = prim::kPrimTupleGetItem;
for (size_t i = 1; i < tuple_arg->size(); ++i) {
if (CanGradArgument(tuple_arg, i)) {
elements.push_back(fg->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(i))}));
if (elements.size() > 1) {
// We should deal with 'get_all=True' as other options later:
// "The returned result may vary for grad result element number.
// A single value if only one result, a tuple for multiple results, or a empty tuple for no result.
// Notice that even if the user set 'get_all=True' and pass multiple inputs,
// the 'CanGradArgument' may change it to only one gradient output or no gradient."
constexpr size_t args_least_size = 2;
if (elements.size() >= args_least_size) {
return fg;
@ -744,12 +752,8 @@ void CheckPrimBpropReturnSparse(const FuncGraphPtr &primal_graph) {
if (has_sparse_bprop_prim) {
return EXCLUDE;
auto prim = GetCNodePrimitive(node);
auto prim = GetCNodePrimitiveWithoutDoSignature(node);
if (prim != nullptr) {
auto do_signature = dyn_cast<mindspore::prim::DoSignaturePrimitive>(prim);
if (do_signature != nullptr) {
prim = dyn_cast<Primitive>(do_signature->function());
bool sparse_bprop = GetPrimitiveFlag(prim, GRAPH_FLAG_BPROP_RETURN_SPARSE);
if (sparse_bprop) {
MS_LOG(DEBUG) << "prim: " << prim->ToString() << " has attr 'bprop_return_sparse'";
@ -67,26 +67,16 @@ static AnfNodePtr GenerateUnpackGraphNode(const AnfNodePtr &origin_node, std::ve
return unpack_graph_node;
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) {
ValuePtr value;
if (IsValueNode<prim::DoSignaturePrimitive>(node)) {
value = GetValueNode(node)->cast<prim::DoSignaturePrimitivePtr>()->function();
} else {
value = GetValueNode(node);
if (value == nullptr) {
return nullptr;
return value->cast<MetaFuncGraphPtr>();
// check if node is a specific meta_fg_opration that registered in the meta_fg_ops
// Check if node is a specific meta_fg_opration that registered in the meta_fg_ops
bool CheckMetaFgOps(const AnfNodePtr &node) {
if (node == nullptr) {
return false;
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
auto value = GetValueWithoutDoSignature(node);
if (value == nullptr) {
return false;
auto meta_func_graph_ptr = value->cast<MetaFuncGraphPtr>();
if (meta_func_graph_ptr == nullptr) {
return false;
@ -143,16 +133,29 @@ AnfNodePtr MetaFgVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &
// For general meta_fg_opration, ‘sens_param’ is not involved, and that of GradOperation obtained specifically.
bool sens_param = false;
if (grad_op_->Match(inputs_x[0])) {
auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]);
if (meta_func == nullptr) {
return nullptr;
auto value = GetValueWithoutDoSignature(inputs_x[0]);
auto meta_func = value->cast<MetaFuncGraphPtr>();
auto grad_op_ptr = meta_func->cast<prim::GradOperationPtr>();
sens_param = grad_op_ptr->sens_param();
// Remove the tuple/list inputs from order list for grad(UnpackGraph(..), list/tuple)(..)
if (inputs_x.size() > inputs_x_minimum_size) {
constexpr size_t sequence_input_pos = 2;
auto seq_node = inputs_x[sequence_input_pos];
auto prim = GetCNodePrimitiveWithoutDoSignature(seq_node);
if (prim != nullptr &&
(IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList))) {
auto seq_cnode = dyn_cast<CNode>(seq_node);
inputs_x[1] = GenerateUnpackGraphNode(node, inputs_y, func_node, is_unpack, sens_param);
// construct new meta_fg_opration
// Construct new meta_fg_opration
auto meta_fg_op_cnode = func_graph->NewCNodeBefore(node, inputs_x);
if (unpack_op_->Match(inputs_y[0])) {
inputs_y[1] = meta_fg_op_cnode;
@ -33,9 +33,6 @@
namespace mindspore {
namespace opt {
namespace irpass {
// get metagraph of value node
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node);
class Matcher {
Matcher() {}
@ -59,7 +56,11 @@ class MetaFgMatcher : public Matcher {
if (node == nullptr) {
return false;
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
auto value = GetValueWithoutDoSignature(node);
if (value == nullptr) {
return false;
auto meta_func_graph_ptr = value->cast<MetaFuncGraphPtr>();
if (meta_func_graph_ptr == nullptr) {
return false;
@ -242,13 +242,15 @@ bool ResolveObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, A
MS_LOG(DEBUG) << "Add param graph:" << func_graph->ToString() << ", " << param->DebugString();
output = param;
*node = output;
return true;
} else if (py::hasattr(obj, "__parameter_tuple__")) {
auto tuple = obj.cast<py::tuple>();
std::vector<AnfNodePtr> args;
for (size_t it = 0; it < tuple.size(); ++it) {
for (size_t i = 0; i < tuple.size(); ++i) {
AnfNodePtr out = nullptr;
bool success = ResolveObjectToNode(origin_node, tuple[it], &out);
bool success = ResolveObjectToNode(origin_node, tuple[i], &out);
if (!success) {
MS_LOG(ERROR) << "Resolve object to node failed";
return false;
@ -258,27 +260,54 @@ bool ResolveObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, A
// The ParameterTuple will not be added in order list,
// since we don't want to deal with its RefTensor elements during auto_monad procedure.
output = NewCNode(std::move(args), func_graph);
} else {
ValuePtr convert_result = nullptr;
// When the cell is set recomputed, it should not use old scope from cache.
auto scope = origin_node->scope();
bool has_recompute_scope = (scope == nullptr) ? false : scope->name().find(kAttrRecompute) == 0;
bool converted =
ConvertData(obj, &convert_result, python_adapter::UseSignatureInResolve(), nullptr, has_recompute_scope);
if (!converted) {
MS_LOG(ERROR) << "Convert data failed";
return false;
*node = output;
return true;
} else if ((py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) && py::len(obj) != 0) {
auto tuple = obj.cast<py::tuple>();
std::vector<AnfNodePtr> args;
bool all_parameter_sequence = true;
for (size_t i = 0; i < tuple.size(); ++i) {
if (!py::hasattr(tuple[i], "__parameter__") || !py::isinstance<tensor::MetaTensor>(tuple[i])) {
all_parameter_sequence = false;
AnfNodePtr out = nullptr;
bool success = ResolveObjectToNode(origin_node, tuple[i], &out);
if (!success) {
MS_LOG(ERROR) << "Resolve object to node failed";
return false;
if (convert_result->isa<FuncGraph>() && has_recompute_scope) {
UpdateDebugInfo(convert_result->cast<FuncGraphPtr>(), origin_node->scope(), origin_node->debug_info());
ConvertLoadedGraph(func_graph, convert_result);
output = NewValueNode(convert_result);
if (convert_result->isa<tensor::Tensor>()) {
output = GetMixedPrecisionCastHelp(func_graph, output);
if (all_parameter_sequence) {
// The Parameter tuple/list will not be added in order list,
// since we don't want to deal with its RefTensor elements during auto_monad procedure.
output = NewCNode(std::move(args), func_graph);
*node = output;
return true;
// When the cell is set recomputed, it should not use old scope from cache.
auto scope = origin_node->scope();
bool has_recompute_scope = (scope == nullptr) ? false : scope->name().find(kAttrRecompute) == 0;
ValuePtr convert_result = nullptr;
bool converted =
ConvertData(obj, &convert_result, python_adapter::UseSignatureInResolve(), nullptr, has_recompute_scope);
if (!converted) {
MS_LOG(ERROR) << "Convert data failed";
return false;
if (convert_result->isa<FuncGraph>() && has_recompute_scope) {
UpdateDebugInfo(convert_result->cast<FuncGraphPtr>(), origin_node->scope(), origin_node->debug_info());
ConvertLoadedGraph(func_graph, convert_result);
output = NewValueNode(convert_result);
if (convert_result->isa<tensor::Tensor>()) {
output = GetMixedPrecisionCastHelp(func_graph, output);
*node = output;
return true;
@ -139,25 +139,6 @@ bool IsKeepRef(const PrimitivePtr &prim) {
IsPrimitiveEquals(prim, prim::kPrimPull);
// Gets primitive if the node is a primitive value node.
PrimitivePtr GetPrimitive(const AnfNodePtr &node) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
auto do_sig = dyn_cast<mindspore::prim::DoSignaturePrimitive>(prim);
if (do_sig) {
auto val = do_sig->function();
return dyn_cast<Primitive>(val);
return prim;
// Gets primitive from the given cnode, return nullptr if cnode.inputs[0] is not a primitive.
PrimitivePtr GetPrimitive(const CNodePtr &cnode) {
if (cnode == nullptr || cnode->inputs().empty()) {
return nullptr;
return GetPrimitive(cnode->input(0));
// Gets func_graph from the given cnode, return nullptr if it is not a func graph call.
FuncGraphPtr GetFuncGraph(const CNodePtr &cnode) {
if (cnode != nullptr && !cnode->inputs().empty()) {
@ -620,7 +601,7 @@ class SideEffectFinder {
EffectInfo TraceTupleCNodeEffectInfo(const CNodePtr &cnode, std::stack<int64_t> *tuple_indexes) {
auto prim = GetPrimitive(cnode);
auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
// Trace MakeTuple.
if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
if (tuple_indexes->empty()) {
@ -724,7 +705,7 @@ class SideEffectFinder {
// Trace a cnode for effect info.
EffectInfo TraceEffectInfo(const CNodePtr &cnode) {
auto prim = GetPrimitive(cnode);
auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
if (IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
// Special handling for Switch primitive.
return TraceSwitchEffectInfo(cnode);
@ -778,7 +759,7 @@ class SideEffectFinder {
return {EffectInfo::kDetected, false, false, false};
// Trace an ANFNode for effect info.
// Trace an AnfNode for effect info.
EffectInfo TraceEffectInfo(const AnfNodePtr &node) {
if (node != nullptr) {
// Trace cnode.
@ -794,7 +775,7 @@ class SideEffectFinder {
// Trace primitive.
auto prim = GetPrimitive(node);
auto prim = GetPrimitiveWithoutDoSignature(node);
if (prim != nullptr) {
return GetPrimEffectInfo(prim);
@ -886,7 +867,7 @@ class SideEffectFinder {
// Detect effect info by depth first search.
EffectInfo DetectEffectInfo(const CNodePtr &cnode) {
// For primitive, get effect info from its attributes and inputs.
auto prim = GetPrimitive(cnode);
auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
if (prim != nullptr) {
// Skip 'return' cnode.
if (IsPrimitiveEquals(prim, prim::kPrimReturn)) {
@ -58,17 +58,6 @@ void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &,
PrimitivePtr GetRealPrimitive(const AnfNodePtr &node) {
const auto &primitive = GetCNodePrimitive(node);
if (primitive != nullptr) {
auto do_signature_prim = dyn_cast<prim::DoSignaturePrimitive>(primitive);
if (do_signature_prim != nullptr) {
return dyn_cast<Primitive>(do_signature_prim->function());
return primitive;
} // namespace
bool CheckIfAlwaysEval(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
@ -92,7 +81,7 @@ bool CheckIfAlwaysEval(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg)
void BaseFuncGraphEvaluator::CollectSideEffectNodes(const AnfNodePtr &node,
std::vector<AnfNodePtr> *side_effect_nodes) {
const auto &primitive = GetRealPrimitive(node);
const auto &primitive = GetCNodePrimitiveWithoutDoSignature(node);
if (primitive != nullptr) {
auto effect_info = GetPrimEffectInfo(primitive);
if (effect_info.memory || {
@ -201,21 +201,10 @@ class OrderEnforcer {
return IsPrimitiveCNode(node, prim::kPrimExpandDims) || IsPrimitiveCNode(node, prim::kPrimBatchNormGrad);
// Gets primitive if the node is a primitive value node.
PrimitivePtr GetPrimitive(const AnfNodePtr &node) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
auto do_sig = dyn_cast<mindspore::prim::DoSignaturePrimitive>(prim);
if (do_sig) {
auto val = do_sig->function();
return dyn_cast<Primitive>(val);
return prim;
bool IsSpecialParallelPrimitive(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
auto prim = GetPrimitive(cnode->input(0));
auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
if (prim == nullptr) {
return false;
@ -218,11 +218,9 @@ bool IrExportBuilder::BuildPrimitives() {
if (prim->isa<prim::DoSignaturePrimitive>()) {
auto func = prim->cast<prim::DoSignaturePrimitivePtr>()->function();
if (func != nullptr && func->isa<Primitive>()) {
prim = func->cast<PrimitivePtr>();
auto real_prim = GetValueWithoutDoSignature(prim)->cast<PrimitivePtr>();
if (real_prim != nullptr) {
prim = real_prim;
// Set primitive attributes
@ -249,6 +249,58 @@ PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {
return nullptr;
// Return the function Primitive if DoSignaturePrimitive,
// otherwise return the Primitive directly.
PrimitivePtr GetPrimitiveWithoutDoSignature(const AnfNodePtr &node) {
const auto &do_signature_prim = GetValueNode<prim::DoSignaturePrimitivePtr>(node);
if (do_signature_prim != nullptr) {
return dyn_cast<Primitive>(do_signature_prim->function());
return GetValueNode<PrimitivePtr>(node);
// Check the first input of CNode.
// Return the function Primitive if DoSignaturePrimitive,
// otherwise return the Primitive directly.
PrimitivePtr GetCNodePrimitiveWithoutDoSignature(const AnfNodePtr &node) {
auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr || cnode->size() == 0) {
return nullptr;
return GetPrimitiveWithoutDoSignature(cnode->input(0));
// Return the function value if DoSignaturePrimitive,
// otherwise return the value directly.
ValuePtr GetValueWithoutDoSignature(const ValuePtr &value) {
auto do_signature_prim = dyn_cast<prim::DoSignaturePrimitive>(value);
if (do_signature_prim != nullptr) {
return do_signature_prim->function();
return value;
// Return the function value if DoSignaturePrimitive,
// otherwise return the value directly.
ValuePtr GetValueWithoutDoSignature(const AnfNodePtr &node) {
auto value = GetValueNode(node);
if (value == nullptr) {
return nullptr;
return GetValueWithoutDoSignature(value);
// Check the first input of CNode.
// Return the function value if DoSignaturePrimitive,
// otherwise return the value directly.
ValuePtr GetCNodeValueWithoutDoSignature(const AnfNodePtr &node) {
auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr || cnode->size() == 0) {
return nullptr;
return GetValueWithoutDoSignature(cnode->input(0));
std::string GetCNodeFuncName(const CNodePtr cnode) {
if (cnode->inputs().empty()) {
return "";
@ -1071,15 +1071,34 @@ static S GetValue(const ValuePtr &value) {
MS_CORE_API std::string GetCNodeFuncName(CNodePtr cnode);
// used to get FuncGraphPtr from a cnode first input
// Used to get FuncGraphPtr from a cnode first input
MS_CORE_API FuncGraphPtr GetCNodeFuncGraph(const AnfNodePtr &node);
// used to check whether an AnfNode is a cnode with a kind of Primitive as first input
// Used to check whether an AnfNode is a cnode with a kind of Primitive as first input
MS_CORE_API bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value = nullptr);
// used to get PrimitivePtr from a cnode first input
// Used to get PrimitivePtr from a cnode first input
MS_CORE_API PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
// Return the function Primitive if DoSignaturePrimitive,
// otherwise return the Primitive directly.
MS_CORE_API PrimitivePtr GetPrimitiveWithoutDoSignature(const AnfNodePtr &node);
// Check the first input of CNode.
// Return the function Primitive if DoSignaturePrimitive,
// otherwise return the Primitive directly.
MS_CORE_API PrimitivePtr GetCNodePrimitiveWithoutDoSignature(const AnfNodePtr &node);
// Return the function value if DoSignaturePrimitive,
// otherwise return the value directly.
MS_CORE_API ValuePtr GetValueWithoutDoSignature(const ValuePtr &value);
// Return the function value if DoSignaturePrimitive,
// otherwise return the value directly.
MS_CORE_API ValuePtr GetValueWithoutDoSignature(const AnfNodePtr &node);
// Check the first input of CNode.
// Return the function value if DoSignaturePrimitive,
// otherwise return the value directly.
MS_CORE_API ValuePtr GetCNodeValueWithoutDoSignature(const AnfNodePtr &node);
/// \brief Used to check whether the given node is a ValueNode with some Primitive value.
/// \param[in] node The input node.
@ -373,15 +373,8 @@ bool MSANFModelParser::SetNodeAbstractFromAttrProto(const mind_ir::AttributeProt
void MSANFModelParser::SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &node_proto, const CNodePtr &cnode_ptr) {
auto prim = GetCNodePrimitive(cnode_ptr);
auto prim_to_add_attr = prim;
auto prim_to_add_attr = GetCNodePrimitiveWithoutDoSignature(cnode_ptr);
if (prim_to_add_attr != nullptr) {
if (prim->isa<prim::DoSignaturePrimitive>()) {
auto func = prim->cast<prim::DoSignaturePrimitivePtr>()->function();
if (func != nullptr && func->isa<Primitive>()) {
prim_to_add_attr = func->cast<PrimitivePtr>();
prim_to_add_attr->set_attr("is_load", MakeValue(true));
for (int i = 0; i < node_proto.attribute_size(); ++i) {
@ -393,7 +386,8 @@ void MSANFModelParser::SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &nod
if (prim_to_add_attr != nullptr && !GetAttrValueForCNode(prim_to_add_attr, attr_proto)) {
MS_LOG(ERROR) << "Parse prim: " << prim->ToString() << " attributes error: " << attr_proto.DebugString();
MS_LOG(ERROR) << "Parse prim: " << prim_to_add_attr->ToString()
<< ", attributes error: " << attr_proto.DebugString();
} else {
// ref_attr_name is removed in newer versions.
@ -402,7 +396,8 @@ void MSANFModelParser::SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &nod
if (prim_to_add_attr != nullptr && !SetPrimitiveAttrWithType(prim_to_add_attr, attr_proto)) {
MS_LOG(ERROR) << "Parse prim: " << prim->ToString() << " attributes error: " << attr_proto.DebugString();
MS_LOG(ERROR) << "Parse prim: " << prim_to_add_attr->ToString()
<< ", attributes error: " << attr_proto.DebugString();
@ -1690,17 +1685,11 @@ bool MSANFModelParser::BuildPrimitiveNode(const mind_ir::PrimitiveProto &primiti
prim->set_attr("is_load", MakeValue(true));
// Set primitive attributes
auto prim_to_add_attr = prim;
if (prim->isa<prim::DoSignaturePrimitive>()) {
auto func = prim->cast<prim::DoSignaturePrimitivePtr>()->function();
if (func != nullptr && func->isa<Primitive>()) {
prim_to_add_attr = func->cast<PrimitivePtr>();
prim_to_add_attr->set_attr("is_load", MakeValue(true));
auto prim_to_add_attr = GetValueWithoutDoSignature(prim)->cast<PrimitivePtr>();
prim_to_add_attr->set_attr("is_load", MakeValue(true));
for (int i = 0; i < primitive_proto.attribute_size(); ++i) {
const mind_ir::AttributeProto &attr_proto = primitive_proto.attribute(i);
if (!SetPrimitiveAttrWithType(prim_to_add_attr, attr_proto)) {
@ -264,10 +264,10 @@ class ForwardValueAndGrad(Cell):
if not isinstance(get_by_list, bool):
raise TypeError(f"For 'ForwardValueAndGrad', "
f"the type of 'get_by_list' must be bool, but got '{type(get_by_list)}'")
if get_by_list and not isinstance(weights, ParameterTuple):
if get_by_list and not isinstance(weights, (ParameterTuple, tuple, list)):
raise TypeError(f"For 'ForwardValueAndGrad', "
f"when 'get_by_list' is set to True, the argument 'weights' must be "
f"ParameterTuple type, but got '{type(weights)}'")
f"Parameters array, but got '{type(weights)}'")
|||| = network
if isinstance(network, Cell):
@ -168,6 +168,8 @@ class GradOperation(GradOperation_):
4. Call the gradient function with input function's inputs
to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`.
To be noticed, for above gradient functions, the returned gradient result may vary for grad result element number:
A single value if only one result, a tuple for multiple results, or a empty tuple for no result.
We can configure the sensitivity(gradient with respect to output) by setting `sens_param` as True and
passing an extra sensitivity input to the gradient function, the sensitivity input should has the
@ -188,11 +190,11 @@ class GradOperation(GradOperation_):
get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
get_by_list (bool): If True, get all the gradients with respect to Parameter variables.
get_by_list (bool): If True, get all the gradients with respect to Parameter free variables.
If get_all and get_by_list are both False, get the gradient with respect to first input.
If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables
at the same time in the form of ((gradients with respect to inputs),
(gradients with respect to parameters)). Default: False.
If get_all and get_by_list are both True, get the gradients with respect to inputs and
Parameter free variables at the same time in the form of ("gradients with respect to inputs",
"gradients with respect to parameter free variables"). Default: False.
sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
Default: False.
@ -232,3 +232,67 @@ def test_train_lenet_with_new_interface(num_classes=10, epoch=20, batch_size=32)
assert losses[-1].asnumpy() < 0.01
assert losses[-1].asnumpy() > 0.001
def test_train_lenet_with_new_interface_tuple(num_classes=10, epoch=20, batch_size=32):
Feature: GradOperation get_by_list pass tuple/list
Description: Grad with Parameters as input type and fv. list or tuple as fv of grad.
Expectation: No exception.
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
network = LeNet5(num_classes)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_with_criterion = WithLossCell(network, criterion)
weights = tuple(network.trainable_params())
optimizer = nn.Momentum(weights, 0.1, 0.9)
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True)
losses = []
for i in range(0, epoch):
data = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([batch_size]).astype(np.int32))
sens = Tensor(np.ones([1]).astype(np.float32))
loss, grads = train_network(data, label, sens)
grads = F.identity(grads)
assert losses[-1].asnumpy() < 0.01
assert losses[-1].asnumpy() > 0.001
def test_train_lenet_with_new_interface_list(num_classes=10, epoch=20, batch_size=32):
Feature: GradOperation get_by_list pass tuple/list
Description: Grad with Parameters as input type and fv. list or tuple as fv of grad.
Expectation: No exception.
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
network = LeNet5(num_classes)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_with_criterion = WithLossCell(network, criterion)
weights = list(network.trainable_params())
optimizer = nn.Momentum(weights, 0.1, 0.9)
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True)
losses = []
for i in range(0, epoch):
data = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([batch_size]).astype(np.int32))
sens = Tensor(np.ones([1]).astype(np.float32))
loss, grads = train_network(data, label, sens)
grads = F.identity(grads)
assert losses[-1].asnumpy() < 0.01
assert losses[-1].asnumpy() > 0.001
@ -177,3 +177,69 @@ def test_ascend_lenet2():
loss_output = test_ascend_lenet()
assert loss_output.asnumpy() < 0.004
assert loss_output.asnumpy() > 0.003
class GradWrapTuple(nn.Cell):
GradWrapTuple definition
def __init__(self, network):
super(GradWrapTuple, self).__init__()
|||| = network
self.weights = tuple(filter(lambda x: x.requires_grad, network.get_parameters()))
def construct(self, x, label):
weights = self.weights
return grad_by_list(, weights)(x, label)
def test_ascend_lenet_grad_by_list_tuple():
Feature: GradOperation get_by_list pass tuple/list
Description: Grad with Parameters as input type and fv. list or tuple as fv of grad.
Expectation: No exception.
epoch_size = 20
batch_size = 32
inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32))
labels = Tensor(np.ones([batch_size]).astype(np.int32))
net = LeNet()
criterion = CrossEntropyLoss()
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
net_with_criterion = WithLossCell(net, criterion)
train_network = GradWrapTuple(net_with_criterion)
total_time = 0
for epoch in range(0, epoch_size):
start_time = time.time()
fw_output = net(inputs)
loss_output = criterion(fw_output, labels)
grads = train_network(inputs, labels)
end_time = time.time()
cost_time = end_time - start_time
total_time = total_time + cost_time
print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
return loss_output
def test_ascend_lenet_grad_by_list_tuple1():
Feature: GradOperation get_by_list pass tuple/list
Description: Grad with Parameters as input type and fv. list or tuple as fv of grad.
Expectation: No exception.
os.environ['GRAPH_OP_RUN'] = str(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
loss_output = test_ascend_lenet_grad_by_list_tuple()
assert loss_output.asnumpy() < 0.004
assert loss_output.asnumpy() > 0.003
@ -53,16 +53,16 @@ def test_isinstance():
is_tuple_var = isinstance((x, 1, 1.0, y), tuple)
is_list_const = isinstance(self.list_member, list)
is_list_var = isinstance([x, 1, 1.0, y], list)
is_empty_list = isinstance(self.empty_list, list)
is_dict_const = isinstance(self.dict_member, dict)
is_dict_var = isinstance({"x": x, "y": y}, dict)
is_empty_dic = isinstance(self.empty_dict, dict)
is_list_or_tensor = isinstance([x, y], (Tensor, list))
is_int_or_float_or_tensor_or_tuple = isinstance(x, (Tensor, tuple, int, float))
is_list_or_tensor = isinstance([x, y], (Tensor, list))
float_is_int = isinstance(self.float_member, int)
bool_is_string = isinstance(self.bool_member, str)
tensor_is_tuple = isinstance(x, tuple)
tuple_is_list = isinstance(self.tuple_member, list)
is_empty_list = isinstance(self.empty_list, list)
return is_int, is_float, is_bool, bool_is_int, is_string, is_parameter, \
parameter_is_tensor, is_tensor_const, is_tensor_var, \
is_tuple_const, is_tuple_var, is_list_const, is_list_var, is_empty_list, \
@ -229,7 +229,6 @@ def test_grad_parameter_tuple(mode):
GradCellWithParameterTuple(TestCell2(x1, x2))(z)
@pytest.mark.skip(reason='Not support list or tuple of parameters as GradOperation inputs by now')
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
def test_grad_parameter_list_or_tuple(mode):
Reference in New Issue