forked from mindspore-Ecosystem/mindspore
!12071 Refactor grad operation implementation.
From: @zh_qh Reviewed-by: Signed-off-by:
This commit is contained in:
commit
623cca4214
|
@ -525,105 +525,88 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_
|
|||
}
|
||||
}
|
||||
|
||||
FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
|
||||
const std::vector<AnfNodePtr> ¶ms_list, const std::vector<AnfNodePtr> &args,
|
||||
bool applyJ) {
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
|
||||
const std::vector<AnfNodePtr> &forward_graph_params,
|
||||
const std::vector<AnfNodePtr> &weight_args) {
|
||||
FuncGraphPtr k_child = std::make_shared<FuncGraph>();
|
||||
k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
|
||||
auto weights_node = weights;
|
||||
if (weights == nullptr && !args.empty()) {
|
||||
weights_node = ret->NewCNode(args);
|
||||
AnfNodePtr weights_node = nullptr;
|
||||
if (weights != nullptr) {
|
||||
weights_node = weights;
|
||||
} else if (!weight_args.empty()) {
|
||||
weights_node = k_child->NewCNode(weight_args);
|
||||
}
|
||||
|
||||
ValueNodePtr opsJ = NewValueNode(prim::kPrimJ);
|
||||
ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem);
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
if (applyJ) {
|
||||
inputs.push_back(opsJ);
|
||||
inputs.push_back(node);
|
||||
node = ret->NewCNode(inputs);
|
||||
inputs.push_back(k);
|
||||
for (size_t i = 0; i < forward_graph_params.size(); ++i) {
|
||||
inputs.push_back(k_child->add_parameter());
|
||||
}
|
||||
auto k_app = k_child->NewCNode(inputs);
|
||||
|
||||
std::vector<AnfNodePtr> params;
|
||||
for (size_t i = 0; i < params_list.size(); ++i) {
|
||||
params.push_back(ret->add_parameter());
|
||||
}
|
||||
auto tuple_get_item = NewValueNode(prim::kPrimTupleGetItem);
|
||||
auto f_app = k_child->NewCNode({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(0))});
|
||||
auto bprop = k_child->NewCNode({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(1))});
|
||||
|
||||
inputs.clear();
|
||||
inputs.push_back(node);
|
||||
(void)std::copy(params.begin(), params.end(), std::back_inserter(inputs));
|
||||
AnfNodePtr cnode = ret->NewCNode(inputs);
|
||||
|
||||
inputs.clear();
|
||||
inputs.push_back(opsTupleItem);
|
||||
inputs.push_back(cnode);
|
||||
inputs.push_back(NewValueNode(static_cast<int64_t>(0)));
|
||||
auto out = ret->NewCNode(inputs);
|
||||
|
||||
inputs.clear();
|
||||
inputs.push_back(opsTupleItem);
|
||||
inputs.push_back(cnode);
|
||||
inputs.push_back(NewValueNode(static_cast<int64_t>(1)));
|
||||
AnfNodePtr ptr_bprop = ret->NewCNode(inputs);
|
||||
|
||||
doGetGrad(ret, out, ptr_bprop, weights_node, opsTupleItem);
|
||||
return ret;
|
||||
GradByParameter(k_child, f_app, bprop, weights_node);
|
||||
return k_child;
|
||||
}
|
||||
|
||||
void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptr_bprop, AnfNodePtr weights,
|
||||
ValueNodePtr opsTupleItem) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// Do grad by the parameter of GradOperation.
|
||||
void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
|
||||
const AnfNodePtr &weights) {
|
||||
MS_EXCEPTION_IF_NULL(k_child);
|
||||
|
||||
AnfNodePtr ptr_bprop_arg = nullptr;
|
||||
AnfNodePtr bprop_arg = nullptr;
|
||||
if (sens_param_) {
|
||||
ptr_bprop_arg = func_graph->add_parameter();
|
||||
bprop_arg = k_child->add_parameter();
|
||||
} else {
|
||||
auto ones_like = prim::GetPythonOps("ones_like");
|
||||
ptr_bprop_arg = func_graph->NewCNode({NewValueNode(ones_like), out});
|
||||
bprop_arg = k_child->NewCNode({NewValueNode(ones_like), f_app});
|
||||
}
|
||||
|
||||
AnfNodePtr ptr_bapp = func_graph->NewCNode({ptr_bprop, ptr_bprop_arg});
|
||||
AnfNodePtr b_app = k_child->NewCNode({bprop, bprop_arg});
|
||||
|
||||
CNodePtr fv_bprop = nullptr;
|
||||
if (get_by_list_) {
|
||||
// python code: grads = hyper_map(F.partial(env_get, env), weights)
|
||||
AnfNodePtr env =
|
||||
func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptr_bapp, NewValueNode(static_cast<int64_t>(0))});
|
||||
k_child->NewCNode({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast<int64_t>(0))});
|
||||
AnfNodePtr partial_env_get =
|
||||
func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
|
||||
k_child->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
|
||||
MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
|
||||
fv_bprop = func_graph->NewCNode({NewValueNode(hyper_map), partial_env_get, weights});
|
||||
fv_bprop = k_child->NewCNode({NewValueNode(hyper_map), partial_env_get, weights});
|
||||
}
|
||||
|
||||
CNodePtr inputs_bprop = nullptr;
|
||||
if (get_all_) {
|
||||
TailPtr tail = std::make_shared<Tail>("tail", true);
|
||||
inputs_bprop = func_graph->NewCNode({NewValueNode(tail), ptr_bapp});
|
||||
inputs_bprop = k_child->NewCNode({NewValueNode(tail), b_app});
|
||||
}
|
||||
|
||||
// Gradients wrt inputs and parameters
|
||||
if (fv_bprop != nullptr && inputs_bprop != nullptr) {
|
||||
func_graph->set_output(func_graph->NewCNode({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop}));
|
||||
k_child->set_output(k_child->NewCNode({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop}));
|
||||
return;
|
||||
}
|
||||
|
||||
// Gradients wrt parameters
|
||||
if (fv_bprop != nullptr) {
|
||||
func_graph->set_output(fv_bprop);
|
||||
k_child->set_output(fv_bprop);
|
||||
return;
|
||||
}
|
||||
|
||||
// Gradients wrt inputs
|
||||
if (inputs_bprop != nullptr) {
|
||||
func_graph->set_output(inputs_bprop);
|
||||
k_child->set_output(inputs_bprop);
|
||||
return;
|
||||
}
|
||||
|
||||
// Gradients wrt first input.
|
||||
// ptr_bapp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
|
||||
func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptr_bapp, NewValueNode(static_cast<int64_t>(1))}));
|
||||
// b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
|
||||
k_child->set_output(
|
||||
k_child->NewCNode({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast<int64_t>(1))}));
|
||||
}
|
||||
|
||||
// Generate the graph.
|
||||
|
@ -643,39 +626,39 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
|
|||
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
|
||||
MS_EXCEPTION_IF_NULL(real_fn);
|
||||
|
||||
FuncGraphPtr ptr_graph = real_fn->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(ptr_graph);
|
||||
FuncGraphPtr df_builder = nullptr;
|
||||
FuncGraphPtr forward_graph = real_fn->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(forward_graph);
|
||||
FuncGraphPtr grad_fg = nullptr;
|
||||
{
|
||||
TraceGuard g(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
|
||||
df_builder = std::make_shared<FuncGraph>();
|
||||
TraceGuard g(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
|
||||
grad_fg = std::make_shared<FuncGraph>();
|
||||
}
|
||||
auto nparam = ptr_graph->parameters().size();
|
||||
auto nparam = forward_graph->parameters().size();
|
||||
|
||||
std::ostringstream ss;
|
||||
ss << "grad{" << nparam << "}";
|
||||
df_builder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
df_builder->debug_info()->set_name(ss.str());
|
||||
ParameterPtr param_graph = df_builder->add_parameter();
|
||||
grad_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
grad_fg->debug_info()->set_name(ss.str());
|
||||
ParameterPtr param_graph = grad_fg->add_parameter();
|
||||
|
||||
AnfNodePtr weights = nullptr;
|
||||
if (get_by_list_) {
|
||||
weights = df_builder->add_parameter();
|
||||
weights = grad_fg->add_parameter();
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(prim::kPrimJ));
|
||||
inputs.push_back(param_graph);
|
||||
auto jf = df_builder->NewCNode(inputs);
|
||||
auto j = grad_fg->NewCNode(inputs);
|
||||
// df is checked in GetGrad
|
||||
FuncGraphPtr df = nullptr;
|
||||
FuncGraphPtr k_child = nullptr;
|
||||
{
|
||||
TraceGuard guard(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
|
||||
df = GetGrad(jf, weights, ptr_graph->parameters());
|
||||
TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
|
||||
k_child = GetGrad(j, weights, forward_graph->parameters());
|
||||
}
|
||||
df_builder->set_output(NewValueNode(df));
|
||||
grad_fg->set_output(NewValueNode(k_child));
|
||||
|
||||
return df_builder;
|
||||
return grad_fg;
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
|
||||
|
|
|
@ -140,8 +140,10 @@ class GradOperation : public MetaFuncGraph {
|
|||
~GradOperation() override = default;
|
||||
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
|
||||
|
||||
FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector<AnfNodePtr> &ptrParams,
|
||||
const std::vector<AnfNodePtr> &args = {}, bool applyJ = false);
|
||||
FuncGraphPtr GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
|
||||
const std::vector<AnfNodePtr> &forward_graph_params,
|
||||
const std::vector<AnfNodePtr> &weight_args = {});
|
||||
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||
bool sens_param() const { return sens_param_; }
|
||||
bool get_all_;
|
||||
|
@ -149,8 +151,8 @@ class GradOperation : public MetaFuncGraph {
|
|||
bool sens_param_;
|
||||
|
||||
private:
|
||||
void doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights,
|
||||
ValueNodePtr opsTupleItem);
|
||||
void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
|
||||
const AnfNodePtr &weights);
|
||||
};
|
||||
using GradOperationPtr = std::shared_ptr<GradOperation>;
|
||||
|
||||
|
|
|
@ -54,13 +54,13 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
|
|||
f->MapObject();
|
||||
f->MapMorphism();
|
||||
f->Finish();
|
||||
auto ret = f->k_graph();
|
||||
auto res = f->k_graph();
|
||||
if (is_top) {
|
||||
DFunctor::Clear();
|
||||
}
|
||||
|
||||
multi_graph_sink(ret);
|
||||
return ret;
|
||||
multi_graph_sink(res);
|
||||
return res;
|
||||
}
|
||||
|
||||
FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
|
||||
|
|
Loading…
Reference in New Issue