!12071 Refactor grad operation implementation.

From: @zh_qh
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-04 21:23:07 +08:00 committed by Gitee
commit 623cca4214
3 changed files with 62 additions and 77 deletions

View File

@ -525,105 +525,88 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_
}
}
FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
const std::vector<AnfNodePtr> &params_list, const std::vector<AnfNodePtr> &args,
bool applyJ) {
FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
const std::vector<AnfNodePtr> &forward_graph_params,
const std::vector<AnfNodePtr> &weight_args) {
FuncGraphPtr k_child = std::make_shared<FuncGraph>();
k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
auto weights_node = weights;
if (weights == nullptr && !args.empty()) {
weights_node = ret->NewCNode(args);
AnfNodePtr weights_node = nullptr;
if (weights != nullptr) {
weights_node = weights;
} else if (!weight_args.empty()) {
weights_node = k_child->NewCNode(weight_args);
}
ValueNodePtr opsJ = NewValueNode(prim::kPrimJ);
ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem);
std::vector<AnfNodePtr> inputs;
if (applyJ) {
inputs.push_back(opsJ);
inputs.push_back(node);
node = ret->NewCNode(inputs);
inputs.push_back(k);
for (size_t i = 0; i < forward_graph_params.size(); ++i) {
inputs.push_back(k_child->add_parameter());
}
auto k_app = k_child->NewCNode(inputs);
std::vector<AnfNodePtr> params;
for (size_t i = 0; i < params_list.size(); ++i) {
params.push_back(ret->add_parameter());
}
auto tuple_get_item = NewValueNode(prim::kPrimTupleGetItem);
auto f_app = k_child->NewCNode({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(0))});
auto bprop = k_child->NewCNode({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(1))});
inputs.clear();
inputs.push_back(node);
(void)std::copy(params.begin(), params.end(), std::back_inserter(inputs));
AnfNodePtr cnode = ret->NewCNode(inputs);
inputs.clear();
inputs.push_back(opsTupleItem);
inputs.push_back(cnode);
inputs.push_back(NewValueNode(static_cast<int64_t>(0)));
auto out = ret->NewCNode(inputs);
inputs.clear();
inputs.push_back(opsTupleItem);
inputs.push_back(cnode);
inputs.push_back(NewValueNode(static_cast<int64_t>(1)));
AnfNodePtr ptr_bprop = ret->NewCNode(inputs);
doGetGrad(ret, out, ptr_bprop, weights_node, opsTupleItem);
return ret;
GradByParameter(k_child, f_app, bprop, weights_node);
return k_child;
}
void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptr_bprop, AnfNodePtr weights,
ValueNodePtr opsTupleItem) {
MS_EXCEPTION_IF_NULL(func_graph);
// Do grad by the parameter of GradOperation.
void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
const AnfNodePtr &weights) {
MS_EXCEPTION_IF_NULL(k_child);
AnfNodePtr ptr_bprop_arg = nullptr;
AnfNodePtr bprop_arg = nullptr;
if (sens_param_) {
ptr_bprop_arg = func_graph->add_parameter();
bprop_arg = k_child->add_parameter();
} else {
auto ones_like = prim::GetPythonOps("ones_like");
ptr_bprop_arg = func_graph->NewCNode({NewValueNode(ones_like), out});
bprop_arg = k_child->NewCNode({NewValueNode(ones_like), f_app});
}
AnfNodePtr ptr_bapp = func_graph->NewCNode({ptr_bprop, ptr_bprop_arg});
AnfNodePtr b_app = k_child->NewCNode({bprop, bprop_arg});
CNodePtr fv_bprop = nullptr;
if (get_by_list_) {
// python code: grads = hyper_map(F.partial(env_get, env), weights)
AnfNodePtr env =
func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptr_bapp, NewValueNode(static_cast<int64_t>(0))});
k_child->NewCNode({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast<int64_t>(0))});
AnfNodePtr partial_env_get =
func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
k_child->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
fv_bprop = func_graph->NewCNode({NewValueNode(hyper_map), partial_env_get, weights});
fv_bprop = k_child->NewCNode({NewValueNode(hyper_map), partial_env_get, weights});
}
CNodePtr inputs_bprop = nullptr;
if (get_all_) {
TailPtr tail = std::make_shared<Tail>("tail", true);
inputs_bprop = func_graph->NewCNode({NewValueNode(tail), ptr_bapp});
inputs_bprop = k_child->NewCNode({NewValueNode(tail), b_app});
}
// Gradients wrt inputs and parameters
if (fv_bprop != nullptr && inputs_bprop != nullptr) {
func_graph->set_output(func_graph->NewCNode({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop}));
k_child->set_output(k_child->NewCNode({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop}));
return;
}
// Gradients wrt parameters
if (fv_bprop != nullptr) {
func_graph->set_output(fv_bprop);
k_child->set_output(fv_bprop);
return;
}
// Gradients wrt inputs
if (inputs_bprop != nullptr) {
func_graph->set_output(inputs_bprop);
k_child->set_output(inputs_bprop);
return;
}
// Gradients wrt first input.
// ptr_bapp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptr_bapp, NewValueNode(static_cast<int64_t>(1))}));
// b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
k_child->set_output(
k_child->NewCNode({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast<int64_t>(1))}));
}
// Generate the graph.
@ -643,39 +626,39 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
MS_EXCEPTION_IF_NULL(real_fn);
FuncGraphPtr ptr_graph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(ptr_graph);
FuncGraphPtr df_builder = nullptr;
FuncGraphPtr forward_graph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(forward_graph);
FuncGraphPtr grad_fg = nullptr;
{
TraceGuard g(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
df_builder = std::make_shared<FuncGraph>();
TraceGuard g(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
grad_fg = std::make_shared<FuncGraph>();
}
auto nparam = ptr_graph->parameters().size();
auto nparam = forward_graph->parameters().size();
std::ostringstream ss;
ss << "grad{" << nparam << "}";
df_builder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
df_builder->debug_info()->set_name(ss.str());
ParameterPtr param_graph = df_builder->add_parameter();
grad_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
grad_fg->debug_info()->set_name(ss.str());
ParameterPtr param_graph = grad_fg->add_parameter();
AnfNodePtr weights = nullptr;
if (get_by_list_) {
weights = df_builder->add_parameter();
weights = grad_fg->add_parameter();
}
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimJ));
inputs.push_back(param_graph);
auto jf = df_builder->NewCNode(inputs);
auto j = grad_fg->NewCNode(inputs);
// df is checked in GetGrad
FuncGraphPtr df = nullptr;
FuncGraphPtr k_child = nullptr;
{
TraceGuard guard(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
df = GetGrad(jf, weights, ptr_graph->parameters());
TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
k_child = GetGrad(j, weights, forward_graph->parameters());
}
df_builder->set_output(NewValueNode(df));
grad_fg->set_output(NewValueNode(k_child));
return df_builder;
return grad_fg;
}
REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {

View File

@ -140,8 +140,10 @@ class GradOperation : public MetaFuncGraph {
~GradOperation() override = default;
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector<AnfNodePtr> &ptrParams,
const std::vector<AnfNodePtr> &args = {}, bool applyJ = false);
FuncGraphPtr GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
const std::vector<AnfNodePtr> &forward_graph_params,
const std::vector<AnfNodePtr> &weight_args = {});
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
bool sens_param() const { return sens_param_; }
bool get_all_;
@ -149,8 +151,8 @@ class GradOperation : public MetaFuncGraph {
bool sens_param_;
private:
void doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights,
ValueNodePtr opsTupleItem);
void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
const AnfNodePtr &weights);
};
using GradOperationPtr = std::shared_ptr<GradOperation>;

View File

@ -54,13 +54,13 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
f->MapObject();
f->MapMorphism();
f->Finish();
auto ret = f->k_graph();
auto res = f->k_graph();
if (is_top) {
DFunctor::Clear();
}
multi_graph_sink(ret);
return ret;
multi_graph_sink(res);
return res;
}
FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {