support grad and value_and_grad with has_aux

This commit is contained in:
chenzhuo 2022-08-09 15:17:47 +08:00
parent 543f959786
commit 5f39ebd788
21 changed files with 1183 additions and 177 deletions

View File

@ -525,6 +525,7 @@ Parameter操作函数
mindspore.ops.derivative mindspore.ops.derivative
mindspore.ops.grad mindspore.ops.grad
mindspore.ops.value_and_grad
mindspore.ops.jet mindspore.ops.jet
mindspore.ops.jvp mindspore.ops.jvp
mindspore.ops.vjp mindspore.ops.vjp

View File

@ -0,0 +1,25 @@
mindspore.ops.grad
==================
.. py:function:: mindspore.ops.grad(fn, grad_position=0, weights=None, has_aux=False)
生成求导函数,用于计算给定函数的梯度。
函数求导包含以下三种场景:
1. 对输入求导,此时 `grad_position` 非None`weights` 是None;
2. 对网络变量求导,此时 `grad_position` 是None`weights` 非None;
3. 同时对输入和网络变量求导,此时 `grad_position``weights` 都非None。
参数:
- **fn** (Union[Cell, Function]) - 待求导的函数或网络。
- **grad_position** (Union[NoneType, int, tuple[int]]) - 指定求导输入位置的索引。若为int类型表示对单个输入求导若为tuple类型表示对tuple内索引的位置求导其中索引从0开始若是None表示不对输入求导这种场景下 `weights` 非None。默认值0。
- **weights** (Union[ParameterTuple, Parameter, list[Parameter]]) - 训练网络中需要返回梯度的网络变量。一般可通过 `weights = net.trainable_params()` 获取。默认值None。
- **has_aux** (bool) - 是否返回辅助参数的标志。若为True `fn` 输出数量必须超过一个,其中只有 `fn` 第一个输出参与求导其他输出值将直接返回。默认值False。
返回:
Function用于计算给定函数的梯度的求导函数。例如 `out1, out2 = fn(*args)` ,若 `has_aux` 为True梯度函数将返回 `(gradient, out2)` 形式的结果,其中 `out2` 不参与求导若为False将直接返回 `gradient`
异常:
- **ValueError** - 入参 `grad_position``weights` 同时为None。
- **TypeError** - 入参类型不符合要求。

View File

@ -0,0 +1,25 @@
mindspore.ops.value_and_grad
============================
.. py:function:: mindspore.ops.value_and_grad(fn, grad_position=0, weights=None, has_aux=False)
生成求导函数,用于计算给定函数的正向计算结果和梯度。
函数求导包含以下三种场景:
1. 对输入求导,此时 `grad_position` 非None`weights` 是None;
2. 对网络变量求导,此时 `grad_position` 是None`weights` 非None;
3. 同时对输入和网络变量求导,此时 `grad_position``weights` 都非None。
参数:
- **fn** (Union[Cell, Function]) - 待求导的函数或网络。
- **grad_position** (Union[NoneType, int, tuple[int]]) - 指定求导输入位置的索引。若为int类型表示对单个输入求导若为tuple类型表示对tuple内索引的位置求导其中索引从0开始若是None表示不对输入求导这种场景下 `weights` 非None。默认值0。
- **weights** (Union[ParameterTuple, Parameter, list[Parameter]]) - 训练网络中需要返回梯度的网络变量。一般可通过 `weights = net.trainable_params()` 获取。默认值None。
- **has_aux** (bool) - 是否返回辅助参数的标志。若为True `fn` 输出数量必须超过一个,其中只有 `fn` 第一个输出参与求导其他输出值将直接返回。默认值False。
返回:
Function用于计算给定函数的梯度的求导函数。例如 `out1, out2 = fn(*args)` ,梯度函数将返回 `((out1, out2), gradient)` 形式的结果, 其中 `out2` 不参与求导。
异常:
- **ValueError** - 入参 `grad_position``weights` 同时为None。
- **TypeError** - 入参类型不符合要求。

View File

@ -526,6 +526,7 @@ Differential Functions
mindspore.ops.derivative mindspore.ops.derivative
mindspore.ops.grad mindspore.ops.grad
mindspore.ops.value_and_grad
mindspore.ops.jet mindspore.ops.jet
mindspore.ops.jvp mindspore.ops.jvp
mindspore.ops.vjp mindspore.ops.vjp

View File

@ -612,7 +612,7 @@ FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list)
if (tail_type_ >= kNotGrad) { if (tail_type_ >= kNotGrad) {
AbstractSequencePtr sequence_arg = dyn_cast<AbstractSequence>(args_spec_list[0]); AbstractSequencePtr sequence_arg = dyn_cast<AbstractSequence>(args_spec_list[0]);
if (sequence_arg == nullptr) { if (sequence_arg == nullptr) {
MS_LOG(EXCEPTION) << "'Tail' arg0 must be tuple or list, but got " << sequence_arg->ToString(); MS_LOG(EXCEPTION) << "'Tail' arg0 must be tuple or list, but got " << args_spec_list[0]->ToString();
} }
return GenerateTailFuncGraph(sequence_arg); return GenerateTailFuncGraph(sequence_arg);
} }
@ -624,25 +624,92 @@ FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list)
} }
AbstractTuplePtr tuple_arg = dyn_cast<AbstractTuple>(args_spec_list[0]); AbstractTuplePtr tuple_arg = dyn_cast<AbstractTuple>(args_spec_list[0]);
if (tuple_arg == nullptr) { if (tuple_arg == nullptr) {
MS_LOG(EXCEPTION) << "'Tail' arg0 must be tuple, but got " << tuple_arg->ToString(); MS_LOG(EXCEPTION) << "'Tail' arg0 must be tuple, but got " << args_spec_list[0]->ToString();
} }
if (args_spec_list.size() == args_max_size) { if (args_spec_list.size() == args_max_size) {
AbstractTuplePtr pos = dyn_cast<AbstractTuple>(args_spec_list[1]); AbstractTuplePtr pos = dyn_cast<AbstractTuple>(args_spec_list[1]);
if (pos == nullptr) { if (pos == nullptr) {
MS_LOG(EXCEPTION) << "'Tail' arg1 'position' must be tuple, but got " << pos->ToString(); MS_LOG(EXCEPTION) << "'Tail' arg1 'position' must be tuple, but got " << args_spec_list[1]->ToString();
} }
return GenerateGradFuncGraph(tuple_arg, pos); return GenerateGradFuncGraph(tuple_arg, pos);
} }
return GenerateGradFuncGraph(tuple_arg); return GenerateGradFuncGraph(tuple_arg);
} }
namespace {
AnfNodePtr CreateOutputsWithAux(const FuncGraphPtr &k_child, const AnfNodePtr &gradient, const AnfNodePtr &f_app,
bool has_aux, bool get_value) {
if (get_value) {
return k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), f_app, gradient});
}
if (!has_aux) {
return gradient;
}
PrimitivePtr get_tuple_item_op = prim::kPrimTupleGetItem;
PrimitivePtr make_tuple_op = prim::kPrimMakeTuple;
std::vector<AnfNodePtr> elements = {NewValueNode(make_tuple_op)};
elements.emplace_back(
k_child->NewCNodeInOrder({NewValueNode(get_tuple_item_op), f_app, NewValueNode(static_cast<int64_t>(1))}));
auto aux_output = k_child->NewCNodeInOrder(elements);
auto unpack_node =
k_child->NewCNodeInOrder({NewValueNode(get_tuple_item_op), aux_output, NewValueNode(static_cast<int64_t>(0))});
return k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), gradient, unpack_node});
}
} // namespace
// When set aux True, for out1, out2, out3 = fn(inputs), only first out1 contributes to differentiation of fn.
FuncGraphPtr GradAux::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
AbstractTuplePtr tuple_arg = dyn_cast<AbstractTuple>(args_spec_list[0]);
if (tuple_arg == nullptr) {
MS_LOG(EXCEPTION) << "'GradAux' arg0 must be tuple, but got " << args_spec_list[0]->ToString();
}
FuncGraphPtr fg = std::make_shared<FuncGraph>();
fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr tuple_parameter = fg->add_parameter();
// get_value flag
(void)fg->add_parameter();
AbstractScalarPtr get_value_ptr = dyn_cast<AbstractScalar>(args_spec_list[1]);
bool get_value_flag = GetValue<bool>(get_value_ptr->BuildValue());
std::vector<AnfNodePtr> elements = {NewValueNode(prim::kPrimMakeTuple)};
elements.push_back(
fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(0))}));
if (get_value_flag) {
for (size_t i = 1; i < tuple_arg->size(); i++) {
auto aux_node =
fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(i))});
auto stop_gradient_node = fg->NewCNodeInOrder({NewValueNode(prim::kPrimStopGradient), aux_node});
elements.push_back(stop_gradient_node);
}
} else {
std::vector<AnfNodePtr> aux_elements = {NewValueNode(prim::kPrimMakeTuple)};
for (size_t i = 1; i < tuple_arg->size(); i++) {
auto aux_node =
fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(i))});
auto stop_gradient_node = fg->NewCNodeInOrder({NewValueNode(prim::kPrimStopGradient), aux_node});
aux_elements.push_back(stop_gradient_node);
}
elements.push_back(fg->NewCNodeInOrder(aux_elements));
}
constexpr size_t args_least_size = 2;
if (elements.size() < args_least_size) {
MS_LOG(EXCEPTION) << "When has_aux is True, origin fn requires more than one outputs, but got " << elements.size()
<< " outputs.\n"
<< trace::GetDebugInfo(fg->debug_info());
}
fg->set_output(fg->NewCNodeInOrder(elements));
return fg;
}
GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param, GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param,
bool get_by_position) bool get_by_position, bool has_aux, bool get_value)
: MetaFuncGraph(name), : MetaFuncGraph(name),
get_all_(get_all), get_all_(get_all),
get_by_list_(get_by_list), get_by_list_(get_by_list),
sens_param_(sens_param), sens_param_(sens_param),
get_by_position_(get_by_position) { get_by_position_(get_by_position),
has_aux_(has_aux),
get_value_(get_value) {
if (get_by_position) { if (get_by_position) {
signatures_ = signatures_ =
// def grad(func:read, weight_list:ref, position_list:ref): // def grad(func:read, weight_list:ref, position_list:ref):
@ -714,29 +781,27 @@ void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePt
if (get_by_position_) { if (get_by_position_) {
TailPtr tail_grad_by_position = std::make_shared<Tail>("tail_grad_by_position", kGradByPosition); TailPtr tail_grad_by_position = std::make_shared<Tail>("tail_grad_by_position", kGradByPosition);
inputs_bprop = k_child->NewCNodeInOrder({NewValueNode(tail_grad_by_position), b_app, position}); inputs_bprop = k_child->NewCNodeInOrder({NewValueNode(tail_grad_by_position), b_app, position});
k_child->set_output(inputs_bprop); } else if (get_all_) {
return;
}
if (get_all_) {
TailPtr tail_grad_all = std::make_shared<Tail>("tail_grad_all", kGradAll); TailPtr tail_grad_all = std::make_shared<Tail>("tail_grad_all", kGradAll);
inputs_bprop = k_child->NewCNodeInOrder({NewValueNode(tail_grad_all), b_app}); inputs_bprop = k_child->NewCNodeInOrder({NewValueNode(tail_grad_all), b_app});
} }
// Gradients wrt inputs and parameters // Gradients wrt inputs and parameters
if (fv_bprop != nullptr && inputs_bprop != nullptr) { if (fv_bprop != nullptr && inputs_bprop != nullptr) {
k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop})); auto make_tuple = k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop});
k_child->set_output(CreateOutputsWithAux(k_child, make_tuple, f_app, has_aux_, get_value_));
return; return;
} }
// Gradients wrt parameters // Gradients wrt parameters
if (fv_bprop != nullptr) { if (fv_bprop != nullptr) {
k_child->set_output(fv_bprop); k_child->set_output(CreateOutputsWithAux(k_child, fv_bprop, f_app, has_aux_, get_value_));
return; return;
} }
// Gradients wrt inputs // Gradients wrt inputs
if (inputs_bprop != nullptr) { if (inputs_bprop != nullptr) {
k_child->set_output(inputs_bprop); k_child->set_output(CreateOutputsWithAux(k_child, inputs_bprop, f_app, has_aux_, get_value_));
return; return;
} }
// Gradients wrt first input. // Gradients wrt first input.
@ -744,7 +809,8 @@ void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePt
// so obtain first input grad by setting tail_type of Tail to kGradFirst. // so obtain first input grad by setting tail_type of Tail to kGradFirst.
TailPtr tail_grad_first = std::make_shared<Tail>("tail_grad_first", kGradFirst); TailPtr tail_grad_first = std::make_shared<Tail>("tail_grad_first", kGradFirst);
tail_grad_first->set_enable_tuple_grad_first(enable_tuple_grad); tail_grad_first->set_enable_tuple_grad_first(enable_tuple_grad);
k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app})); auto tail_grad_first_cnode = k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app});
k_child->set_output(CreateOutputsWithAux(k_child, tail_grad_first_cnode, f_app, has_aux_, get_value_));
} }
namespace { namespace {
@ -795,6 +861,14 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
FuncGraphPtr forward_graph = real_fn->func_graph(); FuncGraphPtr forward_graph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(forward_graph); MS_EXCEPTION_IF_NULL(forward_graph);
if (has_aux_) {
GradAuxPtr aux_fn = std::make_shared<GradAux>("aux_fn");
auto output_cnode = forward_graph->output();
auto aux_fn_cnode = forward_graph->NewCNodeInOrder({NewValueNode(aux_fn), output_cnode, NewValueNode(get_value_)});
forward_graph->set_output(aux_fn_cnode);
}
forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true); forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
// Check if primal func graph has the primitive returned sparse result in its bprop(). // Check if primal func graph has the primitive returned sparse result in its bprop().
@ -814,13 +888,12 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
ParameterPtr param_graph = grad_fg->add_parameter(); ParameterPtr param_graph = grad_fg->add_parameter();
AnfNodePtr weights = nullptr; AnfNodePtr weights = nullptr;
if (get_by_list_) {
weights = grad_fg->add_parameter();
}
AnfNodePtr position = nullptr; AnfNodePtr position = nullptr;
if (get_by_position_) { if (get_by_position_) {
weights = grad_fg->add_parameter(); weights = grad_fg->add_parameter();
position = grad_fg->add_parameter(); position = grad_fg->add_parameter();
} else if (get_by_list_) {
weights = grad_fg->add_parameter();
} }
std::vector<AnfNodePtr> inputs; std::vector<AnfNodePtr> inputs;

View File

@ -144,7 +144,8 @@ using MakeListGradientPtr = std::shared_ptr<MakeListGradient>;
class GradOperation : public MetaFuncGraph { class GradOperation : public MetaFuncGraph {
public: public:
explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false, explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false,
bool sens_param = false, bool get_by_position = false); bool sens_param = false, bool get_by_position = false, bool has_aux = false,
bool get_value = false);
~GradOperation() override = default; ~GradOperation() override = default;
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
@ -158,6 +159,8 @@ class GradOperation : public MetaFuncGraph {
bool get_by_list_; bool get_by_list_;
bool sens_param_; bool sens_param_;
bool get_by_position_; bool get_by_position_;
bool has_aux_;
bool get_value_;
private: private:
void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop, void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
@ -165,6 +168,15 @@ class GradOperation : public MetaFuncGraph {
}; };
using GradOperationPtr = std::shared_ptr<GradOperation>; using GradOperationPtr = std::shared_ptr<GradOperation>;
class GradAux : public MetaFuncGraph {
public:
explicit GradAux(const std::string &name) : MetaFuncGraph(name) {}
~GradAux() override = default;
MS_DECLARE_PARENT(GradAux, MetaFuncGraph);
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
};
using GradAuxPtr = std::shared_ptr<GradAux>;
class TaylorOperation : public MetaFuncGraph { class TaylorOperation : public MetaFuncGraph {
public: public:
explicit TaylorOperation(const std::string &name); explicit TaylorOperation(const std::string &name);

View File

@ -38,8 +38,9 @@ REGISTER_PYBIND_WITH_PARENT_NAME(
// Reg GradOperation // Reg GradOperation
(void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>(*m, "GradOperation_") (void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>(*m, "GradOperation_")
.def(py::init<std::string &>(), py::arg("fn")) .def(py::init<std::string &>(), py::arg("fn"))
.def(py::init<std::string &, bool, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"), py::arg("get_by_list"), .def(py::init<std::string &, bool, bool, bool, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"),
py::arg("sens_param"), py::arg("get_by_position")); py::arg("get_by_list"), py::arg("sens_param"), py::arg("get_by_position"), py::arg("has_aux"),
py::arg("get_value"));
// Reg VmapOperation // Reg VmapOperation
(void)py::class_<VmapOperation, MetaFuncGraph, std::shared_ptr<VmapOperation>>(*m, "VmapOperation_") (void)py::class_<VmapOperation, MetaFuncGraph, std::shared_ptr<VmapOperation>>(*m, "VmapOperation_")

View File

@ -298,8 +298,8 @@ class KPynativeCellImpl : public KPynativeCell {
void UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &sens_out) override; void UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &sens_out) override;
// Build a back propagate funcgraph, each cnode in primal funcgraph is replaced by value node or formal cnode, so it // Build a back propagate funcgraph, each cnode in primal funcgraph is replaced by value node or formal cnode, so it
// can be grad again. // can be grad again.
FuncGraphPtr Finish(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position, bool grad_inputs, FuncGraphPtr Finish(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
bool grad_weights, bool has_sens_arg, bool build_formal_param); const GradAttr &grad_attr, bool build_formal_param);
private: private:
bool need_propagate_stop_gradient_{false}; bool need_propagate_stop_gradient_{false};
@ -346,13 +346,16 @@ class KPynativeCellImpl : public KPynativeCell {
AnfNodePtrList *grad_inputs_list); AnfNodePtrList *grad_inputs_list);
// Set return node according to grad flag // Set return node according to grad flag
void SetOutput(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position, bool grad_inputs, void SetOutput(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position, bool grad_inputs,
bool grad_weights); bool grad_weights, const bool get_by_position);
// for higher order gradient; // for higher order gradient;
// Build k mapped node owned by tape_ for each cnode in primal funcgraph, so these node can be // Build k mapped node owned by tape_ for each cnode in primal funcgraph, so these node can be
// used in tape_ to keep tracking the cnode dependency. // used in tape_ to keep tracking the cnode dependency.
bool BuildKNode(); bool BuildKNode();
CNodePtr GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const AnfNodePtrList &args); CNodePtr GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const AnfNodePtrList &args);
AnfNodePtr GetTapeOutputForPosition(const AnfNodePtrList &grad_inputs_list, const AbstractBasePtr &grad_inputs_spec,
const AnfNodePtrList &grad_weights_list, const AbstractBasePtr &grad_weights_spec,
const std::vector<size_t> &grad_position, const bool grad_weights);
}; };
using KPynativeCellImplPtr = std::shared_ptr<KPynativeCellImpl>; using KPynativeCellImplPtr = std::shared_ptr<KPynativeCellImpl>;
@ -371,19 +374,18 @@ KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
} }
FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights, FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights,
const std::vector<size_t> &grad_position, bool grad_inputs, bool grad_weights, const std::vector<size_t> &grad_position, const GradAttr &grad_attr,
bool has_sens_arg, bool build_formal_param) { bool build_formal_param) {
auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell); auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
return k_cell_impl->Finish(weights, grad_position, grad_inputs, grad_weights, has_sens_arg, build_formal_param); return k_cell_impl->Finish(weights, grad_position, grad_attr, build_formal_param);
} }
FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position, FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
bool grad_inputs, bool grad_weights, bool has_sens_arg, const GradAttr &grad_attr, bool build_formal_param) {
bool build_formal_param) {
// propagate stop_gradient flag to cnode before back propagate; // propagate stop_gradient flag to cnode before back propagate;
PropagateStopGradient(); PropagateStopGradient();
// Set sens node and weights node // Set sens node and weights node
SetSensAndWeights(weights, has_sens_arg); SetSensAndWeights(weights, grad_attr.has_sens);
// Build forward CNode; // Build forward CNode;
if (build_formal_param) { if (build_formal_param) {
(void)BuildKNode(); (void)BuildKNode();
@ -393,12 +395,12 @@ FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, const std:
(void)BackPropagate(!build_formal_param); (void)BackPropagate(!build_formal_param);
} }
// Return the gradient; // Return the gradient;
if (grad_position.empty()) { if (grad_attr.get_by_position && grad_position.empty()) {
MS_LOG(EXCEPTION) << "grad_position in F.grad is empty!"; MS_LOG(EXCEPTION) << "grad_position in F.grad is empty when grad by position!";
} }
SetOutput(weights, grad_position, grad_inputs, grad_weights); SetOutput(weights, grad_position, grad_attr.grad_all_inputs, grad_attr.grad_weights, grad_attr.get_by_position);
// Replace Parameter of primal funcgraph with parameter of tape_; // Replace Parameter of primal funcgraph with parameter of tape_;
ReplacePrimalParameter(weights, has_sens_arg); ReplacePrimalParameter(weights, grad_attr.has_sens);
#ifdef ENABLE_DUMP_IR #ifdef ENABLE_DUMP_IR
auto save_graphs_flg = MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); auto save_graphs_flg = MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
if (save_graphs_flg) { if (save_graphs_flg) {
@ -1071,7 +1073,7 @@ AbstractBasePtr KPynativeCellImpl::GetGradInputsSpec(const std::vector<size_t> &
} }
void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position, void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
bool grad_inputs, bool grad_weights) { bool grad_inputs, bool grad_weights, const bool get_by_position) {
AnfNodePtrList grad_inputs_list{NewValueNode(prim::kPrimMakeTuple)}; AnfNodePtrList grad_inputs_list{NewValueNode(prim::kPrimMakeTuple)};
AbstractBasePtr grad_inputs_spec = GetGradInputsSpec(grad_position, grad_inputs, &grad_inputs_list); AbstractBasePtr grad_inputs_spec = GetGradInputsSpec(grad_position, grad_inputs, &grad_inputs_list);
AnfNodePtrList grad_weights_list{NewValueNode(prim::kPrimMakeTuple)}; AnfNodePtrList grad_weights_list{NewValueNode(prim::kPrimMakeTuple)};
@ -1098,13 +1100,20 @@ void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vect
grad_weights_spec = std::make_shared<abstract::AbstractTuple>(grad_weights_abs_list); grad_weights_spec = std::make_shared<abstract::AbstractTuple>(grad_weights_abs_list);
} }
if (get_by_position) {
auto tape_output = GetTapeOutputForPosition(grad_inputs_list, grad_inputs_spec, grad_weights_list,
grad_weights_spec, grad_position, grad_weights);
tape_->set_output(tape_output);
return;
}
AnfNodePtr tape_output; AnfNodePtr tape_output;
if (grad_inputs && grad_weights) { if (grad_inputs && grad_weights) {
tape_output = tape_->NewCNode( tape_output = tape_->NewCNode(
{NewValueNode(prim::kPrimMakeTuple), tape_->NewCNode(grad_inputs_list), tape_->NewCNode(grad_weights_list)}); {NewValueNode(prim::kPrimMakeTuple), tape_->NewCNode(grad_inputs_list), tape_->NewCNode(grad_weights_list)});
tape_output->set_abstract( tape_output->set_abstract(
std::make_shared<abstract::AbstractTuple>(abstract::AbstractBasePtrList{grad_inputs_spec, grad_weights_spec})); std::make_shared<abstract::AbstractTuple>(abstract::AbstractBasePtrList{grad_inputs_spec, grad_weights_spec}));
} else if (grad_inputs || (grad_position.size() > 1)) { } else if (grad_inputs) {
tape_output = tape_->NewCNode(grad_inputs_list); tape_output = tape_->NewCNode(grad_inputs_list);
tape_output->set_abstract(grad_inputs_spec); tape_output->set_abstract(grad_inputs_spec);
} else if (grad_weights) { } else if (grad_weights) {
@ -1114,15 +1123,11 @@ void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vect
tape_output = tape_->NewCNode(grad_inputs_list); tape_output = tape_->NewCNode(grad_inputs_list);
tape_output->set_abstract(grad_inputs_spec); tape_output->set_abstract(grad_inputs_spec);
} else { } else {
size_t index = grad_position[0]; auto input_adjoint_iter = anfnode_to_adjoin_.find(cell_inputs_[0]);
if (index >= cell_inputs_.size()) {
MS_LOG(EXCEPTION) << "Position index " << index << " is exceed input size!";
}
auto input_adjoint_iter = anfnode_to_adjoin_.find(cell_inputs_[index]);
if (input_adjoint_iter == anfnode_to_adjoin_.end()) { if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
// If input is not used in the network, just return zeros_like() as dout; // If input is not used in the network, just return zeros_like() as dout;
MS_LOG(WARNING) << "Input is not used in network, input: " << cell_inputs_[index]->ToString(); MS_LOG(WARNING) << "Input is not used in network, input: " << cell_inputs_[0]->ToString();
tape_output = BuildZerosLikeNode(tape_, cell_inputs_[index]); tape_output = BuildZerosLikeNode(tape_, cell_inputs_[0]);
} else { } else {
tape_output = input_adjoint_iter->second->RealDout(); tape_output = input_adjoint_iter->second->RealDout();
} }
@ -1200,5 +1205,53 @@ void ClearKPynativeCellStaticRes() {
zeros_like_funcgraph_cache.clear(); zeros_like_funcgraph_cache.clear();
ones_like_funcgraph_cache.clear(); ones_like_funcgraph_cache.clear();
} }
// Support both grad_position and weight options in grad.
AnfNodePtr KPynativeCellImpl::GetTapeOutputForPosition(const AnfNodePtrList &grad_inputs_list,
const AbstractBasePtr &grad_inputs_spec,
const AnfNodePtrList &grad_weights_list,
const AbstractBasePtr &grad_weights_spec,
const std::vector<size_t> &grad_position,
const bool grad_weights) {
MS_EXCEPTION_IF_NULL(tape_);
AnfNodePtr tape_output;
if (grad_position.size() == 1) {
AnfNodePtr first_tape_output;
size_t index = grad_position[0];
if (index >= cell_inputs_.size()) {
MS_LOG(EXCEPTION) << "Position index " << index << " is exceed input size!";
}
auto input_adjoint_iter = anfnode_to_adjoin_.find(cell_inputs_[index]);
if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
MS_LOG(WARNING) << "Input is not used in network, input: " << cell_inputs_[index]->ToString();
first_tape_output = BuildZerosLikeNode(tape_, cell_inputs_[index]);
} else {
first_tape_output = input_adjoint_iter->second->RealDout();
}
MS_EXCEPTION_IF_NULL(first_tape_output);
if (grad_weights) {
tape_output =
tape_->NewCNode({NewValueNode(prim::kPrimMakeTuple), first_tape_output, tape_->NewCNode(grad_weights_list)});
MS_EXCEPTION_IF_NULL(tape_output);
tape_output->set_abstract(std::make_shared<abstract::AbstractTuple>(
abstract::AbstractBasePtrList{first_tape_output->abstract(), grad_weights_spec}));
return tape_output;
}
return first_tape_output;
}
if (grad_weights) {
tape_output = tape_->NewCNode(
{NewValueNode(prim::kPrimMakeTuple), tape_->NewCNode(grad_inputs_list), tape_->NewCNode(grad_weights_list)});
tape_output->set_abstract(
std::make_shared<abstract::AbstractTuple>(abstract::AbstractBasePtrList{grad_inputs_spec, grad_weights_spec}));
} else {
tape_output = tape_->NewCNode(grad_inputs_list);
tape_output->set_abstract(grad_inputs_spec);
}
return tape_output;
}
} // namespace ad } // namespace ad
} // namespace mindspore } // namespace mindspore

View File

@ -40,6 +40,16 @@ class KPynativeCell {
using KPynativeCellPtr = std::shared_ptr<KPynativeCell>; using KPynativeCellPtr = std::shared_ptr<KPynativeCell>;
struct GradAttr {
bool grad_all_inputs;
bool grad_weights;
bool has_sens;
bool get_by_position;
GradAttr(bool get_all, bool get_by_list, bool sens_param, bool get_by_position)
: grad_all_inputs(get_all), grad_weights(get_by_list), has_sens(sens_param), get_by_position(get_by_position) {}
};
// bprop_fg: user defined back propagate funcgraph or back propagate funcgraph of primitive, it will be passed after // bprop_fg: user defined back propagate funcgraph or back propagate funcgraph of primitive, it will be passed after
// just parsed. will have prototype: // just parsed. will have prototype:
// (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout) // (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout)
@ -71,8 +81,8 @@ KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
// else: // else:
// each cnode in primal funcgraph is replaced by value node // each cnode in primal funcgraph is replaced by value node
FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights, FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights,
const std::vector<size_t> &grad_position, bool grad_inputs, bool grad_weights, const std::vector<size_t> &grad_position, const GradAttr &grad_attr,
bool has_sens_arg = false, bool build_formal_param = false); bool build_formal_param = false);
// Grad for each operation. // Grad for each operation.
// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim. // c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.

View File

@ -565,7 +565,7 @@ void GradExecutor::GradNetInner(const py::object *ret, const prim::GradOperation
// Get params(weights) require derivative // Get params(weights) require derivative
auto w_args = GetWeightsArgs(weights, df_builder); auto w_args = GetWeightsArgs(weights, df_builder);
auto p_args = GetGradPositionArgs(grad_position); auto p_args = GetGradPositionArgs(grad_position, grad->get_by_position_);
if (w_args.empty() && !df_builder->parameters().empty()) { if (w_args.empty() && !df_builder->parameters().empty()) {
MS_LOG(DEBUG) << "Add weights params to w_args"; MS_LOG(DEBUG) << "Add weights params to w_args";
(void)w_args.insert(w_args.end(), df_builder->parameters().cbegin(), df_builder->parameters().cend()); (void)w_args.insert(w_args.end(), df_builder->parameters().cbegin(), df_builder->parameters().cend());
@ -650,15 +650,19 @@ std::vector<AnfNodePtr> GradExecutor::GetWeightsArgs(const py::object &weights,
return w_args; return w_args;
} }
std::vector<size_t> GradExecutor::GetGradPositionArgs(const py::object &grad_position) const { std::vector<size_t> GradExecutor::GetGradPositionArgs(const py::object &grad_position,
const bool get_by_position) const {
std::vector<size_t> pos_args; std::vector<size_t> pos_args;
if (!get_by_position) {
return pos_args;
}
if (py::isinstance<py::tuple>(grad_position)) { if (py::isinstance<py::tuple>(grad_position)) {
const auto &tuple = grad_position.cast<py::tuple>(); const auto &tuple = grad_position.cast<py::tuple>();
(void)std::transform(tuple.begin(), tuple.end(), std::back_inserter(pos_args), (void)std::transform(tuple.begin(), tuple.end(), std::back_inserter(pos_args),
[](const py::handle &elem) { return py::cast<int64_t>(elem); }); [](const py::handle &elem) { return py::cast<int64_t>(elem); });
return pos_args; return pos_args;
} }
MS_LOG(EXCEPTION) << "Grad position only support tuple."; MS_LOG(EXCEPTION) << "Grad position only support tuple when grad_by_position is set True.";
} }
void GradExecutor::ShallowCopySensValue(const py::tuple &input_args, bool has_sens, VectorRef *run_args) const { void GradExecutor::ShallowCopySensValue(const py::tuple &input_args, bool has_sens, VectorRef *run_args) const {
@ -778,8 +782,9 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const prim::GradOperationPtr &grad, con
auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr(); auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr); MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
MS_EXCEPTION_IF_NULL(grad); MS_EXCEPTION_IF_NULL(grad);
FuncGraphPtr bprop_graph = ad::GradPynativeCellEnd(k_pynative_cell_ptr, weights, grad_position, grad->get_all_, ad::GradAttr grad_attr(grad->get_all_, grad->get_by_list_, grad->sens_param_, grad->get_by_position_);
grad->get_by_list_, grad->sens_param_, build_formal_param); FuncGraphPtr bprop_graph =
ad::GradPynativeCellEnd(k_pynative_cell_ptr, weights, grad_position, grad_attr, build_formal_param);
MS_EXCEPTION_IF_NULL(bprop_graph); MS_EXCEPTION_IF_NULL(bprop_graph);
MS_LOG(DEBUG) << "Top graph input params size " << arg_size; MS_LOG(DEBUG) << "Top graph input params size " << arg_size;

View File

@ -170,7 +170,7 @@ class GradExecutor {
const abstract::AbstractBasePtr &input_abs, const abstract::AbstractBasePtr &input_abs,
const abstract::AbstractBasePtr &param_tensor_abs, const std::string &input_shape); const abstract::AbstractBasePtr &param_tensor_abs, const std::string &input_shape);
void UpdateParamAbsByArgs(const py::list &args, const FuncGraphPtr &bprop_graph); void UpdateParamAbsByArgs(const py::list &args, const FuncGraphPtr &bprop_graph);
std::vector<size_t> GetGradPositionArgs(const py::object &grad_position) const; std::vector<size_t> GetGradPositionArgs(const py::object &grad_position, const bool get_by_position) const;
void ShallowCopySensValue(const py::tuple &input_args, bool has_sens, VectorRef *run_args) const; void ShallowCopySensValue(const py::tuple &input_args, bool has_sens, VectorRef *run_args) const;
// Manage resource for construct forward graph. // Manage resource for construct forward graph.
AnfNodePtr GetObjNode(const ValuePtr &v, const std::string &obj_id) const; AnfNodePtr GetObjNode(const ValuePtr &v, const std::string &obj_id) const;

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""grad impl.""" """grad impl."""
from __future__ import absolute_import
from mindspore.ops._grad.grad_base import get_bprop_fn, get_taylor_fprop_fn from mindspore.ops._grad.grad_base import get_bprop_fn, get_taylor_fprop_fn
from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, grad_clip_ops, \ from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, grad_clip_ops, \
grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops, grad_sparse, grad_inner_ops, taylor_rule grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops, grad_sparse, grad_inner_ops, taylor_rule

View File

@ -347,7 +347,7 @@ class GradOperation(GradOperation_):
self.get_all = get_all self.get_all = get_all
self.get_by_list = get_by_list self.get_by_list = get_by_list
self.sens_param = sens_param self.sens_param = sens_param
GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, False) GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, False, False, False)
self.grad_fn = None self.grad_fn = None
self.fn = None self.fn = None
self.weights_id = None self.weights_id = None
@ -453,7 +453,7 @@ class _Grad(GradOperation_):
A higher-order function which is used to generate the gradient function by position for the input function. A higher-order function which is used to generate the gradient function by position for the input function.
""" """
def __init__(self, get_by_list=False, sens_param=False, get_by_position=False): def __init__(self, get_by_list=False, sens_param=False, get_by_position=False, has_aux=False, get_value=False):
"""Initialize _Grad.""" """Initialize _Grad."""
if not isinstance(get_by_position, bool): if not isinstance(get_by_position, bool):
raise TypeError(f"For '_Grad', the 'get_by_position' should be bool, " raise TypeError(f"For '_Grad', the 'get_by_position' should be bool, "
@ -464,10 +464,18 @@ class _Grad(GradOperation_):
if not isinstance(sens_param, bool): if not isinstance(sens_param, bool):
raise TypeError(f"For '_Grad', the 'sens_param' should be bool, " raise TypeError(f"For '_Grad', the 'sens_param' should be bool, "
f"but got {type(sens_param).__name__}") f"but got {type(sens_param).__name__}")
if not isinstance(has_aux, bool):
raise TypeError(f"For '_Grad', the 'has_aux' should be bool, "
f"but got {type(has_aux).__name__}")
if not isinstance(get_value, bool):
raise TypeError(f"For '_Grad', the 'get_value' should be bool, "
f"but got {type(get_value).__name__}")
self.get_by_position = get_by_position self.get_by_position = get_by_position
self.get_by_list = get_by_list self.get_by_list = get_by_list
self.sens_param = sens_param self.sens_param = sens_param
GradOperation_.__init__(self, 'grad', False, get_by_list, sens_param, get_by_position) self.has_aux = has_aux
self.get_value = get_value
GradOperation_.__init__(self, 'grad', False, get_by_list, sens_param, get_by_position, has_aux, get_value)
self.grad_fn = None self.grad_fn = None
self.fn = None self.fn = None
self.pynative_ = False self.pynative_ = False
@ -480,9 +488,19 @@ class _Grad(GradOperation_):
if self.grad_fn is not None and self.fn == fn and self.grad_position == grad_position and \ if self.grad_fn is not None and self.fn == fn and self.grad_position == grad_position and \
self.weights_id == weights_id: self.weights_id == weights_id:
return self.grad_fn return self.grad_fn
self.fn = fn
self.grad_position = grad_position def aux_fn(*args):
grad_ = _Grad(self.get_by_list, self.sens_param, self.get_by_position) outputs = fn(*args)
if not isinstance(outputs, tuple) or len(outputs) < 2:
raise ValueError("When has_aux is True, origin fn requires more than one outputs, but got "
+ len(outputs) + " outputs.")
res = (outputs[0],)
stop_gradient = Primitive("stop_gradient")
for item in outputs[1:]:
res += (stop_gradient(item),)
return res
grad_ = _Grad(self.get_by_list, self.sens_param, self.get_by_position, self.has_aux, self.get_value)
# If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE # If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE
# If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE # If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
# In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation. # In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
@ -508,23 +526,34 @@ class _Grad(GradOperation_):
@_wrap_func @_wrap_func
def after_grad(*args, **kwargs): def after_grad(*args, **kwargs):
self._pynative_forward_run(fn, grad_, args, kwargs) forward_flag = self.get_value or self.has_aux
res = self._pynative_forward_run(fn, grad_, forward_flag, args, kwargs)
_pynative_executor.grad(fn, grad_, weights, grad_position, *args, **kwargs) _pynative_executor.grad(fn, grad_, weights, grad_position, *args, **kwargs)
out = _pynative_executor(fn, grad_.sens_param, *args, **kwargs) out = _pynative_executor(fn, grad_.sens_param, *args, **kwargs)
_pynative_executor.clear_grad(fn, *args, **kwargs) _pynative_executor.clear_grad(fn, *args, **kwargs)
if self.get_value:
return res, out
if self.has_aux:
return out, res[1:]
return out return out
else: else:
grad_.pynative_ = True grad_.pynative_ = True
# after_grad of this branch can't use @ms_function, just directly call grad_ # after_grad of this branch can't use @ms_function, just directly call grad_
if self.get_by_position: if self.get_by_position:
def after_grad(*args, **kwargs): def after_grad(*args, **kwargs):
if self.has_aux:
return grad_(aux_fn, weights, grad_position)(*args, **kwargs)
return grad_(fn, weights, grad_position)(*args, **kwargs) return grad_(fn, weights, grad_position)(*args, **kwargs)
else: else:
if self.get_by_list: if self.get_by_list:
def after_grad(*args, **kwargs): def after_grad(*args, **kwargs):
if self.has_aux:
return grad_(aux_fn, weights)(*args, **kwargs)
return grad_(fn, weights)(*args, **kwargs) return grad_(fn, weights)(*args, **kwargs)
else: else:
def after_grad(*args, **kwargs): def after_grad(*args, **kwargs):
if self.has_aux:
return grad_(aux_fn)(*args, **kwargs)
return grad_(fn)(*args, **kwargs) return grad_(fn)(*args, **kwargs)
self.grad_fn = after_grad self.grad_fn = after_grad
@ -534,9 +563,10 @@ class _Grad(GradOperation_):
self.grad_hash_id = (grad_position, weights_id) self.grad_hash_id = (grad_position, weights_id)
return self.grad_fn return self.grad_fn
def _pynative_forward_run(self, fn, grad, args, kwargs): def _pynative_forward_run(self, fn, grad, forward_flag, args, kwargs):
""" Pynative forward runs to build grad graph. """ """ Pynative forward runs to build grad graph. """
new_kwargs = kwargs new_kwargs = kwargs
outputs = ()
if self.sens_param: if self.sens_param:
if 'sens' in kwargs.keys(): if 'sens' in kwargs.keys():
new_kwargs = kwargs.copy() new_kwargs = kwargs.copy()
@ -549,12 +579,17 @@ class _Grad(GradOperation_):
_pynative_executor.new_graph(fn, *args, **new_kwargs) _pynative_executor.new_graph(fn, *args, **new_kwargs)
outputs = fn(*args, **new_kwargs) outputs = fn(*args, **new_kwargs)
_pynative_executor.end_graph(fn, outputs, *args, **new_kwargs) _pynative_executor.end_graph(fn, outputs, *args, **new_kwargs)
return outputs
else: else:
# Check if fn has run already. # Check if fn has run already.
if not _pynative_executor.check_run(grad, fn, self.grad_hash_id, *args, **new_kwargs): if not _pynative_executor.check_run(grad, fn, self.grad_hash_id, *args, **new_kwargs):
fn.set_grad() fn.set_grad()
fn(*args, **new_kwargs) outputs = fn(*args, **new_kwargs)
fn.set_grad(False) fn.set_grad(False)
return outputs
if forward_flag and not outputs:
outputs = fn(*args, **new_kwargs)
return outputs
class _Vmap(VmapOperation_): class _Vmap(VmapOperation_):

View File

@ -345,6 +345,7 @@ from .random_func import (
from .grad import ( from .grad import (
grad_func, grad_func,
grad, grad,
value_and_grad,
jet, jet,
derivative, derivative,
jvp, jvp,

View File

@ -17,6 +17,7 @@
from .grad_func import ( from .grad_func import (
grad, grad,
value_and_grad,
jet, jet,
derivative, derivative,
jvp, jvp,

View File

@ -57,53 +57,248 @@ def _convert_grad_position_type(grad_position):
return grad_position return grad_position
grad_by_position = _Grad(get_by_list=False, sens_param=False, get_by_position=True) @constexpr
grad_by_position_with_sens = _Grad(get_by_list=False, sens_param=True, get_by_position=True) def _get_grad_op(get_by_list, get_by_position, has_aux, get_value=False):
return _Grad(get_by_list=get_by_list, get_by_position=get_by_position, has_aux=has_aux, get_value=get_value)
def grad(fn, grad_position=0, sens_param=False): def grad(fn, grad_position=0, weights=None, has_aux=False):
r""" """
A wrapper function to generate the gradient function for the input function. A wrapper function to generate the gradient function for the input function.
As for gradient, three typical cases are included:
1. gradient with respect to inputs. In this case, `grad_position` is not None while `weights` is None.
2. gradient with respect to weights. In this case, `grad_position` is None while `weights` is not None.
3. gradient with respect to inputs and weights. In this case, `grad_position` and `weights` are not None.
Args: Args:
fn (Union(Cell, function)): Function to do GradOperation. fn (Union(Cell, function)): Function to do GradOperation.
grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input. grad_position (Union(NoneType, int, tuple[int])): Index to specify which inputs to be differentiated.
If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0. If int, get the gradient with respect to single input.
sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input. If tuple, get the gradients with respect to selected inputs. `grad_position` begins with 0.
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False. If None, none derivative of any input will be figured out, and in this case, `weights` is required.
Default: 0.
weights (Union(ParameterTuple, Parameter, list(Parameter))): The parameters of the training network that need to
calculate the gradient. `weights` can be got through `weights = net.trainable_params()` .
Default: None.
has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
will be returned straightly. It means the `fn` must return more than one outputs in this case.
Default: False.
Returns: Returns:
Function, returns the gradient function for the input function or cell. Function, the gradient function to calculate gradient for the input function or cell.
For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set True, gradient function will return outputs
like `(gradient, out2)` and `out2` does not contribute to the differentiation, otherwise `gradient`.
Raises:
ValueError: If both `grad_position` and `weights` are None.
TypeError: If type of Args does not belong to required ones.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> import numpy as np >>> import numpy as np
>>> import mindspore as ms >>> import mindspore
>>> import mindspore.nn as nn >>> import mindspore.nn as nn
>>> from mindspore import Tensor >>> from mindspore import Tensor, ops
>>> from mindspore.ops.functional import grad >>> from mindspore.ops import grad
>>> ms.set_context(mode=ms.GRAPH_MODE) >>>
>>> # Cell object to be differentiated
>>> class Net(nn.Cell): >>> class Net(nn.Cell):
... def construct(self, x, y, z): ... def construct(self, x, y, z):
... return x*y*z ... return x * y * z
>>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) >>> x = Tensor([1, 2], mindspore.float32)
>>> y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32)) >>> y = Tensor([-2, 3], mindspore.float32)
>>> z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32)) >>> z = Tensor([0, 3], mindspore.float32)
>>> net = Net() >>> net = Net()
>>> output = grad(net, grad_position=(1, 2))(x, y, z) >>> output = grad(net, grad_position=(1, 2))(x, y, z)
>>> print(output) >>> print(output)
(Tensor(shape=[2, 2], dtype=Float32, value= (Tensor(shape=[2], dtype=Float32, value=[ 0.00000000e+00, 6.00000000e+00]),
[[ 0.00000000e+00, 6.00000000e+00], Tensor(shape=[2], dtype=Float32, value=[-2.00000000e+00, 6.00000000e+00]))
[ 1.50000000e+01, -4.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value= >>>
[[-2.00000000e+00, 6.00000000e+00], >>> # Function object to be differentiated
[-3.00000000e+00, 8.00000000e+00]])) >>> def fn(x, y, z):
... res = x * ops.exp(y) * ops.pow(z, 2)
... return res, z
>>> x = Tensor([3, 3], mindspore.float32)
>>> y = Tensor([0, 0], mindspore.float32)
>>> z = Tensor([2, 2], mindspore.float32)
>>> gradient, aux = grad(net, (1, 2), None, True)(x, y)
>>> print(gradient)
(Tensor(shape=[2], dtype=Float32, value= [ 7.50000000e+01, 7.50000000e+01]),
Tensor(shape=[2], dtype=Float32, value= [ 3.00000000e+01, 3.00000000e+01]))
>>> print(aux)
(Tensor(shape=[2], dtype=Float32, value= [ 5.00000000e+00, 5.00000000e+00]),)
>>>
>>> # For given network to be differentiated with both inputs and weights, there are 3 cases.
>>> net = nn.Dense(10, 1)
>>> loss_fn = nn.MSELoss()
>>> def forward(inputs, labels):
... logits = net(inputs)
... loss = loss_fn(logits, labels)
... return loss, logits
>>> inputs = Tensor(np.random.randn(16, 10).astype(np.float32))
>>> labels = Tensor(np.random.randn(16, 1).astype(np.float32))
>>> weights = net.trainable_params()
>>> # Case 1: gradient with respect to inputs.
>>> # Aux value does not contribute to the gradient.
>>> grad_fn = grad(forward, grad_position=0, weights=None, has_aux=True)
>>> inputs_gradient, (aux_logits,) = grad_fn(inputs, labels)
>>> print(len(inputs_gradient))
2
>>> print(aux_logits.shape)
(16, 1)
>>>
>>> # Case 2: gradient with respect to weights.
>>> grad_fn = grad(forward, grad_position=None, weights=weights, has_aux=True)
>>> params_gradient, (aux_logits,) = grad_fn(inputs, labels)
>>> print(len(weights), len(params_gradient))
2 2
>>> print(aux_logits.shape)
(16, 1)
>>>
>>> # Case 3: gradient with respect to inputs and weights.
>>> grad_fn = grad(forward, grad_position=0, weights=weights, has_aux=False)
>>> inputs_gradient, params_gradient = grad_fn(inputs, labels)
>>> print(len(weights), len(params_gradient))
2 2
""" """
if grad_position is None and weights is None:
raise ValueError("`grad_position` and `weight` can not be None at the same time.")
if grad_position is None:
return _get_grad_op(True, False, has_aux)(fn, weights)
grad_position = _convert_grad_position_type(grad_position) grad_position = _convert_grad_position_type(grad_position)
if sens_param: if weights is None:
return grad_by_position_with_sens(fn, None, grad_position) return _get_grad_op(False, True, has_aux)(fn, None, grad_position)
return grad_by_position(fn, None, grad_position) return _get_grad_op(True, True, has_aux)(fn, weights, grad_position)
def value_and_grad(fn, grad_position=0, weights=None, has_aux=False):
"""
A wrapper function to generate the function to calculate forward output and gradient for the input function.
As for gradient, three typical cases are included:
1. gradient with respect to inputs. In this case, `grad_position` is not None while `weights` is None.
2. gradient with respect to weights. In this case, `grad_position` is None while `weights` is not None.
3. gradient with respect to inputs and weights. In this case, `grad_position` and `weights` are not None.
Args:
fn (Union(Cell, function)): Function to do GradOperation.
grad_position (Union(NoneType, int, tuple[int])): Index to specify which inputs to be differentiated.
If int, get the gradient with respect to single input.
If tuple, get the gradients with respect to selected inputs. `grad_position` begins with 0.
If None, none derivative of any input will be solved, and in this case, `weights` is required.
Default: 0.
weights (Union(ParameterTuple, Parameter, list(Parameter))): The parameters of the training network that need to
calculate the gradient. `weights` can be got through `weights = net.trainable_params()` .
Default: None.
has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
will be returned straightly. It means the `fn` must return more than one outputs in this case.
Default: False.
Returns:
Function, returns the gradient function to calculate forward output and gradient for the input function or cell.
For example, as for `out1, out2 = fn(*args)` , gradient function will return outputs like
`((out1, out2), gradient)` . When `has_aux` is set True, only `out1` contributes to the differentiation.
Raises:
ValueError: If both `grad_position` and `weights` are None.
TypeError: If type of Args does not belong to required ones.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor, ops, nn
>>> from mindspore.ops import value_and_grad
>>>
>>> # Cell object to be differentiated
>>> class Net(nn.Cell):
... def construct(self, x, y, z):
... return x * y * z
>>> x = Tensor([1, 2], mindspore.float32)
>>> y = Tensor([-2, 3]), mindspore.float32)
>>> z = Tensor([0, 3]), mindspore.float32)
>>> net = Net()
>>> grad_fn = value_and_grad(net, grad_position=1)
>>> output, inputs_gradient = grad_fn(x, y, z)
>>> print(output)
[ -0. 18.]
>>> print(inputs_gradient)
[0, 6.]
>>>
>>> # Function object to be differentiated
>>> def fn(x, y, z):
... res = x * ops.exp(y) * ops.pow(z, 2)
... return res, z
>>> x = Tensor(np.array([3, 3]).astype(np.float32))
>>> y = Tensor(np.array([0, 0]).astype(np.float32))
>>> z = Tensor(np.array([2, 2]).astype(np.float32))
>>> output, inputs_gradient = grad(net, grad_position=(1, 2), weights=None, has_aux=True)(x, y)
>>> print(output)
(Tensor(shape=[2], dtype=Float32, value= [ 7.50000000e+01, 7.50000000e+01]),
Tensor(shape=[2], dtype=Float32, value= [ 5.00000000e+00, 5.00000000e+00]))
>>> print(inputs_gradient)
(Tensor(shape=[2], dtype=Float32, value= [ 7.50000000e+01, 7.50000000e+01]),
Tensor(shape=[2], dtype=Float32, value= [ 3.00000000e+01, 3.00000000e+01]))
>>>
>>> # For given network to be differentiated with both inputs and weights, there are 3 cases.
>>> net = nn.Dense(10, 1)
>>> loss_fn = nn.MSELoss()
>>> def forward(inputs, labels):
... logits = net(inputs)
... loss = loss_fn(logits, labels)
... return loss, logits
>>> inputs = Tensor(np.random.randn(16, 10).astype(np.float32))
>>> labels = Tensor(np.random.randn(16, 1).astype(np.float32))
>>> weights = net.trainable_params()
>>>
>>> # Case 1: gradient with respect to inputs.
>>> # For has_aux is set True, only loss contributes to the gradient.
>>> grad_fn = value_and_grad(forward, grad_position=0, weights=None, has_aux=True)
>>> (loss, logits), inputs_gradient = grad_fn(inputs, labels)
>>> print(logits.shape)
(16, 1)
>>> print(inputs.shape, inputs_gradient.shape)
(16, 10) (16, 10)
>>>
>>> # Case 2: gradient with respect to weights.
>>> # For has_aux is set True, only loss contributes to the gradient.
>>> grad_fn = value_and_grad(forward, grad_position=None, weights=weights, has_aux=True)
>>> (loss, logits), params_gradient = grad_fn(inputs, labels)
>>> print(logits.shape)
(16, 1)
>>> print(len(weights), len(params_gradient))
2 2
>>>
>>> # Case 3: gradient with respect to inputs and weights.
>>> # For has_aux is set False, both loss and logits contribute to the gradient.
>>> grad_fn = value_and_grad(forward, grad_position=0, weights=weights, has_aux=False)
>>> (loss, logits), (inputs_gradient, params_gradient) = grad_fn(inputs, labels)
>>> print(logits.shape)
(16, 1)
>>> print(inputs.shape, inputs_gradient.shape)
(16, 10) (16, 10)
>>> print(len(weights), len(params_gradient))
2 2
"""
if grad_position is None and weights is None:
raise ValueError("`grad_position` and `weight` can not be None at the same time.")
if grad_position is None:
return _get_grad_op(True, False, has_aux, True)(fn, weights)
grad_position = _convert_grad_position_type(grad_position)
if weights is None:
return _get_grad_op(False, True, has_aux, True)(fn, None, grad_position)
return _get_grad_op(True, True, has_aux, True)(fn, weights, grad_position)
def _trans_jet_inputs(primals_item, series_item): def _trans_jet_inputs(primals_item, series_item):
@ -531,6 +726,7 @@ def vjp(fn, inputs, v):
__all__ = [ __all__ = [
'grad', 'grad',
'value_and_grad',
'jet', 'jet',
'derivative', 'derivative',
'jvp', 'jvp',

View File

@ -19,7 +19,7 @@ import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore import ms_function from mindspore import ms_function
from mindspore.ops.functional import grad from mindspore.ops.functional import grad, value_and_grad
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore import Parameter, ParameterTuple from mindspore import Parameter, ParameterTuple
@ -29,19 +29,22 @@ context.set_context(mode=context.GRAPH_MODE)
class SingleInputSingleOutputNet(nn.Cell): class SingleInputSingleOutputNet(nn.Cell):
def construct(self, x): def construct(self, x):
return x**3 return x ** 3
class SingleInputMultipleOutputsNet(nn.Cell): class SingleInputMultipleOutputsNet(nn.Cell):
def construct(self, x): def construct(self, x):
return x**3, 2*x return x ** 3, 2 * x
class MultipleInputsSingleOutputNet(nn.Cell): class MultipleInputsSingleOutputNet(nn.Cell):
def construct(self, x, y, z): def construct(self, x, y, z):
return x*y*z return x * y * z
class MultipleInputsMultipleOutputsNet(nn.Cell): class MultipleInputsMultipleOutputsNet(nn.Cell):
def construct(self, x, y, z): def construct(self, x, y, z):
return x**2 + y**2 + z**2, x*y*z return x ** 2 + y ** 2 + z ** 2, x * y * z
class ParamNet(nn.Cell): class ParamNet(nn.Cell):
@ -56,15 +59,15 @@ class ParamNet(nn.Cell):
def function(x, y, z): def function(x, y, z):
return x**2 + y**2 + z**2, x*y*z return x ** 2 + y ** 2 + z ** 2, x * y * z
def iteration_grad_function(x, y, z): def iteration_grad_function(x, y, z):
return x**2*y*z return x ** 2 * y * z
@ms_function @ms_function
def grad_warp_with_msfunction(x, y, z): def grad_wrap_with_msfunction(x, y, z):
output = grad(function)(x, y, z) output = grad(function)(x, y, z)
return output return output
@ -145,28 +148,6 @@ def test_grad_multiple_inputs_multiple_outputs_cell_graph():
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy()) assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_function_with_sens_graph():
"""
Features: Function grad.
Description: Test F.grad with function setting sens_param in graph mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
v = Tensor(np.array([[-1, 3], [2, 1]]).astype(np.float32))
expect_grad1 = Tensor(np.array([[4, 36], [26, 0]]).astype(np.float32))
expect_grad2 = Tensor(np.array([[2, 36], [14, 6]]).astype(np.float32))
real_grad = grad(function, grad_position=(1, 2), sens_param=True)(x, y, z, (v, v))
assert isinstance(real_grad, tuple)
assert len(real_grad) == 2
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard @pytest.mark.env_onecard
@ -191,17 +172,17 @@ def test_grad_iteration_function_graph():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_grad_warp_with_msfunction_graph(): def test_grad_wrap_with_msfunction_graph():
""" """
Features: Function grad. Features: Function grad.
Description: Test F.grad warpped with ms_function in graph mode. Description: Test F.grad wrapped with ms_function in graph mode.
Expectation: No exception. Expectation: No exception.
""" """
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32)) y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32)) z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32)) expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
real_grad = grad_warp_with_msfunction(x, y, z) real_grad = grad_wrap_with_msfunction(x, y, z)
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy()) assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
@ -246,6 +227,210 @@ def test_grad_with_weights_twice_graph():
assert np.allclose(out2[0].asnumpy(), expect2) assert np.allclose(out2[0].asnumpy(), expect2)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_with_weights_has_aux_graph():
"""
Features: Function grad.
Description: Test F.grad with different weights and has_aux in graph mode.
Expectation: No exception.
"""
class ParamNetAux(nn.Cell):
def __init__(self):
super(ParamNetAux, self).__init__()
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
def construct(self, x):
res = x * self.w * self.z
return res, x, self.w
x = Tensor(np.array([1, 2]).astype(np.float32))
net = ParamNetAux()
weights = ParameterTuple(net.trainable_params())
expect_grad_input = np.array([6, 6]).astype(np.float32)
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
expect_grad_weight2 = np.array([2, 4]).astype(np.float32)
expect_aux1 = np.array([1, 2]).astype(np.float32)
expect_aux2 = np.array([2, 2]).astype(np.float32)
res, aux = grad(net, 0, weights, True)(x)
assert np.allclose(res[0].asnumpy(), expect_grad_input)
assert np.allclose(res[1][0].asnumpy(), expect_grad_weight1)
assert np.allclose(res[1][1].asnumpy(), expect_grad_weight2)
assert np.allclose(aux[0].asnumpy(), expect_aux1)
assert np.allclose(aux[1].asnumpy(), expect_aux2)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_ms_function_grad_with_weights_has_aux_graph():
"""
Features: Function grad.
Description: Test F.grad with different weights and has_aux in graph mode.
Expectation: No exception.
"""
class ParamMultipleInputNet(nn.Cell):
def __init__(self):
super(ParamMultipleInputNet, self).__init__()
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
def construct(self, x, y):
outputs = x * y * self.w
return outputs, x, self.w
net = ParamMultipleInputNet()
weights = net.trainable_params()
@ms_function
def user_fn(x, y):
res, aux = grad(net, 0, weights, True)(x, y)
return res, aux
x = Tensor(np.array([1, 2]).astype(np.float32))
y = Tensor(np.array([3, 3]).astype(np.float32))
res, aux = user_fn(x, y)
expect_grad_input = np.array([6, 6]).astype(np.float32)
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
expect_aux1 = np.array([1, 2]).astype(np.float32)
expect_aux2 = np.array([2, 2]).astype(np.float32)
assert np.allclose(res[0].asnumpy(), expect_grad_input)
assert np.allclose(res[1][0].asnumpy(), expect_grad_weight1)
assert np.allclose(aux[0].asnumpy(), expect_aux1)
assert np.allclose(aux[1].asnumpy(), expect_aux2)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_construct_grad_with_weights_has_aux_graph():
"""
Features: Function grad.
Description: Test F.grad with different weights and has_aux in graph mode.
Expectation: No exception.
"""
class ParamMultipleInputNet(nn.Cell):
def __init__(self):
super(ParamMultipleInputNet, self).__init__()
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
def construct(self, x, y):
outputs = x * y * self.w
return outputs, x, self.w
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = net.trainable_params()
def construct(self, x, y):
res, aux = grad(self.net, 0, self.weights, True)(x, y)
return res, aux
x = Tensor(np.array([1, 2]).astype(np.float32))
y = Tensor(np.array([3, 3]).astype(np.float32))
inner_net = ParamMultipleInputNet()
grad_net = GradNet(inner_net)
res, aux = grad_net(x, y)
expect_grad_input = np.array([6, 6]).astype(np.float32)
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
expect_aux1 = np.array([1, 2]).astype(np.float32)
expect_aux2 = np.array([2, 2]).astype(np.float32)
assert np.allclose(res[0].asnumpy(), expect_grad_input)
assert np.allclose(res[1][0].asnumpy(), expect_grad_weight1)
assert np.allclose(aux[0].asnumpy(), expect_aux1)
assert np.allclose(aux[1].asnumpy(), expect_aux2)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_if_with_weights_has_aux_graph():
"""
Features: Function grad.
Description: Test F.grad with different weights and has_aux as well as if case in graph mode.
Expectation: No exception.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
def construct(self, x):
if x[0] == 1:
res = x * self.w * self.z
else:
res = x * x
return res, x, self.w
x = Tensor(np.array([1, 2]).astype(np.float32))
net = Net()
weights = ParameterTuple(net.trainable_params())
expect_grad_input = np.array([6, 6]).astype(np.float32)
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
expect_grad_weight2 = np.array([2, 4]).astype(np.float32)
expect_aux1 = np.array([1, 2]).astype(np.float32)
expect_aux2 = np.array([2, 2]).astype(np.float32)
res, aux = grad(net, 0, weights, True)(x)
assert np.allclose(res[0].asnumpy(), expect_grad_input)
assert np.allclose(res[1][0].asnumpy(), expect_grad_weight1)
assert np.allclose(res[1][1].asnumpy(), expect_grad_weight2)
assert np.allclose(aux[0].asnumpy(), expect_aux1)
assert np.allclose(aux[1].asnumpy(), expect_aux2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_nest_with_weights_has_aux_graph():
"""
Features: Function value_and_grad.
Description: Test F.grad with different weights and has_aux as well as nested nets in graph mode.
Expectation: No exception.
"""
class InnerNet(nn.Cell):
def construct(self, x):
return x * 3, x
class Net(nn.Cell):
def __init__(self, net):
super(Net, self).__init__()
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
self.net = net
def construct(self, x):
res1 = x * self.w * self.z
res2 = self.net(res1)
return res2
x = Tensor(np.array([1, 2]).astype(np.float32))
inner_net = InnerNet()
net = Net(inner_net)
weights = ParameterTuple(net.trainable_params())
expect_grad_input = np.array([18, 18]).astype(np.float32)
expect_grad_weight1 = np.array([9, 18]).astype(np.float32)
expect_grad_weight2 = np.array([6, 12]).astype(np.float32)
expect_aux = np.array([6, 12]).astype(np.float32)
res, aux = grad(net, 0, weights, True)(x)
assert np.allclose(res[0].asnumpy(), expect_grad_input)
assert np.allclose(res[1][0].asnumpy(), expect_grad_weight1)
assert np.allclose(res[1][1].asnumpy(), expect_grad_weight2)
assert np.allclose(aux[0].asnumpy(), expect_aux)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard @pytest.mark.env_onecard
@ -255,6 +440,7 @@ def test_grad_if_ith_train_one_step():
Description: Grad a network with each output. A simplification for GAN network. Description: Grad a network with each output. A simplification for GAN network.
Expectation: Compile success. Expectation: Compile success.
""" """
class IthOutputCell(nn.Cell): class IthOutputCell(nn.Cell):
def __init__(self, network, output_index): def __init__(self, network, output_index):
super().__init__() super().__init__()
@ -316,6 +502,7 @@ def test_grad_net_d_net_g():
Description: Grad two different network. A simplification for GAN network. Description: Grad two different network. A simplification for GAN network.
Expectation: Compile success. Expectation: Compile success.
""" """
class NetD(nn.Cell): class NetD(nn.Cell):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -398,3 +585,176 @@ def test_grad_net_d_net_g():
network = Backbone() network = Backbone()
train_one_net = MyTrainOneStepCell(network) train_one_net = MyTrainOneStepCell(network)
train_one_net(x, y) train_one_net(x, y)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_value_and_grad_with_weights_has_aux_graph():
"""
Features: Function value_and_grad.
Description: Test F.value_and_grad with different weights and has_aux in graph mode.
Expectation: No exception.
"""
class ParamNetMultipleOutputs(nn.Cell):
def __init__(self):
super(ParamNetMultipleOutputs, self).__init__()
self.w1 = Parameter(Tensor([2., 2.], mstype.float32), name="w1")
self.w2 = Parameter(Tensor([3., 3.], mstype.float32), name="w2")
def construct(self, x):
res = x * self.w1 * self.w2
return res, x, self.w1
x = Tensor(np.array([1, 2]).astype(np.float32))
net = ParamNetMultipleOutputs()
weights = ParameterTuple(net.trainable_params())
expect_grad_input = np.array([6, 6]).astype(np.float32)
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
expect_grad_weight2 = np.array([2, 4]).astype(np.float32)
expect_value0 = np.array([6, 12]).astype(np.float32)
expect_value1 = np.array([1, 2]).astype(np.float32)
expect_value2 = np.array([2, 2]).astype(np.float32)
value, gradient = value_and_grad(net, 0, weights, True)(x)
assert np.allclose(value[0].asnumpy(), expect_value0)
assert np.allclose(value[1].asnumpy(), expect_value1)
assert np.allclose(value[2].asnumpy(), expect_value2)
assert np.allclose(gradient[0].asnumpy(), expect_grad_input)
assert np.allclose(gradient[1][0].asnumpy(), expect_grad_weight1)
assert np.allclose(gradient[1][1].asnumpy(), expect_grad_weight2)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_construct_value_and_grad_with_weights_has_aux_graph():
"""
Features: Function value_and_grad.
Description: Test F.value_and_grad with different weights and has_aux in graph mode.
Expectation: No exception.
"""
class ParamNetMultipleInputsOutputs(nn.Cell):
def __init__(self):
super(ParamNetMultipleInputsOutputs, self).__init__()
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
def construct(self, x, y):
res = x * y * self.w
return res, x, self.w
class GradNet2(nn.Cell):
def __init__(self, net):
super(GradNet2, self).__init__()
self.net = net
self.weights = net.trainable_params()
def construct(self, x, y):
value, gradient = value_and_grad(self.net, 0, self.weights, True)(x, y)
return value, gradient
x = Tensor(np.array([1, 2]).astype(np.float32))
y = Tensor(np.array([3, 3]).astype(np.float32))
inner_net = ParamNetMultipleInputsOutputs()
grad_net = GradNet2(inner_net)
value, gradient = grad_net(x, y)
expect_grad_input = np.array([6, 6]).astype(np.float32)
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
expect_value0 = np.array([6, 12]).astype(np.float32)
expect_value1 = np.array([1, 2]).astype(np.float32)
expect_value2 = np.array([2, 2]).astype(np.float32)
assert np.allclose(value[0].asnumpy(), expect_value0)
assert np.allclose(value[1].asnumpy(), expect_value1)
assert np.allclose(value[2].asnumpy(), expect_value2)
assert np.allclose(gradient[0].asnumpy(), expect_grad_input)
assert np.allclose(gradient[1][0].asnumpy(), expect_grad_weight1)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_value_and_grad_nest_with_weights_graph():
"""
Features: Function value_and_grad.
Description: Test F.value_and_grad with different weights and has_aux as well as nested nets in graph mode.
Expectation: No exception.
"""
class InnerNet(nn.Cell):
def construct(self, x):
return x * 3, x
class Net(nn.Cell):
def __init__(self, net):
super(Net, self).__init__()
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
self.net = net
def construct(self, x):
res1 = x * self.w * self.z
res2 = self.net(res1)
return res2
x = Tensor(np.array([1, 2]).astype(np.float32))
inner_net = InnerNet()
net = Net(inner_net)
weights = ParameterTuple(net.trainable_params())
expect_grad_input = np.array([24, 24]).astype(np.float32)
expect_grad_weight1 = np.array([12, 24]).astype(np.float32)
expect_grad_weight2 = np.array([8, 16]).astype(np.float32)
expect_value0 = np.array([18, 36]).astype(np.float32)
expect_value1 = np.array([6, 12]).astype(np.float32)
value, gradient = value_and_grad(net, 0, weights, False)(x)
assert np.allclose(value[0].asnumpy(), expect_value0)
assert np.allclose(value[1].asnumpy(), expect_value1)
assert np.allclose(gradient[0].asnumpy(), expect_grad_input)
assert np.allclose(gradient[1][0].asnumpy(), expect_grad_weight1)
assert np.allclose(gradient[1][1].asnumpy(), expect_grad_weight2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_value_and_grad_nest_with_weights_has_aux_graph():
"""
Features: Function value_and_grad.
Description: Test F.value_and_grad with different weights and has_aux as well as nested nets in graph mode.
Expectation: No exception.
"""
class InnerNet(nn.Cell):
def construct(self, x):
return x * 3, x
class Net(nn.Cell):
def __init__(self, net):
super(Net, self).__init__()
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
self.net = net
def construct(self, x):
res1 = x * self.w * self.z
res2 = self.net(res1)
return res2
x = Tensor(np.array([1, 2]).astype(np.float32))
inner_net = InnerNet()
net = Net(inner_net)
weights = ParameterTuple(net.trainable_params())
expect_grad_input = np.array([18, 18]).astype(np.float32)
expect_grad_weight1 = np.array([9, 18]).astype(np.float32)
expect_grad_weight2 = np.array([6, 12]).astype(np.float32)
expect_value0 = np.array([18, 36]).astype(np.float32)
expect_value1 = np.array([6, 12]).astype(np.float32)
value, gradient = value_and_grad(net, 0, weights, True)(x)
assert np.allclose(value[0].asnumpy(), expect_value0)
assert np.allclose(value[1].asnumpy(), expect_value1)
assert np.allclose(gradient[0].asnumpy(), expect_grad_input)
assert np.allclose(gradient[1][0].asnumpy(), expect_grad_weight1)
assert np.allclose(gradient[1][1].asnumpy(), expect_grad_weight2)

View File

@ -19,8 +19,9 @@ import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore import ms_function from mindspore import ms_function
from mindspore.ops.functional import grad
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops.functional import grad, value_and_grad
from mindspore.common import dtype as mstype
from mindspore import Parameter, ParameterTuple from mindspore import Parameter, ParameterTuple
context.set_context(mode=context.PYNATIVE_MODE) context.set_context(mode=context.PYNATIVE_MODE)
@ -28,19 +29,22 @@ context.set_context(mode=context.PYNATIVE_MODE)
class SingleInputSingleOutputNet(nn.Cell): class SingleInputSingleOutputNet(nn.Cell):
def construct(self, x): def construct(self, x):
return x**3 return x ** 3
class SingleInputMultipleOutputsNet(nn.Cell): class SingleInputMultipleOutputsNet(nn.Cell):
def construct(self, x): def construct(self, x):
return x**3, 2*x return x ** 3, 2 * x
class MultipleInputsSingleOutputNet(nn.Cell): class MultipleInputsSingleOutputNet(nn.Cell):
def construct(self, x, y, z): def construct(self, x, y, z):
return x*y*z return x * y * z
class MultipleInputsMultipleOutputsNet(nn.Cell): class MultipleInputsMultipleOutputsNet(nn.Cell):
def construct(self, x, y, z): def construct(self, x, y, z):
return x**2 + y**2 + z**2, x*y*z return x ** 2 + y ** 2 + z ** 2, x * y * z
class ParamNet(nn.Cell): class ParamNet(nn.Cell):
@ -55,13 +59,15 @@ class ParamNet(nn.Cell):
def function(x, y, z): def function(x, y, z):
return x**2 + y**2 + z**2, x*y*z return x ** 2 + y ** 2 + z ** 2, x * y * z
def iteration_grad_function(x, y, z): def iteration_grad_function(x, y, z):
return x**2*y*z return x ** 2 * y * z
@ms_function @ms_function
def grad_warp_with_msfunction(x, y, z): def grad_wrap_with_msfunction(x, y, z):
output = grad(function)(x, y, z) output = grad(function)(x, y, z)
return output return output
@ -142,28 +148,6 @@ def test_grad_multiple_inputs_multiple_outputs_cell_pynative():
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy()) assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_function_with_sens_pynative():
"""
Features: Function grad.
Description: Test F.grad with function setting sens_param in pynative mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
v = Tensor(np.array([[-1, 3], [2, 1]]).astype(np.float32))
expect_grad1 = Tensor(np.array([[4, 36], [26, 0]]).astype(np.float32))
expect_grad2 = Tensor(np.array([[2, 36], [14, 6]]).astype(np.float32))
real_grad = grad(function, grad_position=(1, 2), sens_param=True)(x, y, z, (v, v))
assert isinstance(real_grad, tuple)
assert len(real_grad) == 2
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard @pytest.mark.env_onecard
@ -184,22 +168,24 @@ def test_grad_iteration_function_pynative():
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy()) assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy()) assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_grad_warp_with_msfunction_pynative(): def test_grad_wrap_with_msfunction_pynative():
""" """
Features: Function grad. Features: Function grad.
Description: Test F.grad warpped with ms_function in pynative mode. Description: Test F.grad wrapped with ms_function in pynative mode.
Expectation: No exception. Expectation: No exception.
""" """
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32)) y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32)) z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32)) expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
real_grad = grad_warp_with_msfunction(x, y, z) real_grad = grad_wrap_with_msfunction(x, y, z)
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy()) assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard @pytest.mark.env_onecard
@ -239,3 +225,249 @@ def test_grad_with_weights_twice_pynative():
out2 = grad_fn(net, weights2)(x) out2 = grad_fn(net, weights2)(x)
assert np.allclose(out1[0].asnumpy(), expect1) assert np.allclose(out1[0].asnumpy(), expect1)
assert np.allclose(out2[0].asnumpy(), expect2) assert np.allclose(out2[0].asnumpy(), expect2)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_with_weights_has_aux_pynative():
"""
Features: Function grad.
Description: Test F.grad with different weights and has_aux in pynative mode.
Expectation: No exception.
"""
class ParamNetAux(nn.Cell):
def __init__(self):
super(ParamNetAux, self).__init__()
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
def construct(self, x):
res = x * self.w * self.z
return res, x, self.w
x = Tensor(np.array([1, 2]).astype(np.float32))
net = ParamNetAux()
weights = ParameterTuple(net.trainable_params())
expect_grad_input = np.array([6, 6]).astype(np.float32)
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
expect_grad_weight2 = np.array([2, 4]).astype(np.float32)
expect_aux1 = np.array([1, 2]).astype(np.float32)
expect_aux2 = np.array([2, 2]).astype(np.float32)
res, aux = grad(net, 0, weights, True)(x)
assert np.allclose(res[0].asnumpy(), expect_grad_input)
assert np.allclose(res[1][0].asnumpy(), expect_grad_weight1)
assert np.allclose(res[1][1].asnumpy(), expect_grad_weight2)
assert np.allclose(aux[0].asnumpy(), expect_aux1)
assert np.allclose(aux[1].asnumpy(), expect_aux2)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_if_with_weights_has_aux_pynative():
"""
Features: Function grad.
Description: Test F.grad with different weights and has_aux as well as if case in pynative mode.
Expectation: No exception.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
def construct(self, x):
if x[0] == 1:
res = x * self.w * self.z
else:
res = x * x
return res, x, self.w
x = Tensor(np.array([1, 2]).astype(np.float32))
net = Net()
weights = ParameterTuple(net.trainable_params())
expect_grad_input = np.array([6, 6]).astype(np.float32)
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
expect_grad_weight2 = np.array([2, 4]).astype(np.float32)
expect_aux1 = np.array([1, 2]).astype(np.float32)
expect_aux2 = np.array([2, 2]).astype(np.float32)
res, aux = grad(net, 0, weights, True)(x)
assert np.allclose(res[0].asnumpy(), expect_grad_input)
assert np.allclose(res[1][0].asnumpy(), expect_grad_weight1)
assert np.allclose(res[1][1].asnumpy(), expect_grad_weight2)
assert np.allclose(aux[0].asnumpy(), expect_aux1)
assert np.allclose(aux[1].asnumpy(), expect_aux2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_nest_with_weights_has_aux_pynative():
"""
Features: Function value_and_grad.
Description: Test F.grad with different weights and has_aux as well as nested nets in pynative mode.
Expectation: No exception.
"""
class InnerNet(nn.Cell):
def construct(self, x):
return x * 3, x
class Net(nn.Cell):
def __init__(self, net):
super(Net, self).__init__()
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
self.net = net
def construct(self, x):
res1 = x * self.w * self.z
res2 = self.net(res1)
return res2
x = Tensor(np.array([1, 2]).astype(np.float32))
inner_net = InnerNet()
net = Net(inner_net)
weights = ParameterTuple(net.trainable_params())
expect_grad_input = np.array([18, 18]).astype(np.float32)
expect_grad_weight1 = np.array([9, 18]).astype(np.float32)
expect_grad_weight2 = np.array([6, 12]).astype(np.float32)
expect_aux = np.array([6, 12]).astype(np.float32)
res, aux = grad(net, 0, weights, True)(x)
assert np.allclose(res[0].asnumpy(), expect_grad_input)
assert np.allclose(res[1][0].asnumpy(), expect_grad_weight1)
assert np.allclose(res[1][1].asnumpy(), expect_grad_weight2)
assert np.allclose(aux[0].asnumpy(), expect_aux)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_value_and_grad_with_weights_has_aux_pynative():
"""
Features: Function value_and_grad.
Description: Test F.value_and_grad with different weights and has_aux in pynative mode.
Expectation: No exception.
"""
class ParamNetMultipleOutputs(nn.Cell):
def __init__(self):
super(ParamNetMultipleOutputs, self).__init__()
self.w1 = Parameter(Tensor([2., 2.], mstype.float32), name="w1")
self.w2 = Parameter(Tensor([3., 3.], mstype.float32), name="w2")
def construct(self, x):
res = x * self.w1 * self.w2
return res, x, self.w1
x = Tensor(np.array([1, 2]).astype(np.float32))
net = ParamNetMultipleOutputs()
weights = ParameterTuple(net.trainable_params())
expect_grad_input = np.array([6, 6]).astype(np.float32)
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
expect_grad_weight2 = np.array([2, 4]).astype(np.float32)
expect_value0 = np.array([6, 12]).astype(np.float32)
expect_value1 = np.array([1, 2]).astype(np.float32)
expect_value2 = np.array([2, 2]).astype(np.float32)
value, gradient = value_and_grad(net, 0, weights, True)(x)
assert np.allclose(value[0].asnumpy(), expect_value0)
assert np.allclose(value[1].asnumpy(), expect_value1)
assert np.allclose(value[2].asnumpy(), expect_value2)
assert np.allclose(gradient[0].asnumpy(), expect_grad_input)
assert np.allclose(gradient[1][0].asnumpy(), expect_grad_weight1)
assert np.allclose(gradient[1][1].asnumpy(), expect_grad_weight2)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_value_and_grad_nest_with_weights_pynative():
"""
Features: Function value_and_grad.
Description: Test F.value_and_grad with different weights and has_aux as well as nested nets in pynative mode.
Expectation: No exception.
"""
class InnerNet(nn.Cell):
def construct(self, x):
return x * 3, x
class Net(nn.Cell):
def __init__(self, net):
super(Net, self).__init__()
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
self.net = net
def construct(self, x):
res1 = x * self.w * self.z
res2 = self.net(res1)
return res2
x = Tensor(np.array([1, 2]).astype(np.float32))
inner_net = InnerNet()
net = Net(inner_net)
weights = ParameterTuple(net.trainable_params())
expect_grad_input = np.array([24, 24]).astype(np.float32)
expect_grad_weight1 = np.array([12, 24]).astype(np.float32)
expect_grad_weight2 = np.array([8, 16]).astype(np.float32)
expect_value0 = np.array([18, 36]).astype(np.float32)
expect_value1 = np.array([6, 12]).astype(np.float32)
value, gradient = value_and_grad(net, 0, weights, False)(x)
assert np.allclose(value[0].asnumpy(), expect_value0)
assert np.allclose(value[1].asnumpy(), expect_value1)
assert np.allclose(gradient[0].asnumpy(), expect_grad_input)
assert np.allclose(gradient[1][0].asnumpy(), expect_grad_weight1)
assert np.allclose(gradient[1][1].asnumpy(), expect_grad_weight2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_value_and_grad_nest_with_weights_has_aux_pynative():
"""
Features: Function value_and_grad.
Description: Test F.value_and_grad with different weights and has_aux as well as nested nets in pynative mode.
Expectation: No exception.
"""
class InnerNet(nn.Cell):
def construct(self, x):
return x * 3, x
class Net(nn.Cell):
def __init__(self, net):
super(Net, self).__init__()
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
self.net = net
def construct(self, x):
res1 = x * self.w * self.z
res2 = self.net(res1)
return res2
x = Tensor(np.array([1, 2]).astype(np.float32))
inner_net = InnerNet()
net = Net(inner_net)
weights = ParameterTuple(net.trainable_params())
expect_grad_input = np.array([18, 18]).astype(np.float32)
expect_grad_weight1 = np.array([9, 18]).astype(np.float32)
expect_grad_weight2 = np.array([6, 12]).astype(np.float32)
expect_value0 = np.array([18, 36]).astype(np.float32)
expect_value1 = np.array([6, 12]).astype(np.float32)
value, gradient = value_and_grad(net, 0, weights, True)(x)
assert np.allclose(value[0].asnumpy(), expect_value0)
assert np.allclose(value[1].asnumpy(), expect_value1)
assert np.allclose(gradient[0].asnumpy(), expect_grad_input)
assert np.allclose(gradient[1][0].asnumpy(), expect_grad_weight1)
assert np.allclose(gradient[1][1].asnumpy(), expect_grad_weight2)

View File

@ -98,8 +98,8 @@ class TestKPynative : public UT::Common {
GradPynativeOp(k_pynative_cell, c_node, args, out); GradPynativeOp(k_pynative_cell, c_node, args, out);
} }
} }
auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, std::vector<size_t>{0}, true, false, false, GradAttr grad_attr(true, false, false, false);
true); auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, std::vector<size_t>{0}, grad_attr, true);
return bprop_fg; return bprop_fg;
} }
}; };

View File

@ -58,16 +58,3 @@ def test_grad_multiple_inputs_multiple_outputs_cell_graph():
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32)) z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
net = MultipleInputsMultipleOutputsNet() net = MultipleInputsMultipleOutputsNet()
grad(net, grad_position=(1, 2))(x, y, z) grad(net, grad_position=(1, 2))(x, y, z)
def test_grad_function_with_sens_graph():
"""
Features: Function grad.
Description: Test F.grad with function setting sens_param in graph mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
v = Tensor(np.array([[-1, 3], [2, 1]]).astype(np.float32))
grad(function, grad_position=(1, 2), sens_param=True)(x, y, z, (v, v))

View File

@ -58,16 +58,3 @@ def test_grad_multiple_inputs_multiple_outputs_cell_pynative():
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32)) z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
net = MultipleInputsMultipleOutputsNet() net = MultipleInputsMultipleOutputsNet()
grad(net, grad_position=(1, 2))(x, y, z) grad(net, grad_position=(1, 2))(x, y, z)
def test_grad_function_with_sens_pynative():
"""
Features: Function grad.
Description: Test F.grad with function setting sens_param in pynative mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
v = Tensor(np.array([[-1, 3], [2, 1]]).astype(np.float32))
grad(function, grad_position=(1, 2), sens_param=True)(x, y, z, (v, v))