!22923 Optimize codes for ms function

Merge pull request !22923 from JoyLvliang/optimize_codes_for_ms_function
This commit is contained in:
i-robot 2021-09-08 12:31:31 +00:00 committed by Gitee
commit ddd17bc09d
6 changed files with 53 additions and 39 deletions

View File

@ -58,6 +58,27 @@
namespace mindspore {
namespace pipeline {
namespace {
void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> new_paras;
for (const auto &param : func_graph->parameters()) {
auto param_node = param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_node);
if (param_node->has_default()) {
new_paras.push_back(param_node);
continue;
}
AbstractBasePtr par_abs = param_node->abstract();
MS_EXCEPTION_IF_NULL(par_abs);
if (par_abs->isa<abstract::AbstractUndetermined>() ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && par_abs->BuildType() != nullptr &&
par_abs->BuildType()->isa<Number>())) {
new_paras.push_back(param_node);
}
}
func_graph->set_parameters(new_paras);
}
// Disable mindRT in the control flow scenario.
void ResetMindRTEnable(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
@ -161,6 +182,7 @@ void ModifyOutputNode(const FuncGraphPtr &func_graph) {
func_graph->set_output(merge_node);
// Clear
func_graph->set_modify_output(true);
func_graph->ClearUsedForwardNodes();
}
} // namespace
@ -541,6 +563,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
if (loaded_graph_ptr != nullptr) {
CheckRootInputShapeAndType(res, loaded_graph_ptr);
}
UpdateFuncGraphParameter(new_fg);
MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true);
return true;
}

View File

@ -632,33 +632,11 @@ bool CconvPass(const ResourcePtr &res) {
bool PipelineSplitPass(const ResourcePtr &res) { return PipelineSplit(res); }
void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> new_paras;
for (const auto &param : func_graph->parameters()) {
auto param_node = param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_node);
if (param_node->has_default()) {
new_paras.push_back(param_node);
continue;
}
AbstractBasePtr par_abs = param_node->abstract();
MS_EXCEPTION_IF_NULL(par_abs);
if (par_abs->isa<abstract::AbstractUndetermined>() ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && par_abs->BuildType() != nullptr &&
par_abs->BuildType()->isa<Number>())) {
new_paras.push_back(param_node);
}
}
func_graph->set_parameters(new_paras);
}
bool ValidatePass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();
Validate(func_graph);
UpdateFuncGraphParameter(func_graph);
return true;
}

View File

@ -1574,6 +1574,9 @@ void GradExecutor::MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, con
auto same_param = graph_info->params.at(param_name);
manage->Replace(anf_node, same_param);
param = same_param;
} else {
(void)df_builder->add_parameter(param);
param->debug_info()->set_name(param_name);
}
new_params.push_back(param);
input_nodes.emplace_back(param);
@ -2917,7 +2920,6 @@ void GradExecutor::GradMsFunctionInner(const std::string &phase, const py::objec
UpdateMsFunctionForwardTensors(op_exec_info, added_out);
return;
}
MS_LOG(DEBUG) << "Ms func graph run firstly. The graph phase is: " << graph_phase();
if (!need_construct_graph()) {
MS_LOG(EXCEPTION) << "The flag of need construct graph is False.";
@ -2935,26 +2937,35 @@ void GradExecutor::GradMsFunctionInner(const std::string &phase, const py::objec
MakeAdjointForMsFunction(new_ms_func_graph, new_grad_graph, actual_out, args, actual_out_v);
}
void GradExecutor::GradMsFunction(const py::object &out, const py::args &args) {
if (!grad_flag_) {
MS_LOG(DEBUG) << "Only run forward infer computation, no need to construct grad graph.";
return;
}
py::object GradExecutor::GradMsFunction(const py::object &out, const py::args &args) {
// Get actual forward output object.
if (graph_phase().empty()) {
MS_LOG(EXCEPTION) << "The graph phase is empty, can not obtain ms_function func graph.";
}
// Get ms_function func graph and grad graph.
const auto &phase = graph_phase();
MS_LOG(DEBUG) << "ms_function func graph phase: " << phase;
auto executor = pipeline::GraphExecutorPy::GetInstance();
MS_EXCEPTION_IF_NULL(executor);
FuncGraphPtr ms_func_graph = executor->GetFuncGraph(phase);
MS_EXCEPTION_IF_NULL(ms_func_graph);
py::object ret = out;
if (ms_func_graph->modify_output()) {
auto tuple_out = py::cast<py::tuple>(out);
ret = tuple_out[0];
}
// Make Adjoint for grad graph of ms_function.
if (!grad_flag_) {
MS_LOG(DEBUG) << "Only run forward infer computation, no need to construct grad graph.";
set_graph_phase("");
return ret;
}
FuncGraphPtr grad_graph = executor->GetGradGraph(phase);
MS_EXCEPTION_IF_NULL(grad_graph);
GradMsFunctionInner(phase, out, args, ms_func_graph, grad_graph);
set_graph_phase("");
return ret;
}
void GradExecutor::ClearGrad(const py::object &cell, const py::args &args) {
@ -3083,8 +3094,8 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
MS_LOG(DEBUG) << "Leave end graph process.";
}
void PynativeExecutor::GradMsFunction(const py::object &out, const py::args &args) {
grad_executor()->GradMsFunction(out, args);
py::object PynativeExecutor::GradMsFunction(const py::object &out, const py::args &args) {
return grad_executor()->GradMsFunction(out, args);
}
void PynativeExecutor::GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,

View File

@ -194,7 +194,7 @@ class GradExecutor {
// Construct grad graph for ms_function
bool eliminate_forward() const { return eliminate_forward_; }
void set_eliminate_forward(bool eliminate_forward) { eliminate_forward_ = eliminate_forward; }
void GradMsFunction(const py::object &out, const py::args &args);
py::object GradMsFunction(const py::object &out, const py::args &args);
void GradMsFunctionInner(const std::string &phase, const py::object &out, const py::args &args,
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph);
void UpdateMsFunctionForwardTensors(const OpExecInfoPtr &op_exec_info, const ValuePtr &new_forward_value);
@ -370,11 +370,11 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void set_graph_phase(const std::string &graph_phase);
void set_py_exe_path(const py::object &py_exe_path);
void set_kernel_build_server_dir(const py::object &kernel_build_server_dir);
void GradMsFunction(const py::object &out, const py::args &args);
void NewGraph(const py::object &cell, const py::args &args);
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
void GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args);
py::object GradMsFunction(const py::object &out, const py::args &args);
py::object CheckGraph(const py::object &cell, const py::args &args);
py::object CheckAlreadyRun(const py::object &cell, const py::args &args);
py::object Run(const py::object &cell, const py::tuple &args);

View File

@ -209,9 +209,7 @@ class _MindsporeFunctionExecutor:
if context.get_context("mode") == context.PYNATIVE_MODE:
_pynative_executor.set_graph_phase(phase)
_pynative_executor.grad_ms_function(output, *new_inputs)
if phase.find("export") == -1:
output = output[0]
output = _pynative_executor.grad_ms_function(output, *new_inputs)
return output
@ -395,7 +393,7 @@ class _PynativeExecutor:
self._executor.sync()
def grad_ms_function(self, output, *args):
self._executor.grad_ms_function(output, *args)
return self._executor.grad_ms_function(output, *args)
def set_graph_phase(self, phase):
self._executor.set_graph_phase(phase)

View File

@ -405,6 +405,8 @@ class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfo
std::string bprop_hash() const { return bprop_hash_; }
void set_bprop_hash(const std::string &bprop_hash) { bprop_hash_ = bprop_hash; }
bool modify_output() const { return modify_output_; }
void set_modify_output(bool modify_output) { modify_output_ = modify_output; }
const std::unordered_set<AnfNodePtr> &used_forward_nodes() const { return used_forward_nodes_; }
void set_used_forward_nodes(const std::vector<AnfNodePtr> &used_forward_nodes);
void ClearUsedForwardNodes() { used_forward_nodes_.clear(); }
@ -496,6 +498,7 @@ class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfo
// If the graph is decorated by @ms_function and runs grad process in pynative mode,
// forward nodes used in grad graph will be added to output for holding output values.
bool modify_output_ = false;
std::unordered_set<AnfNodePtr> used_forward_nodes_;
};