support grad and value_and_grad with has_aux
This commit is contained in:
parent
543f959786
commit
5f39ebd788
|
@ -525,6 +525,7 @@ Parameter操作函数
|
|||
|
||||
mindspore.ops.derivative
|
||||
mindspore.ops.grad
|
||||
mindspore.ops.value_and_grad
|
||||
mindspore.ops.jet
|
||||
mindspore.ops.jvp
|
||||
mindspore.ops.vjp
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
mindspore.ops.grad
|
||||
==================
|
||||
|
||||
.. py:function:: mindspore.ops.grad(fn, grad_position=0, weights=None, has_aux=False)
|
||||
|
||||
生成求导函数,用于计算给定函数的梯度。
|
||||
|
||||
函数求导包含以下三种场景:
|
||||
|
||||
1. 对输入求导,此时 `grad_position` 非None,而 `weights` 是None;
|
||||
2. 对网络变量求导,此时 `grad_position` 是None,而 `weights` 非None;
|
||||
3. 同时对输入和网络变量求导,此时 `grad_position`和 `weights` 都非None。
|
||||
|
||||
参数:
|
||||
- **fn** (Union[Cell, Function]) - 待求导的函数或网络。
|
||||
- **grad_position** (Union[NoneType, int, tuple[int]]) - 指定求导输入位置的索引。若为int类型,表示对单个输入求导;若为tuple类型,表示对tuple内索引的位置求导,其中索引从0开始;若是None,表示不对输入求导,这种场景下, `weights` 非None。默认值:0。
|
||||
- **weights** (Union[ParameterTuple, Parameter, list[Parameter]]) - 训练网络中需要返回梯度的网络变量。一般可通过 `weights = net.trainable_params()` 获取。默认值:None。
|
||||
- **has_aux** (bool) - 是否返回辅助参数的标志。若为True, `fn` 输出数量必须超过一个,其中只有 `fn` 第一个输出参与求导,其他输出值将直接返回。默认值:False。
|
||||
|
||||
返回:
|
||||
Function,用于计算给定函数的梯度的求导函数。例如 `out1, out2 = fn(*args)` ,若 `has_aux` 为True,梯度函数将返回 `(gradient, out2)` 形式的结果,其中 `out2` 不参与求导,若为False,将直接返回 `gradient` 。
|
||||
|
||||
异常:
|
||||
- **ValueError** - 入参 `grad_position` 和 `weights` 同时为None。
|
||||
- **TypeError** - 入参类型不符合要求。
|
|
@ -0,0 +1,25 @@
|
|||
mindspore.ops.value_and_grad
|
||||
============================
|
||||
|
||||
.. py:function:: mindspore.ops.value_and_grad(fn, grad_position=0, weights=None, has_aux=False)
|
||||
|
||||
生成求导函数,用于计算给定函数的正向计算结果和梯度。
|
||||
|
||||
函数求导包含以下三种场景:
|
||||
|
||||
1. 对输入求导,此时 `grad_position` 非None,而 `weights` 是None;
|
||||
2. 对网络变量求导,此时 `grad_position` 是None,而 `weights` 非None;
|
||||
3. 同时对输入和网络变量求导,此时 `grad_position` 和 `weights` 都非None。
|
||||
|
||||
参数:
|
||||
- **fn** (Union[Cell, Function]) - 待求导的函数或网络。
|
||||
- **grad_position** (Union[NoneType, int, tuple[int]]) - 指定求导输入位置的索引。若为int类型,表示对单个输入求导;若为tuple类型,表示对tuple内索引的位置求导,其中索引从0开始;若是None,表示不对输入求导,这种场景下, `weights` 非None。默认值:0。
|
||||
- **weights** (Union[ParameterTuple, Parameter, list[Parameter]]) - 训练网络中需要返回梯度的网络变量。一般可通过 `weights = net.trainable_params()` 获取。默认值:None。
|
||||
- **has_aux** (bool) - 是否返回辅助参数的标志。若为True, `fn` 输出数量必须超过一个,其中只有 `fn` 第一个输出参与求导,其他输出值将直接返回。默认值:False。
|
||||
|
||||
返回:
|
||||
Function,用于计算给定函数的梯度的求导函数。例如 `out1, out2 = fn(*args)` ,梯度函数将返回 `((out1, out2), gradient)` 形式的结果, 其中 `out2` 不参与求导。
|
||||
|
||||
异常:
|
||||
- **ValueError** - 入参 `grad_position` 和 `weights` 同时为None。
|
||||
- **TypeError** - 入参类型不符合要求。
|
|
@ -526,6 +526,7 @@ Differential Functions
|
|||
|
||||
mindspore.ops.derivative
|
||||
mindspore.ops.grad
|
||||
mindspore.ops.value_and_grad
|
||||
mindspore.ops.jet
|
||||
mindspore.ops.jvp
|
||||
mindspore.ops.vjp
|
||||
|
|
|
@ -612,7 +612,7 @@ FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list)
|
|||
if (tail_type_ >= kNotGrad) {
|
||||
AbstractSequencePtr sequence_arg = dyn_cast<AbstractSequence>(args_spec_list[0]);
|
||||
if (sequence_arg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "'Tail' arg0 must be tuple or list, but got " << sequence_arg->ToString();
|
||||
MS_LOG(EXCEPTION) << "'Tail' arg0 must be tuple or list, but got " << args_spec_list[0]->ToString();
|
||||
}
|
||||
return GenerateTailFuncGraph(sequence_arg);
|
||||
}
|
||||
|
@ -624,25 +624,92 @@ FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list)
|
|||
}
|
||||
AbstractTuplePtr tuple_arg = dyn_cast<AbstractTuple>(args_spec_list[0]);
|
||||
if (tuple_arg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "'Tail' arg0 must be tuple, but got " << tuple_arg->ToString();
|
||||
MS_LOG(EXCEPTION) << "'Tail' arg0 must be tuple, but got " << args_spec_list[0]->ToString();
|
||||
}
|
||||
if (args_spec_list.size() == args_max_size) {
|
||||
AbstractTuplePtr pos = dyn_cast<AbstractTuple>(args_spec_list[1]);
|
||||
if (pos == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "'Tail' arg1 'position' must be tuple, but got " << pos->ToString();
|
||||
MS_LOG(EXCEPTION) << "'Tail' arg1 'position' must be tuple, but got " << args_spec_list[1]->ToString();
|
||||
}
|
||||
return GenerateGradFuncGraph(tuple_arg, pos);
|
||||
}
|
||||
return GenerateGradFuncGraph(tuple_arg);
|
||||
}
|
||||
namespace {
|
||||
AnfNodePtr CreateOutputsWithAux(const FuncGraphPtr &k_child, const AnfNodePtr &gradient, const AnfNodePtr &f_app,
|
||||
bool has_aux, bool get_value) {
|
||||
if (get_value) {
|
||||
return k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), f_app, gradient});
|
||||
}
|
||||
if (!has_aux) {
|
||||
return gradient;
|
||||
}
|
||||
PrimitivePtr get_tuple_item_op = prim::kPrimTupleGetItem;
|
||||
PrimitivePtr make_tuple_op = prim::kPrimMakeTuple;
|
||||
std::vector<AnfNodePtr> elements = {NewValueNode(make_tuple_op)};
|
||||
elements.emplace_back(
|
||||
k_child->NewCNodeInOrder({NewValueNode(get_tuple_item_op), f_app, NewValueNode(static_cast<int64_t>(1))}));
|
||||
auto aux_output = k_child->NewCNodeInOrder(elements);
|
||||
auto unpack_node =
|
||||
k_child->NewCNodeInOrder({NewValueNode(get_tuple_item_op), aux_output, NewValueNode(static_cast<int64_t>(0))});
|
||||
return k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), gradient, unpack_node});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// When set aux True, for out1, out2, out3 = fn(inputs), only first out1 contributes to differentiation of fn.
|
||||
FuncGraphPtr GradAux::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
AbstractTuplePtr tuple_arg = dyn_cast<AbstractTuple>(args_spec_list[0]);
|
||||
if (tuple_arg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "'GradAux' arg0 must be tuple, but got " << args_spec_list[0]->ToString();
|
||||
}
|
||||
FuncGraphPtr fg = std::make_shared<FuncGraph>();
|
||||
fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
AnfNodePtr tuple_parameter = fg->add_parameter();
|
||||
// get_value flag
|
||||
(void)fg->add_parameter();
|
||||
|
||||
AbstractScalarPtr get_value_ptr = dyn_cast<AbstractScalar>(args_spec_list[1]);
|
||||
bool get_value_flag = GetValue<bool>(get_value_ptr->BuildValue());
|
||||
std::vector<AnfNodePtr> elements = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
elements.push_back(
|
||||
fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(0))}));
|
||||
if (get_value_flag) {
|
||||
for (size_t i = 1; i < tuple_arg->size(); i++) {
|
||||
auto aux_node =
|
||||
fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(i))});
|
||||
auto stop_gradient_node = fg->NewCNodeInOrder({NewValueNode(prim::kPrimStopGradient), aux_node});
|
||||
elements.push_back(stop_gradient_node);
|
||||
}
|
||||
} else {
|
||||
std::vector<AnfNodePtr> aux_elements = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
for (size_t i = 1; i < tuple_arg->size(); i++) {
|
||||
auto aux_node =
|
||||
fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(i))});
|
||||
auto stop_gradient_node = fg->NewCNodeInOrder({NewValueNode(prim::kPrimStopGradient), aux_node});
|
||||
aux_elements.push_back(stop_gradient_node);
|
||||
}
|
||||
elements.push_back(fg->NewCNodeInOrder(aux_elements));
|
||||
}
|
||||
|
||||
constexpr size_t args_least_size = 2;
|
||||
if (elements.size() < args_least_size) {
|
||||
MS_LOG(EXCEPTION) << "When has_aux is True, origin fn requires more than one outputs, but got " << elements.size()
|
||||
<< " outputs.\n"
|
||||
<< trace::GetDebugInfo(fg->debug_info());
|
||||
}
|
||||
fg->set_output(fg->NewCNodeInOrder(elements));
|
||||
return fg;
|
||||
}
|
||||
|
||||
GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param,
|
||||
bool get_by_position)
|
||||
bool get_by_position, bool has_aux, bool get_value)
|
||||
: MetaFuncGraph(name),
|
||||
get_all_(get_all),
|
||||
get_by_list_(get_by_list),
|
||||
sens_param_(sens_param),
|
||||
get_by_position_(get_by_position) {
|
||||
get_by_position_(get_by_position),
|
||||
has_aux_(has_aux),
|
||||
get_value_(get_value) {
|
||||
if (get_by_position) {
|
||||
signatures_ =
|
||||
// def grad(func:read, weight_list:ref, position_list:ref):
|
||||
|
@ -714,29 +781,27 @@ void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePt
|
|||
if (get_by_position_) {
|
||||
TailPtr tail_grad_by_position = std::make_shared<Tail>("tail_grad_by_position", kGradByPosition);
|
||||
inputs_bprop = k_child->NewCNodeInOrder({NewValueNode(tail_grad_by_position), b_app, position});
|
||||
k_child->set_output(inputs_bprop);
|
||||
return;
|
||||
}
|
||||
if (get_all_) {
|
||||
} else if (get_all_) {
|
||||
TailPtr tail_grad_all = std::make_shared<Tail>("tail_grad_all", kGradAll);
|
||||
inputs_bprop = k_child->NewCNodeInOrder({NewValueNode(tail_grad_all), b_app});
|
||||
}
|
||||
|
||||
// Gradients wrt inputs and parameters
|
||||
if (fv_bprop != nullptr && inputs_bprop != nullptr) {
|
||||
k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop}));
|
||||
auto make_tuple = k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop});
|
||||
k_child->set_output(CreateOutputsWithAux(k_child, make_tuple, f_app, has_aux_, get_value_));
|
||||
return;
|
||||
}
|
||||
|
||||
// Gradients wrt parameters
|
||||
if (fv_bprop != nullptr) {
|
||||
k_child->set_output(fv_bprop);
|
||||
k_child->set_output(CreateOutputsWithAux(k_child, fv_bprop, f_app, has_aux_, get_value_));
|
||||
return;
|
||||
}
|
||||
|
||||
// Gradients wrt inputs
|
||||
if (inputs_bprop != nullptr) {
|
||||
k_child->set_output(inputs_bprop);
|
||||
k_child->set_output(CreateOutputsWithAux(k_child, inputs_bprop, f_app, has_aux_, get_value_));
|
||||
return;
|
||||
}
|
||||
// Gradients wrt first input.
|
||||
|
@ -744,7 +809,8 @@ void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePt
|
|||
// so obtain first input grad by setting tail_type of Tail to kGradFirst.
|
||||
TailPtr tail_grad_first = std::make_shared<Tail>("tail_grad_first", kGradFirst);
|
||||
tail_grad_first->set_enable_tuple_grad_first(enable_tuple_grad);
|
||||
k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app}));
|
||||
auto tail_grad_first_cnode = k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app});
|
||||
k_child->set_output(CreateOutputsWithAux(k_child, tail_grad_first_cnode, f_app, has_aux_, get_value_));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -795,6 +861,14 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
|
|||
|
||||
FuncGraphPtr forward_graph = real_fn->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(forward_graph);
|
||||
|
||||
if (has_aux_) {
|
||||
GradAuxPtr aux_fn = std::make_shared<GradAux>("aux_fn");
|
||||
auto output_cnode = forward_graph->output();
|
||||
auto aux_fn_cnode = forward_graph->NewCNodeInOrder({NewValueNode(aux_fn), output_cnode, NewValueNode(get_value_)});
|
||||
forward_graph->set_output(aux_fn_cnode);
|
||||
}
|
||||
|
||||
forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
|
||||
|
||||
// Check if primal func graph has the primitive returned sparse result in its bprop().
|
||||
|
@ -814,13 +888,12 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
|
|||
ParameterPtr param_graph = grad_fg->add_parameter();
|
||||
|
||||
AnfNodePtr weights = nullptr;
|
||||
if (get_by_list_) {
|
||||
weights = grad_fg->add_parameter();
|
||||
}
|
||||
AnfNodePtr position = nullptr;
|
||||
if (get_by_position_) {
|
||||
weights = grad_fg->add_parameter();
|
||||
position = grad_fg->add_parameter();
|
||||
} else if (get_by_list_) {
|
||||
weights = grad_fg->add_parameter();
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
|
|
|
@ -144,7 +144,8 @@ using MakeListGradientPtr = std::shared_ptr<MakeListGradient>;
|
|||
class GradOperation : public MetaFuncGraph {
|
||||
public:
|
||||
explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false,
|
||||
bool sens_param = false, bool get_by_position = false);
|
||||
bool sens_param = false, bool get_by_position = false, bool has_aux = false,
|
||||
bool get_value = false);
|
||||
~GradOperation() override = default;
|
||||
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
|
||||
|
||||
|
@ -158,6 +159,8 @@ class GradOperation : public MetaFuncGraph {
|
|||
bool get_by_list_;
|
||||
bool sens_param_;
|
||||
bool get_by_position_;
|
||||
bool has_aux_;
|
||||
bool get_value_;
|
||||
|
||||
private:
|
||||
void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
|
||||
|
@ -165,6 +168,15 @@ class GradOperation : public MetaFuncGraph {
|
|||
};
|
||||
using GradOperationPtr = std::shared_ptr<GradOperation>;
|
||||
|
||||
class GradAux : public MetaFuncGraph {
|
||||
public:
|
||||
explicit GradAux(const std::string &name) : MetaFuncGraph(name) {}
|
||||
~GradAux() override = default;
|
||||
MS_DECLARE_PARENT(GradAux, MetaFuncGraph);
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||
};
|
||||
using GradAuxPtr = std::shared_ptr<GradAux>;
|
||||
|
||||
class TaylorOperation : public MetaFuncGraph {
|
||||
public:
|
||||
explicit TaylorOperation(const std::string &name);
|
||||
|
|
|
@ -38,8 +38,9 @@ REGISTER_PYBIND_WITH_PARENT_NAME(
|
|||
// Reg GradOperation
|
||||
(void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>(*m, "GradOperation_")
|
||||
.def(py::init<std::string &>(), py::arg("fn"))
|
||||
.def(py::init<std::string &, bool, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"), py::arg("get_by_list"),
|
||||
py::arg("sens_param"), py::arg("get_by_position"));
|
||||
.def(py::init<std::string &, bool, bool, bool, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"),
|
||||
py::arg("get_by_list"), py::arg("sens_param"), py::arg("get_by_position"), py::arg("has_aux"),
|
||||
py::arg("get_value"));
|
||||
|
||||
// Reg VmapOperation
|
||||
(void)py::class_<VmapOperation, MetaFuncGraph, std::shared_ptr<VmapOperation>>(*m, "VmapOperation_")
|
||||
|
|
|
@ -298,8 +298,8 @@ class KPynativeCellImpl : public KPynativeCell {
|
|||
void UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &sens_out) override;
|
||||
// Build a back propagate funcgraph, each cnode in primal funcgraph is replaced by value node or formal cnode, so it
|
||||
// can be grad again.
|
||||
FuncGraphPtr Finish(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position, bool grad_inputs,
|
||||
bool grad_weights, bool has_sens_arg, bool build_formal_param);
|
||||
FuncGraphPtr Finish(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
|
||||
const GradAttr &grad_attr, bool build_formal_param);
|
||||
|
||||
private:
|
||||
bool need_propagate_stop_gradient_{false};
|
||||
|
@ -346,13 +346,16 @@ class KPynativeCellImpl : public KPynativeCell {
|
|||
AnfNodePtrList *grad_inputs_list);
|
||||
// Set return node according to grad flag
|
||||
void SetOutput(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position, bool grad_inputs,
|
||||
bool grad_weights);
|
||||
bool grad_weights, const bool get_by_position);
|
||||
|
||||
// for higher order gradient;
|
||||
// Build k mapped node owned by tape_ for each cnode in primal funcgraph, so these node can be
|
||||
// used in tape_ to keep tracking the cnode dependency.
|
||||
bool BuildKNode();
|
||||
CNodePtr GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const AnfNodePtrList &args);
|
||||
AnfNodePtr GetTapeOutputForPosition(const AnfNodePtrList &grad_inputs_list, const AbstractBasePtr &grad_inputs_spec,
|
||||
const AnfNodePtrList &grad_weights_list, const AbstractBasePtr &grad_weights_spec,
|
||||
const std::vector<size_t> &grad_position, const bool grad_weights);
|
||||
};
|
||||
using KPynativeCellImplPtr = std::shared_ptr<KPynativeCellImpl>;
|
||||
|
||||
|
@ -371,19 +374,18 @@ KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
|
|||
}
|
||||
|
||||
FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights,
|
||||
const std::vector<size_t> &grad_position, bool grad_inputs, bool grad_weights,
|
||||
bool has_sens_arg, bool build_formal_param) {
|
||||
const std::vector<size_t> &grad_position, const GradAttr &grad_attr,
|
||||
bool build_formal_param) {
|
||||
auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
|
||||
return k_cell_impl->Finish(weights, grad_position, grad_inputs, grad_weights, has_sens_arg, build_formal_param);
|
||||
return k_cell_impl->Finish(weights, grad_position, grad_attr, build_formal_param);
|
||||
}
|
||||
|
||||
FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
|
||||
bool grad_inputs, bool grad_weights, bool has_sens_arg,
|
||||
bool build_formal_param) {
|
||||
const GradAttr &grad_attr, bool build_formal_param) {
|
||||
// propagate stop_gradient flag to cnode before back propagate;
|
||||
PropagateStopGradient();
|
||||
// Set sens node and weights node
|
||||
SetSensAndWeights(weights, has_sens_arg);
|
||||
SetSensAndWeights(weights, grad_attr.has_sens);
|
||||
// Build forward CNode;
|
||||
if (build_formal_param) {
|
||||
(void)BuildKNode();
|
||||
|
@ -393,12 +395,12 @@ FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, const std:
|
|||
(void)BackPropagate(!build_formal_param);
|
||||
}
|
||||
// Return the gradient;
|
||||
if (grad_position.empty()) {
|
||||
MS_LOG(EXCEPTION) << "grad_position in F.grad is empty!";
|
||||
if (grad_attr.get_by_position && grad_position.empty()) {
|
||||
MS_LOG(EXCEPTION) << "grad_position in F.grad is empty when grad by position!";
|
||||
}
|
||||
SetOutput(weights, grad_position, grad_inputs, grad_weights);
|
||||
SetOutput(weights, grad_position, grad_attr.grad_all_inputs, grad_attr.grad_weights, grad_attr.get_by_position);
|
||||
// Replace Parameter of primal funcgraph with parameter of tape_;
|
||||
ReplacePrimalParameter(weights, has_sens_arg);
|
||||
ReplacePrimalParameter(weights, grad_attr.has_sens);
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
auto save_graphs_flg = MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
if (save_graphs_flg) {
|
||||
|
@ -1071,7 +1073,7 @@ AbstractBasePtr KPynativeCellImpl::GetGradInputsSpec(const std::vector<size_t> &
|
|||
}
|
||||
|
||||
void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
|
||||
bool grad_inputs, bool grad_weights) {
|
||||
bool grad_inputs, bool grad_weights, const bool get_by_position) {
|
||||
AnfNodePtrList grad_inputs_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
AbstractBasePtr grad_inputs_spec = GetGradInputsSpec(grad_position, grad_inputs, &grad_inputs_list);
|
||||
AnfNodePtrList grad_weights_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
|
@ -1098,13 +1100,20 @@ void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vect
|
|||
grad_weights_spec = std::make_shared<abstract::AbstractTuple>(grad_weights_abs_list);
|
||||
}
|
||||
|
||||
if (get_by_position) {
|
||||
auto tape_output = GetTapeOutputForPosition(grad_inputs_list, grad_inputs_spec, grad_weights_list,
|
||||
grad_weights_spec, grad_position, grad_weights);
|
||||
tape_->set_output(tape_output);
|
||||
return;
|
||||
}
|
||||
|
||||
AnfNodePtr tape_output;
|
||||
if (grad_inputs && grad_weights) {
|
||||
tape_output = tape_->NewCNode(
|
||||
{NewValueNode(prim::kPrimMakeTuple), tape_->NewCNode(grad_inputs_list), tape_->NewCNode(grad_weights_list)});
|
||||
tape_output->set_abstract(
|
||||
std::make_shared<abstract::AbstractTuple>(abstract::AbstractBasePtrList{grad_inputs_spec, grad_weights_spec}));
|
||||
} else if (grad_inputs || (grad_position.size() > 1)) {
|
||||
} else if (grad_inputs) {
|
||||
tape_output = tape_->NewCNode(grad_inputs_list);
|
||||
tape_output->set_abstract(grad_inputs_spec);
|
||||
} else if (grad_weights) {
|
||||
|
@ -1114,15 +1123,11 @@ void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vect
|
|||
tape_output = tape_->NewCNode(grad_inputs_list);
|
||||
tape_output->set_abstract(grad_inputs_spec);
|
||||
} else {
|
||||
size_t index = grad_position[0];
|
||||
if (index >= cell_inputs_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Position index " << index << " is exceed input size!";
|
||||
}
|
||||
auto input_adjoint_iter = anfnode_to_adjoin_.find(cell_inputs_[index]);
|
||||
auto input_adjoint_iter = anfnode_to_adjoin_.find(cell_inputs_[0]);
|
||||
if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
|
||||
// If input is not used in the network, just return zeros_like() as dout;
|
||||
MS_LOG(WARNING) << "Input is not used in network, input: " << cell_inputs_[index]->ToString();
|
||||
tape_output = BuildZerosLikeNode(tape_, cell_inputs_[index]);
|
||||
MS_LOG(WARNING) << "Input is not used in network, input: " << cell_inputs_[0]->ToString();
|
||||
tape_output = BuildZerosLikeNode(tape_, cell_inputs_[0]);
|
||||
} else {
|
||||
tape_output = input_adjoint_iter->second->RealDout();
|
||||
}
|
||||
|
@ -1200,5 +1205,53 @@ void ClearKPynativeCellStaticRes() {
|
|||
zeros_like_funcgraph_cache.clear();
|
||||
ones_like_funcgraph_cache.clear();
|
||||
}
|
||||
|
||||
// Support both grad_position and weight options in grad.
|
||||
AnfNodePtr KPynativeCellImpl::GetTapeOutputForPosition(const AnfNodePtrList &grad_inputs_list,
|
||||
const AbstractBasePtr &grad_inputs_spec,
|
||||
const AnfNodePtrList &grad_weights_list,
|
||||
const AbstractBasePtr &grad_weights_spec,
|
||||
const std::vector<size_t> &grad_position,
|
||||
const bool grad_weights) {
|
||||
MS_EXCEPTION_IF_NULL(tape_);
|
||||
|
||||
AnfNodePtr tape_output;
|
||||
if (grad_position.size() == 1) {
|
||||
AnfNodePtr first_tape_output;
|
||||
size_t index = grad_position[0];
|
||||
if (index >= cell_inputs_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Position index " << index << " is exceed input size!";
|
||||
}
|
||||
auto input_adjoint_iter = anfnode_to_adjoin_.find(cell_inputs_[index]);
|
||||
if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
|
||||
MS_LOG(WARNING) << "Input is not used in network, input: " << cell_inputs_[index]->ToString();
|
||||
first_tape_output = BuildZerosLikeNode(tape_, cell_inputs_[index]);
|
||||
} else {
|
||||
first_tape_output = input_adjoint_iter->second->RealDout();
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(first_tape_output);
|
||||
if (grad_weights) {
|
||||
tape_output =
|
||||
tape_->NewCNode({NewValueNode(prim::kPrimMakeTuple), first_tape_output, tape_->NewCNode(grad_weights_list)});
|
||||
MS_EXCEPTION_IF_NULL(tape_output);
|
||||
tape_output->set_abstract(std::make_shared<abstract::AbstractTuple>(
|
||||
abstract::AbstractBasePtrList{first_tape_output->abstract(), grad_weights_spec}));
|
||||
return tape_output;
|
||||
}
|
||||
return first_tape_output;
|
||||
}
|
||||
|
||||
if (grad_weights) {
|
||||
tape_output = tape_->NewCNode(
|
||||
{NewValueNode(prim::kPrimMakeTuple), tape_->NewCNode(grad_inputs_list), tape_->NewCNode(grad_weights_list)});
|
||||
tape_output->set_abstract(
|
||||
std::make_shared<abstract::AbstractTuple>(abstract::AbstractBasePtrList{grad_inputs_spec, grad_weights_spec}));
|
||||
} else {
|
||||
tape_output = tape_->NewCNode(grad_inputs_list);
|
||||
tape_output->set_abstract(grad_inputs_spec);
|
||||
}
|
||||
return tape_output;
|
||||
}
|
||||
} // namespace ad
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,6 +40,16 @@ class KPynativeCell {
|
|||
|
||||
using KPynativeCellPtr = std::shared_ptr<KPynativeCell>;
|
||||
|
||||
struct GradAttr {
|
||||
bool grad_all_inputs;
|
||||
bool grad_weights;
|
||||
bool has_sens;
|
||||
bool get_by_position;
|
||||
|
||||
GradAttr(bool get_all, bool get_by_list, bool sens_param, bool get_by_position)
|
||||
: grad_all_inputs(get_all), grad_weights(get_by_list), has_sens(sens_param), get_by_position(get_by_position) {}
|
||||
};
|
||||
|
||||
// bprop_fg: user defined back propagate funcgraph or back propagate funcgraph of primitive, it will be passed after
|
||||
// just parsed. will have prototype:
|
||||
// (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout)
|
||||
|
@ -71,8 +81,8 @@ KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
|
|||
// else:
|
||||
// each cnode in primal funcgraph is replaced by value node
|
||||
FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights,
|
||||
const std::vector<size_t> &grad_position, bool grad_inputs, bool grad_weights,
|
||||
bool has_sens_arg = false, bool build_formal_param = false);
|
||||
const std::vector<size_t> &grad_position, const GradAttr &grad_attr,
|
||||
bool build_formal_param = false);
|
||||
|
||||
// Grad for each operation.
|
||||
// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.
|
||||
|
|
|
@ -565,7 +565,7 @@ void GradExecutor::GradNetInner(const py::object *ret, const prim::GradOperation
|
|||
|
||||
// Get params(weights) require derivative
|
||||
auto w_args = GetWeightsArgs(weights, df_builder);
|
||||
auto p_args = GetGradPositionArgs(grad_position);
|
||||
auto p_args = GetGradPositionArgs(grad_position, grad->get_by_position_);
|
||||
if (w_args.empty() && !df_builder->parameters().empty()) {
|
||||
MS_LOG(DEBUG) << "Add weights params to w_args";
|
||||
(void)w_args.insert(w_args.end(), df_builder->parameters().cbegin(), df_builder->parameters().cend());
|
||||
|
@ -650,15 +650,19 @@ std::vector<AnfNodePtr> GradExecutor::GetWeightsArgs(const py::object &weights,
|
|||
return w_args;
|
||||
}
|
||||
|
||||
std::vector<size_t> GradExecutor::GetGradPositionArgs(const py::object &grad_position) const {
|
||||
std::vector<size_t> GradExecutor::GetGradPositionArgs(const py::object &grad_position,
|
||||
const bool get_by_position) const {
|
||||
std::vector<size_t> pos_args;
|
||||
if (!get_by_position) {
|
||||
return pos_args;
|
||||
}
|
||||
if (py::isinstance<py::tuple>(grad_position)) {
|
||||
const auto &tuple = grad_position.cast<py::tuple>();
|
||||
(void)std::transform(tuple.begin(), tuple.end(), std::back_inserter(pos_args),
|
||||
[](const py::handle &elem) { return py::cast<int64_t>(elem); });
|
||||
return pos_args;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Grad position only support tuple.";
|
||||
MS_LOG(EXCEPTION) << "Grad position only support tuple when grad_by_position is set True.";
|
||||
}
|
||||
|
||||
void GradExecutor::ShallowCopySensValue(const py::tuple &input_args, bool has_sens, VectorRef *run_args) const {
|
||||
|
@ -778,8 +782,9 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const prim::GradOperationPtr &grad, con
|
|||
auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
|
||||
MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
|
||||
MS_EXCEPTION_IF_NULL(grad);
|
||||
FuncGraphPtr bprop_graph = ad::GradPynativeCellEnd(k_pynative_cell_ptr, weights, grad_position, grad->get_all_,
|
||||
grad->get_by_list_, grad->sens_param_, build_formal_param);
|
||||
ad::GradAttr grad_attr(grad->get_all_, grad->get_by_list_, grad->sens_param_, grad->get_by_position_);
|
||||
FuncGraphPtr bprop_graph =
|
||||
ad::GradPynativeCellEnd(k_pynative_cell_ptr, weights, grad_position, grad_attr, build_formal_param);
|
||||
MS_EXCEPTION_IF_NULL(bprop_graph);
|
||||
|
||||
MS_LOG(DEBUG) << "Top graph input params size " << arg_size;
|
||||
|
|
|
@ -170,7 +170,7 @@ class GradExecutor {
|
|||
const abstract::AbstractBasePtr &input_abs,
|
||||
const abstract::AbstractBasePtr ¶m_tensor_abs, const std::string &input_shape);
|
||||
void UpdateParamAbsByArgs(const py::list &args, const FuncGraphPtr &bprop_graph);
|
||||
std::vector<size_t> GetGradPositionArgs(const py::object &grad_position) const;
|
||||
std::vector<size_t> GetGradPositionArgs(const py::object &grad_position, const bool get_by_position) const;
|
||||
void ShallowCopySensValue(const py::tuple &input_args, bool has_sens, VectorRef *run_args) const;
|
||||
// Manage resource for construct forward graph.
|
||||
AnfNodePtr GetObjNode(const ValuePtr &v, const std::string &obj_id) const;
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
|
||||
"""grad impl."""
|
||||
from __future__ import absolute_import
|
||||
from mindspore.ops._grad.grad_base import get_bprop_fn, get_taylor_fprop_fn
|
||||
from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, grad_clip_ops, \
|
||||
grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops, grad_sparse, grad_inner_ops, taylor_rule
|
||||
|
|
|
@ -347,7 +347,7 @@ class GradOperation(GradOperation_):
|
|||
self.get_all = get_all
|
||||
self.get_by_list = get_by_list
|
||||
self.sens_param = sens_param
|
||||
GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, False)
|
||||
GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, False, False, False)
|
||||
self.grad_fn = None
|
||||
self.fn = None
|
||||
self.weights_id = None
|
||||
|
@ -453,7 +453,7 @@ class _Grad(GradOperation_):
|
|||
A higher-order function which is used to generate the gradient function by position for the input function.
|
||||
"""
|
||||
|
||||
def __init__(self, get_by_list=False, sens_param=False, get_by_position=False):
|
||||
def __init__(self, get_by_list=False, sens_param=False, get_by_position=False, has_aux=False, get_value=False):
|
||||
"""Initialize _Grad."""
|
||||
if not isinstance(get_by_position, bool):
|
||||
raise TypeError(f"For '_Grad', the 'get_by_position' should be bool, "
|
||||
|
@ -464,10 +464,18 @@ class _Grad(GradOperation_):
|
|||
if not isinstance(sens_param, bool):
|
||||
raise TypeError(f"For '_Grad', the 'sens_param' should be bool, "
|
||||
f"but got {type(sens_param).__name__}")
|
||||
if not isinstance(has_aux, bool):
|
||||
raise TypeError(f"For '_Grad', the 'has_aux' should be bool, "
|
||||
f"but got {type(has_aux).__name__}")
|
||||
if not isinstance(get_value, bool):
|
||||
raise TypeError(f"For '_Grad', the 'get_value' should be bool, "
|
||||
f"but got {type(get_value).__name__}")
|
||||
self.get_by_position = get_by_position
|
||||
self.get_by_list = get_by_list
|
||||
self.sens_param = sens_param
|
||||
GradOperation_.__init__(self, 'grad', False, get_by_list, sens_param, get_by_position)
|
||||
self.has_aux = has_aux
|
||||
self.get_value = get_value
|
||||
GradOperation_.__init__(self, 'grad', False, get_by_list, sens_param, get_by_position, has_aux, get_value)
|
||||
self.grad_fn = None
|
||||
self.fn = None
|
||||
self.pynative_ = False
|
||||
|
@ -480,9 +488,19 @@ class _Grad(GradOperation_):
|
|||
if self.grad_fn is not None and self.fn == fn and self.grad_position == grad_position and \
|
||||
self.weights_id == weights_id:
|
||||
return self.grad_fn
|
||||
self.fn = fn
|
||||
self.grad_position = grad_position
|
||||
grad_ = _Grad(self.get_by_list, self.sens_param, self.get_by_position)
|
||||
|
||||
def aux_fn(*args):
|
||||
outputs = fn(*args)
|
||||
if not isinstance(outputs, tuple) or len(outputs) < 2:
|
||||
raise ValueError("When has_aux is True, origin fn requires more than one outputs, but got "
|
||||
+ len(outputs) + " outputs.")
|
||||
res = (outputs[0],)
|
||||
stop_gradient = Primitive("stop_gradient")
|
||||
for item in outputs[1:]:
|
||||
res += (stop_gradient(item),)
|
||||
return res
|
||||
|
||||
grad_ = _Grad(self.get_by_list, self.sens_param, self.get_by_position, self.has_aux, self.get_value)
|
||||
# If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE
|
||||
# If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
|
||||
# In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
|
||||
|
@ -508,23 +526,34 @@ class _Grad(GradOperation_):
|
|||
|
||||
@_wrap_func
|
||||
def after_grad(*args, **kwargs):
|
||||
self._pynative_forward_run(fn, grad_, args, kwargs)
|
||||
forward_flag = self.get_value or self.has_aux
|
||||
res = self._pynative_forward_run(fn, grad_, forward_flag, args, kwargs)
|
||||
_pynative_executor.grad(fn, grad_, weights, grad_position, *args, **kwargs)
|
||||
out = _pynative_executor(fn, grad_.sens_param, *args, **kwargs)
|
||||
_pynative_executor.clear_grad(fn, *args, **kwargs)
|
||||
if self.get_value:
|
||||
return res, out
|
||||
if self.has_aux:
|
||||
return out, res[1:]
|
||||
return out
|
||||
else:
|
||||
grad_.pynative_ = True
|
||||
# after_grad of this branch can't use @ms_function, just directly call grad_
|
||||
if self.get_by_position:
|
||||
def after_grad(*args, **kwargs):
|
||||
if self.has_aux:
|
||||
return grad_(aux_fn, weights, grad_position)(*args, **kwargs)
|
||||
return grad_(fn, weights, grad_position)(*args, **kwargs)
|
||||
else:
|
||||
if self.get_by_list:
|
||||
def after_grad(*args, **kwargs):
|
||||
if self.has_aux:
|
||||
return grad_(aux_fn, weights)(*args, **kwargs)
|
||||
return grad_(fn, weights)(*args, **kwargs)
|
||||
else:
|
||||
def after_grad(*args, **kwargs):
|
||||
if self.has_aux:
|
||||
return grad_(aux_fn)(*args, **kwargs)
|
||||
return grad_(fn)(*args, **kwargs)
|
||||
|
||||
self.grad_fn = after_grad
|
||||
|
@ -534,9 +563,10 @@ class _Grad(GradOperation_):
|
|||
self.grad_hash_id = (grad_position, weights_id)
|
||||
return self.grad_fn
|
||||
|
||||
def _pynative_forward_run(self, fn, grad, args, kwargs):
|
||||
def _pynative_forward_run(self, fn, grad, forward_flag, args, kwargs):
|
||||
""" Pynative forward runs to build grad graph. """
|
||||
new_kwargs = kwargs
|
||||
outputs = ()
|
||||
if self.sens_param:
|
||||
if 'sens' in kwargs.keys():
|
||||
new_kwargs = kwargs.copy()
|
||||
|
@ -549,12 +579,17 @@ class _Grad(GradOperation_):
|
|||
_pynative_executor.new_graph(fn, *args, **new_kwargs)
|
||||
outputs = fn(*args, **new_kwargs)
|
||||
_pynative_executor.end_graph(fn, outputs, *args, **new_kwargs)
|
||||
return outputs
|
||||
else:
|
||||
# Check if fn has run already.
|
||||
if not _pynative_executor.check_run(grad, fn, self.grad_hash_id, *args, **new_kwargs):
|
||||
fn.set_grad()
|
||||
fn(*args, **new_kwargs)
|
||||
outputs = fn(*args, **new_kwargs)
|
||||
fn.set_grad(False)
|
||||
return outputs
|
||||
if forward_flag and not outputs:
|
||||
outputs = fn(*args, **new_kwargs)
|
||||
return outputs
|
||||
|
||||
|
||||
class _Vmap(VmapOperation_):
|
||||
|
|
|
@ -345,6 +345,7 @@ from .random_func import (
|
|||
from .grad import (
|
||||
grad_func,
|
||||
grad,
|
||||
value_and_grad,
|
||||
jet,
|
||||
derivative,
|
||||
jvp,
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
from .grad_func import (
|
||||
grad,
|
||||
value_and_grad,
|
||||
jet,
|
||||
derivative,
|
||||
jvp,
|
||||
|
|
|
@ -57,53 +57,248 @@ def _convert_grad_position_type(grad_position):
|
|||
return grad_position
|
||||
|
||||
|
||||
grad_by_position = _Grad(get_by_list=False, sens_param=False, get_by_position=True)
|
||||
grad_by_position_with_sens = _Grad(get_by_list=False, sens_param=True, get_by_position=True)
|
||||
@constexpr
|
||||
def _get_grad_op(get_by_list, get_by_position, has_aux, get_value=False):
|
||||
return _Grad(get_by_list=get_by_list, get_by_position=get_by_position, has_aux=has_aux, get_value=get_value)
|
||||
|
||||
|
||||
def grad(fn, grad_position=0, sens_param=False):
|
||||
r"""
|
||||
def grad(fn, grad_position=0, weights=None, has_aux=False):
|
||||
"""
|
||||
A wrapper function to generate the gradient function for the input function.
|
||||
|
||||
As for gradient, three typical cases are included:
|
||||
|
||||
1. gradient with respect to inputs. In this case, `grad_position` is not None while `weights` is None.
|
||||
2. gradient with respect to weights. In this case, `grad_position` is None while `weights` is not None.
|
||||
3. gradient with respect to inputs and weights. In this case, `grad_position` and `weights` are not None.
|
||||
|
||||
Args:
|
||||
fn (Union(Cell, function)): Function to do GradOperation.
|
||||
grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
|
||||
If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
|
||||
sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
|
||||
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False.
|
||||
grad_position (Union(NoneType, int, tuple[int])): Index to specify which inputs to be differentiated.
|
||||
If int, get the gradient with respect to single input.
|
||||
If tuple, get the gradients with respect to selected inputs. `grad_position` begins with 0.
|
||||
If None, none derivative of any input will be figured out, and in this case, `weights` is required.
|
||||
Default: 0.
|
||||
weights (Union(ParameterTuple, Parameter, list(Parameter))): The parameters of the training network that need to
|
||||
calculate the gradient. `weights` can be got through `weights = net.trainable_params()` .
|
||||
Default: None.
|
||||
has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
|
||||
will be returned straightly. It means the `fn` must return more than one outputs in this case.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
Function, returns the gradient function for the input function or cell.
|
||||
Function, the gradient function to calculate gradient for the input function or cell.
|
||||
For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set True, gradient function will return outputs
|
||||
like `(gradient, out2)` and `out2` does not contribute to the differentiation, otherwise `gradient`.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `grad_position` and `weights` are None.
|
||||
TypeError: If type of Args does not belong to required ones.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> import mindspore
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops.functional import grad
|
||||
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
||||
>>> from mindspore import Tensor, ops
|
||||
>>> from mindspore.ops import grad
|
||||
>>>
|
||||
>>> # Cell object to be differentiated
|
||||
>>> class Net(nn.Cell):
|
||||
... def construct(self, x, y, z):
|
||||
... return x*y*z
|
||||
>>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
>>> y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
>>> z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
... return x * y * z
|
||||
>>> x = Tensor([1, 2], mindspore.float32)
|
||||
>>> y = Tensor([-2, 3], mindspore.float32)
|
||||
>>> z = Tensor([0, 3], mindspore.float32)
|
||||
>>> net = Net()
|
||||
>>> output = grad(net, grad_position=(1, 2))(x, y, z)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[ 0.00000000e+00, 6.00000000e+00],
|
||||
[ 1.50000000e+01, -4.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[-2.00000000e+00, 6.00000000e+00],
|
||||
[-3.00000000e+00, 8.00000000e+00]]))
|
||||
(Tensor(shape=[2], dtype=Float32, value=[ 0.00000000e+00, 6.00000000e+00]),
|
||||
Tensor(shape=[2], dtype=Float32, value=[-2.00000000e+00, 6.00000000e+00]))
|
||||
>>>
|
||||
>>> # Function object to be differentiated
|
||||
>>> def fn(x, y, z):
|
||||
... res = x * ops.exp(y) * ops.pow(z, 2)
|
||||
... return res, z
|
||||
>>> x = Tensor([3, 3], mindspore.float32)
|
||||
>>> y = Tensor([0, 0], mindspore.float32)
|
||||
>>> z = Tensor([2, 2], mindspore.float32)
|
||||
>>> gradient, aux = grad(net, (1, 2), None, True)(x, y)
|
||||
>>> print(gradient)
|
||||
(Tensor(shape=[2], dtype=Float32, value= [ 7.50000000e+01, 7.50000000e+01]),
|
||||
Tensor(shape=[2], dtype=Float32, value= [ 3.00000000e+01, 3.00000000e+01]))
|
||||
>>> print(aux)
|
||||
(Tensor(shape=[2], dtype=Float32, value= [ 5.00000000e+00, 5.00000000e+00]),)
|
||||
>>>
|
||||
>>> # For given network to be differentiated with both inputs and weights, there are 3 cases.
|
||||
>>> net = nn.Dense(10, 1)
|
||||
>>> loss_fn = nn.MSELoss()
|
||||
>>> def forward(inputs, labels):
|
||||
... logits = net(inputs)
|
||||
... loss = loss_fn(logits, labels)
|
||||
... return loss, logits
|
||||
>>> inputs = Tensor(np.random.randn(16, 10).astype(np.float32))
|
||||
>>> labels = Tensor(np.random.randn(16, 1).astype(np.float32))
|
||||
>>> weights = net.trainable_params()
|
||||
>>> # Case 1: gradient with respect to inputs.
|
||||
>>> # Aux value does not contribute to the gradient.
|
||||
>>> grad_fn = grad(forward, grad_position=0, weights=None, has_aux=True)
|
||||
>>> inputs_gradient, (aux_logits,) = grad_fn(inputs, labels)
|
||||
>>> print(len(inputs_gradient))
|
||||
2
|
||||
>>> print(aux_logits.shape)
|
||||
(16, 1)
|
||||
>>>
|
||||
>>> # Case 2: gradient with respect to weights.
|
||||
>>> grad_fn = grad(forward, grad_position=None, weights=weights, has_aux=True)
|
||||
>>> params_gradient, (aux_logits,) = grad_fn(inputs, labels)
|
||||
>>> print(len(weights), len(params_gradient))
|
||||
2 2
|
||||
>>> print(aux_logits.shape)
|
||||
(16, 1)
|
||||
>>>
|
||||
>>> # Case 3: gradient with respect to inputs and weights.
|
||||
>>> grad_fn = grad(forward, grad_position=0, weights=weights, has_aux=False)
|
||||
>>> inputs_gradient, params_gradient = grad_fn(inputs, labels)
|
||||
>>> print(len(weights), len(params_gradient))
|
||||
2 2
|
||||
"""
|
||||
if grad_position is None and weights is None:
|
||||
raise ValueError("`grad_position` and `weight` can not be None at the same time.")
|
||||
|
||||
if grad_position is None:
|
||||
return _get_grad_op(True, False, has_aux)(fn, weights)
|
||||
|
||||
grad_position = _convert_grad_position_type(grad_position)
|
||||
if sens_param:
|
||||
return grad_by_position_with_sens(fn, None, grad_position)
|
||||
return grad_by_position(fn, None, grad_position)
|
||||
if weights is None:
|
||||
return _get_grad_op(False, True, has_aux)(fn, None, grad_position)
|
||||
return _get_grad_op(True, True, has_aux)(fn, weights, grad_position)
|
||||
|
||||
|
||||
def value_and_grad(fn, grad_position=0, weights=None, has_aux=False):
|
||||
"""
|
||||
A wrapper function to generate the function to calculate forward output and gradient for the input function.
|
||||
|
||||
As for gradient, three typical cases are included:
|
||||
|
||||
1. gradient with respect to inputs. In this case, `grad_position` is not None while `weights` is None.
|
||||
2. gradient with respect to weights. In this case, `grad_position` is None while `weights` is not None.
|
||||
3. gradient with respect to inputs and weights. In this case, `grad_position` and `weights` are not None.
|
||||
|
||||
Args:
|
||||
fn (Union(Cell, function)): Function to do GradOperation.
|
||||
grad_position (Union(NoneType, int, tuple[int])): Index to specify which inputs to be differentiated.
|
||||
If int, get the gradient with respect to single input.
|
||||
If tuple, get the gradients with respect to selected inputs. `grad_position` begins with 0.
|
||||
If None, none derivative of any input will be solved, and in this case, `weights` is required.
|
||||
Default: 0.
|
||||
weights (Union(ParameterTuple, Parameter, list(Parameter))): The parameters of the training network that need to
|
||||
calculate the gradient. `weights` can be got through `weights = net.trainable_params()` .
|
||||
Default: None.
|
||||
has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
|
||||
will be returned straightly. It means the `fn` must return more than one outputs in this case.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
Function, returns the gradient function to calculate forward output and gradient for the input function or cell.
|
||||
For example, as for `out1, out2 = fn(*args)` , gradient function will return outputs like
|
||||
`((out1, out2), gradient)` . When `has_aux` is set True, only `out1` contributes to the differentiation.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `grad_position` and `weights` are None.
|
||||
TypeError: If type of Args does not belong to required ones.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore
|
||||
>>> from mindspore import Tensor, ops, nn
|
||||
>>> from mindspore.ops import value_and_grad
|
||||
>>>
|
||||
>>> # Cell object to be differentiated
|
||||
>>> class Net(nn.Cell):
|
||||
... def construct(self, x, y, z):
|
||||
... return x * y * z
|
||||
>>> x = Tensor([1, 2], mindspore.float32)
|
||||
>>> y = Tensor([-2, 3]), mindspore.float32)
|
||||
>>> z = Tensor([0, 3]), mindspore.float32)
|
||||
>>> net = Net()
|
||||
>>> grad_fn = value_and_grad(net, grad_position=1)
|
||||
>>> output, inputs_gradient = grad_fn(x, y, z)
|
||||
>>> print(output)
|
||||
[ -0. 18.]
|
||||
>>> print(inputs_gradient)
|
||||
[0, 6.]
|
||||
>>>
|
||||
>>> # Function object to be differentiated
|
||||
>>> def fn(x, y, z):
|
||||
... res = x * ops.exp(y) * ops.pow(z, 2)
|
||||
... return res, z
|
||||
>>> x = Tensor(np.array([3, 3]).astype(np.float32))
|
||||
>>> y = Tensor(np.array([0, 0]).astype(np.float32))
|
||||
>>> z = Tensor(np.array([2, 2]).astype(np.float32))
|
||||
>>> output, inputs_gradient = grad(net, grad_position=(1, 2), weights=None, has_aux=True)(x, y)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[2], dtype=Float32, value= [ 7.50000000e+01, 7.50000000e+01]),
|
||||
Tensor(shape=[2], dtype=Float32, value= [ 5.00000000e+00, 5.00000000e+00]))
|
||||
>>> print(inputs_gradient)
|
||||
(Tensor(shape=[2], dtype=Float32, value= [ 7.50000000e+01, 7.50000000e+01]),
|
||||
Tensor(shape=[2], dtype=Float32, value= [ 3.00000000e+01, 3.00000000e+01]))
|
||||
>>>
|
||||
>>> # For given network to be differentiated with both inputs and weights, there are 3 cases.
|
||||
>>> net = nn.Dense(10, 1)
|
||||
>>> loss_fn = nn.MSELoss()
|
||||
>>> def forward(inputs, labels):
|
||||
... logits = net(inputs)
|
||||
... loss = loss_fn(logits, labels)
|
||||
... return loss, logits
|
||||
>>> inputs = Tensor(np.random.randn(16, 10).astype(np.float32))
|
||||
>>> labels = Tensor(np.random.randn(16, 1).astype(np.float32))
|
||||
>>> weights = net.trainable_params()
|
||||
>>>
|
||||
>>> # Case 1: gradient with respect to inputs.
|
||||
>>> # For has_aux is set True, only loss contributes to the gradient.
|
||||
>>> grad_fn = value_and_grad(forward, grad_position=0, weights=None, has_aux=True)
|
||||
>>> (loss, logits), inputs_gradient = grad_fn(inputs, labels)
|
||||
>>> print(logits.shape)
|
||||
(16, 1)
|
||||
>>> print(inputs.shape, inputs_gradient.shape)
|
||||
(16, 10) (16, 10)
|
||||
>>>
|
||||
>>> # Case 2: gradient with respect to weights.
|
||||
>>> # For has_aux is set True, only loss contributes to the gradient.
|
||||
>>> grad_fn = value_and_grad(forward, grad_position=None, weights=weights, has_aux=True)
|
||||
>>> (loss, logits), params_gradient = grad_fn(inputs, labels)
|
||||
>>> print(logits.shape)
|
||||
(16, 1)
|
||||
>>> print(len(weights), len(params_gradient))
|
||||
2 2
|
||||
>>>
|
||||
>>> # Case 3: gradient with respect to inputs and weights.
|
||||
>>> # For has_aux is set False, both loss and logits contribute to the gradient.
|
||||
>>> grad_fn = value_and_grad(forward, grad_position=0, weights=weights, has_aux=False)
|
||||
>>> (loss, logits), (inputs_gradient, params_gradient) = grad_fn(inputs, labels)
|
||||
>>> print(logits.shape)
|
||||
(16, 1)
|
||||
>>> print(inputs.shape, inputs_gradient.shape)
|
||||
(16, 10) (16, 10)
|
||||
>>> print(len(weights), len(params_gradient))
|
||||
2 2
|
||||
"""
|
||||
if grad_position is None and weights is None:
|
||||
raise ValueError("`grad_position` and `weight` can not be None at the same time.")
|
||||
|
||||
if grad_position is None:
|
||||
return _get_grad_op(True, False, has_aux, True)(fn, weights)
|
||||
|
||||
grad_position = _convert_grad_position_type(grad_position)
|
||||
if weights is None:
|
||||
return _get_grad_op(False, True, has_aux, True)(fn, None, grad_position)
|
||||
return _get_grad_op(True, True, has_aux, True)(fn, weights, grad_position)
|
||||
|
||||
|
||||
def _trans_jet_inputs(primals_item, series_item):
|
||||
|
@ -531,6 +726,7 @@ def vjp(fn, inputs, v):
|
|||
|
||||
__all__ = [
|
||||
'grad',
|
||||
'value_and_grad',
|
||||
'jet',
|
||||
'derivative',
|
||||
'jvp',
|
||||
|
|
|
@ -19,7 +19,7 @@ import mindspore.nn as nn
|
|||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore import ms_function
|
||||
from mindspore.ops.functional import grad
|
||||
from mindspore.ops.functional import grad, value_and_grad
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import Parameter, ParameterTuple
|
||||
|
@ -29,19 +29,22 @@ context.set_context(mode=context.GRAPH_MODE)
|
|||
|
||||
class SingleInputSingleOutputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3
|
||||
return x ** 3
|
||||
|
||||
|
||||
class SingleInputMultipleOutputsNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3, 2*x
|
||||
return x ** 3, 2 * x
|
||||
|
||||
|
||||
class MultipleInputsSingleOutputNet(nn.Cell):
|
||||
def construct(self, x, y, z):
|
||||
return x*y*z
|
||||
return x * y * z
|
||||
|
||||
|
||||
class MultipleInputsMultipleOutputsNet(nn.Cell):
|
||||
def construct(self, x, y, z):
|
||||
return x**2 + y**2 + z**2, x*y*z
|
||||
return x ** 2 + y ** 2 + z ** 2, x * y * z
|
||||
|
||||
|
||||
class ParamNet(nn.Cell):
|
||||
|
@ -56,15 +59,15 @@ class ParamNet(nn.Cell):
|
|||
|
||||
|
||||
def function(x, y, z):
|
||||
return x**2 + y**2 + z**2, x*y*z
|
||||
return x ** 2 + y ** 2 + z ** 2, x * y * z
|
||||
|
||||
|
||||
def iteration_grad_function(x, y, z):
|
||||
return x**2*y*z
|
||||
return x ** 2 * y * z
|
||||
|
||||
|
||||
@ms_function
|
||||
def grad_warp_with_msfunction(x, y, z):
|
||||
def grad_wrap_with_msfunction(x, y, z):
|
||||
output = grad(function)(x, y, z)
|
||||
return output
|
||||
|
||||
|
@ -145,28 +148,6 @@ def test_grad_multiple_inputs_multiple_outputs_cell_graph():
|
|||
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_function_with_sens_graph():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad with function setting sens_param in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
v = Tensor(np.array([[-1, 3], [2, 1]]).astype(np.float32))
|
||||
expect_grad1 = Tensor(np.array([[4, 36], [26, 0]]).astype(np.float32))
|
||||
expect_grad2 = Tensor(np.array([[2, 36], [14, 6]]).astype(np.float32))
|
||||
real_grad = grad(function, grad_position=(1, 2), sens_param=True)(x, y, z, (v, v))
|
||||
assert isinstance(real_grad, tuple)
|
||||
assert len(real_grad) == 2
|
||||
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
|
||||
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -191,17 +172,17 @@ def test_grad_iteration_function_graph():
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_warp_with_msfunction_graph():
|
||||
def test_grad_wrap_with_msfunction_graph():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad warpped with ms_function in graph mode.
|
||||
Description: Test F.grad wrapped with ms_function in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
|
||||
real_grad = grad_warp_with_msfunction(x, y, z)
|
||||
real_grad = grad_wrap_with_msfunction(x, y, z)
|
||||
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
|
@ -246,6 +227,210 @@ def test_grad_with_weights_twice_graph():
|
|||
assert np.allclose(out2[0].asnumpy(), expect2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_with_weights_has_aux_graph():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad with different weights and has_aux in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class ParamNetAux(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ParamNetAux, self).__init__()
|
||||
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
|
||||
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
|
||||
|
||||
def construct(self, x):
|
||||
res = x * self.w * self.z
|
||||
return res, x, self.w
|
||||
|
||||
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||
net = ParamNetAux()
|
||||
weights = ParameterTuple(net.trainable_params())
|
||||
expect_grad_input = np.array([6, 6]).astype(np.float32)
|
||||
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
|
||||
expect_grad_weight2 = np.array([2, 4]).astype(np.float32)
|
||||
expect_aux1 = np.array([1, 2]).astype(np.float32)
|
||||
expect_aux2 = np.array([2, 2]).astype(np.float32)
|
||||
res, aux = grad(net, 0, weights, True)(x)
|
||||
assert np.allclose(res[0].asnumpy(), expect_grad_input)
|
||||
assert np.allclose(res[1][0].asnumpy(), expect_grad_weight1)
|
||||
assert np.allclose(res[1][1].asnumpy(), expect_grad_weight2)
|
||||
assert np.allclose(aux[0].asnumpy(), expect_aux1)
|
||||
assert np.allclose(aux[1].asnumpy(), expect_aux2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_ms_function_grad_with_weights_has_aux_graph():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad with different weights and has_aux in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class ParamMultipleInputNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ParamMultipleInputNet, self).__init__()
|
||||
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
|
||||
|
||||
def construct(self, x, y):
|
||||
outputs = x * y * self.w
|
||||
return outputs, x, self.w
|
||||
|
||||
net = ParamMultipleInputNet()
|
||||
weights = net.trainable_params()
|
||||
|
||||
@ms_function
|
||||
def user_fn(x, y):
|
||||
res, aux = grad(net, 0, weights, True)(x, y)
|
||||
return res, aux
|
||||
|
||||
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||
y = Tensor(np.array([3, 3]).astype(np.float32))
|
||||
res, aux = user_fn(x, y)
|
||||
expect_grad_input = np.array([6, 6]).astype(np.float32)
|
||||
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
|
||||
expect_aux1 = np.array([1, 2]).astype(np.float32)
|
||||
expect_aux2 = np.array([2, 2]).astype(np.float32)
|
||||
assert np.allclose(res[0].asnumpy(), expect_grad_input)
|
||||
assert np.allclose(res[1][0].asnumpy(), expect_grad_weight1)
|
||||
assert np.allclose(aux[0].asnumpy(), expect_aux1)
|
||||
assert np.allclose(aux[1].asnumpy(), expect_aux2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_construct_grad_with_weights_has_aux_graph():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad with different weights and has_aux in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class ParamMultipleInputNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ParamMultipleInputNet, self).__init__()
|
||||
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
|
||||
|
||||
def construct(self, x, y):
|
||||
outputs = x * y * self.w
|
||||
return outputs, x, self.w
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = net.trainable_params()
|
||||
|
||||
def construct(self, x, y):
|
||||
res, aux = grad(self.net, 0, self.weights, True)(x, y)
|
||||
return res, aux
|
||||
|
||||
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||
y = Tensor(np.array([3, 3]).astype(np.float32))
|
||||
inner_net = ParamMultipleInputNet()
|
||||
grad_net = GradNet(inner_net)
|
||||
res, aux = grad_net(x, y)
|
||||
expect_grad_input = np.array([6, 6]).astype(np.float32)
|
||||
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
|
||||
expect_aux1 = np.array([1, 2]).astype(np.float32)
|
||||
expect_aux2 = np.array([2, 2]).astype(np.float32)
|
||||
assert np.allclose(res[0].asnumpy(), expect_grad_input)
|
||||
assert np.allclose(res[1][0].asnumpy(), expect_grad_weight1)
|
||||
assert np.allclose(aux[0].asnumpy(), expect_aux1)
|
||||
assert np.allclose(aux[1].asnumpy(), expect_aux2)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_if_with_weights_has_aux_graph():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad with different weights and has_aux as well as if case in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
|
||||
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
|
||||
|
||||
def construct(self, x):
|
||||
if x[0] == 1:
|
||||
res = x * self.w * self.z
|
||||
else:
|
||||
res = x * x
|
||||
return res, x, self.w
|
||||
|
||||
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||
net = Net()
|
||||
weights = ParameterTuple(net.trainable_params())
|
||||
expect_grad_input = np.array([6, 6]).astype(np.float32)
|
||||
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
|
||||
expect_grad_weight2 = np.array([2, 4]).astype(np.float32)
|
||||
expect_aux1 = np.array([1, 2]).astype(np.float32)
|
||||
expect_aux2 = np.array([2, 2]).astype(np.float32)
|
||||
res, aux = grad(net, 0, weights, True)(x)
|
||||
assert np.allclose(res[0].asnumpy(), expect_grad_input)
|
||||
assert np.allclose(res[1][0].asnumpy(), expect_grad_weight1)
|
||||
assert np.allclose(res[1][1].asnumpy(), expect_grad_weight2)
|
||||
assert np.allclose(aux[0].asnumpy(), expect_aux1)
|
||||
assert np.allclose(aux[1].asnumpy(), expect_aux2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_nest_with_weights_has_aux_graph():
|
||||
"""
|
||||
Features: Function value_and_grad.
|
||||
Description: Test F.grad with different weights and has_aux as well as nested nets in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class InnerNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x * 3, x
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(Net, self).__init__()
|
||||
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
|
||||
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
|
||||
self.net = net
|
||||
|
||||
def construct(self, x):
|
||||
res1 = x * self.w * self.z
|
||||
res2 = self.net(res1)
|
||||
return res2
|
||||
|
||||
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||
inner_net = InnerNet()
|
||||
net = Net(inner_net)
|
||||
weights = ParameterTuple(net.trainable_params())
|
||||
expect_grad_input = np.array([18, 18]).astype(np.float32)
|
||||
expect_grad_weight1 = np.array([9, 18]).astype(np.float32)
|
||||
expect_grad_weight2 = np.array([6, 12]).astype(np.float32)
|
||||
expect_aux = np.array([6, 12]).astype(np.float32)
|
||||
res, aux = grad(net, 0, weights, True)(x)
|
||||
assert np.allclose(res[0].asnumpy(), expect_grad_input)
|
||||
assert np.allclose(res[1][0].asnumpy(), expect_grad_weight1)
|
||||
assert np.allclose(res[1][1].asnumpy(), expect_grad_weight2)
|
||||
assert np.allclose(aux[0].asnumpy(), expect_aux)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -255,6 +440,7 @@ def test_grad_if_ith_train_one_step():
|
|||
Description: Grad a network with each output. A simplification for GAN network.
|
||||
Expectation: Compile success.
|
||||
"""
|
||||
|
||||
class IthOutputCell(nn.Cell):
|
||||
def __init__(self, network, output_index):
|
||||
super().__init__()
|
||||
|
@ -316,6 +502,7 @@ def test_grad_net_d_net_g():
|
|||
Description: Grad two different network. A simplification for GAN network.
|
||||
Expectation: Compile success.
|
||||
"""
|
||||
|
||||
class NetD(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -398,3 +585,176 @@ def test_grad_net_d_net_g():
|
|||
network = Backbone()
|
||||
train_one_net = MyTrainOneStepCell(network)
|
||||
train_one_net(x, y)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_value_and_grad_with_weights_has_aux_graph():
|
||||
"""
|
||||
Features: Function value_and_grad.
|
||||
Description: Test F.value_and_grad with different weights and has_aux in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class ParamNetMultipleOutputs(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ParamNetMultipleOutputs, self).__init__()
|
||||
self.w1 = Parameter(Tensor([2., 2.], mstype.float32), name="w1")
|
||||
self.w2 = Parameter(Tensor([3., 3.], mstype.float32), name="w2")
|
||||
|
||||
def construct(self, x):
|
||||
res = x * self.w1 * self.w2
|
||||
return res, x, self.w1
|
||||
|
||||
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||
net = ParamNetMultipleOutputs()
|
||||
weights = ParameterTuple(net.trainable_params())
|
||||
expect_grad_input = np.array([6, 6]).astype(np.float32)
|
||||
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
|
||||
expect_grad_weight2 = np.array([2, 4]).astype(np.float32)
|
||||
expect_value0 = np.array([6, 12]).astype(np.float32)
|
||||
expect_value1 = np.array([1, 2]).astype(np.float32)
|
||||
expect_value2 = np.array([2, 2]).astype(np.float32)
|
||||
value, gradient = value_and_grad(net, 0, weights, True)(x)
|
||||
assert np.allclose(value[0].asnumpy(), expect_value0)
|
||||
assert np.allclose(value[1].asnumpy(), expect_value1)
|
||||
assert np.allclose(value[2].asnumpy(), expect_value2)
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad_input)
|
||||
assert np.allclose(gradient[1][0].asnumpy(), expect_grad_weight1)
|
||||
assert np.allclose(gradient[1][1].asnumpy(), expect_grad_weight2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_construct_value_and_grad_with_weights_has_aux_graph():
|
||||
"""
|
||||
Features: Function value_and_grad.
|
||||
Description: Test F.value_and_grad with different weights and has_aux in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class ParamNetMultipleInputsOutputs(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ParamNetMultipleInputsOutputs, self).__init__()
|
||||
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
|
||||
|
||||
def construct(self, x, y):
|
||||
res = x * y * self.w
|
||||
return res, x, self.w
|
||||
|
||||
class GradNet2(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet2, self).__init__()
|
||||
self.net = net
|
||||
self.weights = net.trainable_params()
|
||||
|
||||
def construct(self, x, y):
|
||||
value, gradient = value_and_grad(self.net, 0, self.weights, True)(x, y)
|
||||
return value, gradient
|
||||
|
||||
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||
y = Tensor(np.array([3, 3]).astype(np.float32))
|
||||
inner_net = ParamNetMultipleInputsOutputs()
|
||||
grad_net = GradNet2(inner_net)
|
||||
value, gradient = grad_net(x, y)
|
||||
expect_grad_input = np.array([6, 6]).astype(np.float32)
|
||||
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
|
||||
expect_value0 = np.array([6, 12]).astype(np.float32)
|
||||
expect_value1 = np.array([1, 2]).astype(np.float32)
|
||||
expect_value2 = np.array([2, 2]).astype(np.float32)
|
||||
assert np.allclose(value[0].asnumpy(), expect_value0)
|
||||
assert np.allclose(value[1].asnumpy(), expect_value1)
|
||||
assert np.allclose(value[2].asnumpy(), expect_value2)
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad_input)
|
||||
assert np.allclose(gradient[1][0].asnumpy(), expect_grad_weight1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_value_and_grad_nest_with_weights_graph():
|
||||
"""
|
||||
Features: Function value_and_grad.
|
||||
Description: Test F.value_and_grad with different weights and has_aux as well as nested nets in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class InnerNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x * 3, x
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(Net, self).__init__()
|
||||
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
|
||||
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
|
||||
self.net = net
|
||||
|
||||
def construct(self, x):
|
||||
res1 = x * self.w * self.z
|
||||
res2 = self.net(res1)
|
||||
return res2
|
||||
|
||||
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||
inner_net = InnerNet()
|
||||
net = Net(inner_net)
|
||||
weights = ParameterTuple(net.trainable_params())
|
||||
expect_grad_input = np.array([24, 24]).astype(np.float32)
|
||||
expect_grad_weight1 = np.array([12, 24]).astype(np.float32)
|
||||
expect_grad_weight2 = np.array([8, 16]).astype(np.float32)
|
||||
expect_value0 = np.array([18, 36]).astype(np.float32)
|
||||
expect_value1 = np.array([6, 12]).astype(np.float32)
|
||||
value, gradient = value_and_grad(net, 0, weights, False)(x)
|
||||
assert np.allclose(value[0].asnumpy(), expect_value0)
|
||||
assert np.allclose(value[1].asnumpy(), expect_value1)
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad_input)
|
||||
assert np.allclose(gradient[1][0].asnumpy(), expect_grad_weight1)
|
||||
assert np.allclose(gradient[1][1].asnumpy(), expect_grad_weight2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_value_and_grad_nest_with_weights_has_aux_graph():
|
||||
"""
|
||||
Features: Function value_and_grad.
|
||||
Description: Test F.value_and_grad with different weights and has_aux as well as nested nets in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class InnerNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x * 3, x
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(Net, self).__init__()
|
||||
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
|
||||
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
|
||||
self.net = net
|
||||
|
||||
def construct(self, x):
|
||||
res1 = x * self.w * self.z
|
||||
res2 = self.net(res1)
|
||||
return res2
|
||||
|
||||
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||
inner_net = InnerNet()
|
||||
net = Net(inner_net)
|
||||
weights = ParameterTuple(net.trainable_params())
|
||||
expect_grad_input = np.array([18, 18]).astype(np.float32)
|
||||
expect_grad_weight1 = np.array([9, 18]).astype(np.float32)
|
||||
expect_grad_weight2 = np.array([6, 12]).astype(np.float32)
|
||||
expect_value0 = np.array([18, 36]).astype(np.float32)
|
||||
expect_value1 = np.array([6, 12]).astype(np.float32)
|
||||
value, gradient = value_and_grad(net, 0, weights, True)(x)
|
||||
assert np.allclose(value[0].asnumpy(), expect_value0)
|
||||
assert np.allclose(value[1].asnumpy(), expect_value1)
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad_input)
|
||||
assert np.allclose(gradient[1][0].asnumpy(), expect_grad_weight1)
|
||||
assert np.allclose(gradient[1][1].asnumpy(), expect_grad_weight2)
|
||||
|
|
|
@ -19,8 +19,9 @@ import mindspore.nn as nn
|
|||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore import ms_function
|
||||
from mindspore.ops.functional import grad
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops.functional import grad, value_and_grad
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import Parameter, ParameterTuple
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
@ -28,19 +29,22 @@ context.set_context(mode=context.PYNATIVE_MODE)
|
|||
|
||||
class SingleInputSingleOutputNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3
|
||||
return x ** 3
|
||||
|
||||
|
||||
class SingleInputMultipleOutputsNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x**3, 2*x
|
||||
return x ** 3, 2 * x
|
||||
|
||||
|
||||
class MultipleInputsSingleOutputNet(nn.Cell):
|
||||
def construct(self, x, y, z):
|
||||
return x*y*z
|
||||
return x * y * z
|
||||
|
||||
|
||||
class MultipleInputsMultipleOutputsNet(nn.Cell):
|
||||
def construct(self, x, y, z):
|
||||
return x**2 + y**2 + z**2, x*y*z
|
||||
return x ** 2 + y ** 2 + z ** 2, x * y * z
|
||||
|
||||
|
||||
class ParamNet(nn.Cell):
|
||||
|
@ -55,13 +59,15 @@ class ParamNet(nn.Cell):
|
|||
|
||||
|
||||
def function(x, y, z):
|
||||
return x**2 + y**2 + z**2, x*y*z
|
||||
return x ** 2 + y ** 2 + z ** 2, x * y * z
|
||||
|
||||
|
||||
def iteration_grad_function(x, y, z):
|
||||
return x**2*y*z
|
||||
return x ** 2 * y * z
|
||||
|
||||
|
||||
@ms_function
|
||||
def grad_warp_with_msfunction(x, y, z):
|
||||
def grad_wrap_with_msfunction(x, y, z):
|
||||
output = grad(function)(x, y, z)
|
||||
return output
|
||||
|
||||
|
@ -142,28 +148,6 @@ def test_grad_multiple_inputs_multiple_outputs_cell_pynative():
|
|||
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_function_with_sens_pynative():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad with function setting sens_param in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
v = Tensor(np.array([[-1, 3], [2, 1]]).astype(np.float32))
|
||||
expect_grad1 = Tensor(np.array([[4, 36], [26, 0]]).astype(np.float32))
|
||||
expect_grad2 = Tensor(np.array([[2, 36], [14, 6]]).astype(np.float32))
|
||||
real_grad = grad(function, grad_position=(1, 2), sens_param=True)(x, y, z, (v, v))
|
||||
assert isinstance(real_grad, tuple)
|
||||
assert len(real_grad) == 2
|
||||
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
|
||||
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -184,22 +168,24 @@ def test_grad_iteration_function_pynative():
|
|||
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
|
||||
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_warp_with_msfunction_pynative():
|
||||
def test_grad_wrap_with_msfunction_pynative():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad warpped with ms_function in pynative mode.
|
||||
Description: Test F.grad wrapped with ms_function in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
|
||||
real_grad = grad_warp_with_msfunction(x, y, z)
|
||||
real_grad = grad_wrap_with_msfunction(x, y, z)
|
||||
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -239,3 +225,249 @@ def test_grad_with_weights_twice_pynative():
|
|||
out2 = grad_fn(net, weights2)(x)
|
||||
assert np.allclose(out1[0].asnumpy(), expect1)
|
||||
assert np.allclose(out2[0].asnumpy(), expect2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_with_weights_has_aux_pynative():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad with different weights and has_aux in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class ParamNetAux(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ParamNetAux, self).__init__()
|
||||
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
|
||||
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
|
||||
|
||||
def construct(self, x):
|
||||
res = x * self.w * self.z
|
||||
return res, x, self.w
|
||||
|
||||
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||
net = ParamNetAux()
|
||||
weights = ParameterTuple(net.trainable_params())
|
||||
expect_grad_input = np.array([6, 6]).astype(np.float32)
|
||||
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
|
||||
expect_grad_weight2 = np.array([2, 4]).astype(np.float32)
|
||||
expect_aux1 = np.array([1, 2]).astype(np.float32)
|
||||
expect_aux2 = np.array([2, 2]).astype(np.float32)
|
||||
res, aux = grad(net, 0, weights, True)(x)
|
||||
assert np.allclose(res[0].asnumpy(), expect_grad_input)
|
||||
assert np.allclose(res[1][0].asnumpy(), expect_grad_weight1)
|
||||
assert np.allclose(res[1][1].asnumpy(), expect_grad_weight2)
|
||||
assert np.allclose(aux[0].asnumpy(), expect_aux1)
|
||||
assert np.allclose(aux[1].asnumpy(), expect_aux2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_if_with_weights_has_aux_pynative():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad with different weights and has_aux as well as if case in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
|
||||
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
|
||||
|
||||
def construct(self, x):
|
||||
if x[0] == 1:
|
||||
res = x * self.w * self.z
|
||||
else:
|
||||
res = x * x
|
||||
return res, x, self.w
|
||||
|
||||
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||
net = Net()
|
||||
weights = ParameterTuple(net.trainable_params())
|
||||
expect_grad_input = np.array([6, 6]).astype(np.float32)
|
||||
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
|
||||
expect_grad_weight2 = np.array([2, 4]).astype(np.float32)
|
||||
expect_aux1 = np.array([1, 2]).astype(np.float32)
|
||||
expect_aux2 = np.array([2, 2]).astype(np.float32)
|
||||
res, aux = grad(net, 0, weights, True)(x)
|
||||
assert np.allclose(res[0].asnumpy(), expect_grad_input)
|
||||
assert np.allclose(res[1][0].asnumpy(), expect_grad_weight1)
|
||||
assert np.allclose(res[1][1].asnumpy(), expect_grad_weight2)
|
||||
assert np.allclose(aux[0].asnumpy(), expect_aux1)
|
||||
assert np.allclose(aux[1].asnumpy(), expect_aux2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_nest_with_weights_has_aux_pynative():
|
||||
"""
|
||||
Features: Function value_and_grad.
|
||||
Description: Test F.grad with different weights and has_aux as well as nested nets in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class InnerNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x * 3, x
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(Net, self).__init__()
|
||||
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
|
||||
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
|
||||
self.net = net
|
||||
|
||||
def construct(self, x):
|
||||
res1 = x * self.w * self.z
|
||||
res2 = self.net(res1)
|
||||
return res2
|
||||
|
||||
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||
inner_net = InnerNet()
|
||||
net = Net(inner_net)
|
||||
weights = ParameterTuple(net.trainable_params())
|
||||
expect_grad_input = np.array([18, 18]).astype(np.float32)
|
||||
expect_grad_weight1 = np.array([9, 18]).astype(np.float32)
|
||||
expect_grad_weight2 = np.array([6, 12]).astype(np.float32)
|
||||
expect_aux = np.array([6, 12]).astype(np.float32)
|
||||
res, aux = grad(net, 0, weights, True)(x)
|
||||
assert np.allclose(res[0].asnumpy(), expect_grad_input)
|
||||
assert np.allclose(res[1][0].asnumpy(), expect_grad_weight1)
|
||||
assert np.allclose(res[1][1].asnumpy(), expect_grad_weight2)
|
||||
assert np.allclose(aux[0].asnumpy(), expect_aux)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_value_and_grad_with_weights_has_aux_pynative():
|
||||
"""
|
||||
Features: Function value_and_grad.
|
||||
Description: Test F.value_and_grad with different weights and has_aux in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class ParamNetMultipleOutputs(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ParamNetMultipleOutputs, self).__init__()
|
||||
self.w1 = Parameter(Tensor([2., 2.], mstype.float32), name="w1")
|
||||
self.w2 = Parameter(Tensor([3., 3.], mstype.float32), name="w2")
|
||||
|
||||
def construct(self, x):
|
||||
res = x * self.w1 * self.w2
|
||||
return res, x, self.w1
|
||||
|
||||
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||
net = ParamNetMultipleOutputs()
|
||||
weights = ParameterTuple(net.trainable_params())
|
||||
expect_grad_input = np.array([6, 6]).astype(np.float32)
|
||||
expect_grad_weight1 = np.array([3, 6]).astype(np.float32)
|
||||
expect_grad_weight2 = np.array([2, 4]).astype(np.float32)
|
||||
expect_value0 = np.array([6, 12]).astype(np.float32)
|
||||
expect_value1 = np.array([1, 2]).astype(np.float32)
|
||||
expect_value2 = np.array([2, 2]).astype(np.float32)
|
||||
value, gradient = value_and_grad(net, 0, weights, True)(x)
|
||||
assert np.allclose(value[0].asnumpy(), expect_value0)
|
||||
assert np.allclose(value[1].asnumpy(), expect_value1)
|
||||
assert np.allclose(value[2].asnumpy(), expect_value2)
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad_input)
|
||||
assert np.allclose(gradient[1][0].asnumpy(), expect_grad_weight1)
|
||||
assert np.allclose(gradient[1][1].asnumpy(), expect_grad_weight2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_value_and_grad_nest_with_weights_pynative():
|
||||
"""
|
||||
Features: Function value_and_grad.
|
||||
Description: Test F.value_and_grad with different weights and has_aux as well as nested nets in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class InnerNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x * 3, x
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(Net, self).__init__()
|
||||
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
|
||||
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
|
||||
self.net = net
|
||||
|
||||
def construct(self, x):
|
||||
res1 = x * self.w * self.z
|
||||
res2 = self.net(res1)
|
||||
return res2
|
||||
|
||||
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||
inner_net = InnerNet()
|
||||
net = Net(inner_net)
|
||||
weights = ParameterTuple(net.trainable_params())
|
||||
expect_grad_input = np.array([24, 24]).astype(np.float32)
|
||||
expect_grad_weight1 = np.array([12, 24]).astype(np.float32)
|
||||
expect_grad_weight2 = np.array([8, 16]).astype(np.float32)
|
||||
expect_value0 = np.array([18, 36]).astype(np.float32)
|
||||
expect_value1 = np.array([6, 12]).astype(np.float32)
|
||||
value, gradient = value_and_grad(net, 0, weights, False)(x)
|
||||
assert np.allclose(value[0].asnumpy(), expect_value0)
|
||||
assert np.allclose(value[1].asnumpy(), expect_value1)
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad_input)
|
||||
assert np.allclose(gradient[1][0].asnumpy(), expect_grad_weight1)
|
||||
assert np.allclose(gradient[1][1].asnumpy(), expect_grad_weight2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_value_and_grad_nest_with_weights_has_aux_pynative():
|
||||
"""
|
||||
Features: Function value_and_grad.
|
||||
Description: Test F.value_and_grad with different weights and has_aux as well as nested nets in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class InnerNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x * 3, x
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(Net, self).__init__()
|
||||
self.w = Parameter(Tensor([2., 2.], mstype.float32), name="w")
|
||||
self.z = Parameter(Tensor([3., 3.], mstype.float32), name="z")
|
||||
self.net = net
|
||||
|
||||
def construct(self, x):
|
||||
res1 = x * self.w * self.z
|
||||
res2 = self.net(res1)
|
||||
return res2
|
||||
|
||||
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||
inner_net = InnerNet()
|
||||
net = Net(inner_net)
|
||||
weights = ParameterTuple(net.trainable_params())
|
||||
expect_grad_input = np.array([18, 18]).astype(np.float32)
|
||||
expect_grad_weight1 = np.array([9, 18]).astype(np.float32)
|
||||
expect_grad_weight2 = np.array([6, 12]).astype(np.float32)
|
||||
expect_value0 = np.array([18, 36]).astype(np.float32)
|
||||
expect_value1 = np.array([6, 12]).astype(np.float32)
|
||||
value, gradient = value_and_grad(net, 0, weights, True)(x)
|
||||
assert np.allclose(value[0].asnumpy(), expect_value0)
|
||||
assert np.allclose(value[1].asnumpy(), expect_value1)
|
||||
assert np.allclose(gradient[0].asnumpy(), expect_grad_input)
|
||||
assert np.allclose(gradient[1][0].asnumpy(), expect_grad_weight1)
|
||||
assert np.allclose(gradient[1][1].asnumpy(), expect_grad_weight2)
|
||||
|
|
|
@ -98,8 +98,8 @@ class TestKPynative : public UT::Common {
|
|||
GradPynativeOp(k_pynative_cell, c_node, args, out);
|
||||
}
|
||||
}
|
||||
auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, std::vector<size_t>{0}, true, false, false,
|
||||
true);
|
||||
GradAttr grad_attr(true, false, false, false);
|
||||
auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, std::vector<size_t>{0}, grad_attr, true);
|
||||
return bprop_fg;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -58,16 +58,3 @@ def test_grad_multiple_inputs_multiple_outputs_cell_graph():
|
|||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
net = MultipleInputsMultipleOutputsNet()
|
||||
grad(net, grad_position=(1, 2))(x, y, z)
|
||||
|
||||
|
||||
def test_grad_function_with_sens_graph():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad with function setting sens_param in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
v = Tensor(np.array([[-1, 3], [2, 1]]).astype(np.float32))
|
||||
grad(function, grad_position=(1, 2), sens_param=True)(x, y, z, (v, v))
|
||||
|
|
|
@ -58,16 +58,3 @@ def test_grad_multiple_inputs_multiple_outputs_cell_pynative():
|
|||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
net = MultipleInputsMultipleOutputsNet()
|
||||
grad(net, grad_position=(1, 2))(x, y, z)
|
||||
|
||||
|
||||
def test_grad_function_with_sens_pynative():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad with function setting sens_param in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
v = Tensor(np.array([[-1, 3], [2, 1]]).astype(np.float32))
|
||||
grad(function, grad_position=(1, 2), sens_param=True)(x, y, z, (v, v))
|
||||
|
|
Loading…
Reference in New Issue