forked from mindspore-Ecosystem/mindspore
!22923 Optimize codes for ms function
Merge pull request !22923 from JoyLvliang/optimize_codes_for_ms_function
This commit is contained in:
commit
ddd17bc09d
|
@ -58,6 +58,27 @@
|
|||
namespace mindspore {
|
||||
namespace pipeline {
|
||||
namespace {
|
||||
void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> new_paras;
|
||||
for (const auto ¶m : func_graph->parameters()) {
|
||||
auto param_node = param->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
if (param_node->has_default()) {
|
||||
new_paras.push_back(param_node);
|
||||
continue;
|
||||
}
|
||||
AbstractBasePtr par_abs = param_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(par_abs);
|
||||
if (par_abs->isa<abstract::AbstractUndetermined>() ||
|
||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && par_abs->BuildType() != nullptr &&
|
||||
par_abs->BuildType()->isa<Number>())) {
|
||||
new_paras.push_back(param_node);
|
||||
}
|
||||
}
|
||||
func_graph->set_parameters(new_paras);
|
||||
}
|
||||
|
||||
// Disable mindRT in the control flow scenario.
|
||||
void ResetMindRTEnable(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
|
@ -161,6 +182,7 @@ void ModifyOutputNode(const FuncGraphPtr &func_graph) {
|
|||
func_graph->set_output(merge_node);
|
||||
|
||||
// Clear
|
||||
func_graph->set_modify_output(true);
|
||||
func_graph->ClearUsedForwardNodes();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -541,6 +563,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|||
if (loaded_graph_ptr != nullptr) {
|
||||
CheckRootInputShapeAndType(res, loaded_graph_ptr);
|
||||
}
|
||||
|
||||
UpdateFuncGraphParameter(new_fg);
|
||||
MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true);
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -632,33 +632,11 @@ bool CconvPass(const ResourcePtr &res) {
|
|||
|
||||
bool PipelineSplitPass(const ResourcePtr &res) { return PipelineSplit(res); }
|
||||
|
||||
void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> new_paras;
|
||||
for (const auto ¶m : func_graph->parameters()) {
|
||||
auto param_node = param->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
if (param_node->has_default()) {
|
||||
new_paras.push_back(param_node);
|
||||
continue;
|
||||
}
|
||||
AbstractBasePtr par_abs = param_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(par_abs);
|
||||
if (par_abs->isa<abstract::AbstractUndetermined>() ||
|
||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && par_abs->BuildType() != nullptr &&
|
||||
par_abs->BuildType()->isa<Number>())) {
|
||||
new_paras.push_back(param_node);
|
||||
}
|
||||
}
|
||||
func_graph->set_parameters(new_paras);
|
||||
}
|
||||
|
||||
bool ValidatePass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
Validate(func_graph);
|
||||
UpdateFuncGraphParameter(func_graph);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -1574,6 +1574,9 @@ void GradExecutor::MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, con
|
|||
auto same_param = graph_info->params.at(param_name);
|
||||
manage->Replace(anf_node, same_param);
|
||||
param = same_param;
|
||||
} else {
|
||||
(void)df_builder->add_parameter(param);
|
||||
param->debug_info()->set_name(param_name);
|
||||
}
|
||||
new_params.push_back(param);
|
||||
input_nodes.emplace_back(param);
|
||||
|
@ -2917,7 +2920,6 @@ void GradExecutor::GradMsFunctionInner(const std::string &phase, const py::objec
|
|||
UpdateMsFunctionForwardTensors(op_exec_info, added_out);
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Ms func graph run firstly. The graph phase is: " << graph_phase();
|
||||
if (!need_construct_graph()) {
|
||||
MS_LOG(EXCEPTION) << "The flag of need construct graph is False.";
|
||||
|
@ -2935,26 +2937,35 @@ void GradExecutor::GradMsFunctionInner(const std::string &phase, const py::objec
|
|||
MakeAdjointForMsFunction(new_ms_func_graph, new_grad_graph, actual_out, args, actual_out_v);
|
||||
}
|
||||
|
||||
void GradExecutor::GradMsFunction(const py::object &out, const py::args &args) {
|
||||
if (!grad_flag_) {
|
||||
MS_LOG(DEBUG) << "Only run forward infer computation, no need to construct grad graph.";
|
||||
return;
|
||||
}
|
||||
py::object GradExecutor::GradMsFunction(const py::object &out, const py::args &args) {
|
||||
// Get actual forward output object.
|
||||
if (graph_phase().empty()) {
|
||||
MS_LOG(EXCEPTION) << "The graph phase is empty, can not obtain ms_function func graph.";
|
||||
}
|
||||
|
||||
// Get ms_function func graph and grad graph.
|
||||
const auto &phase = graph_phase();
|
||||
MS_LOG(DEBUG) << "ms_function func graph phase: " << phase;
|
||||
auto executor = pipeline::GraphExecutorPy::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(executor);
|
||||
FuncGraphPtr ms_func_graph = executor->GetFuncGraph(phase);
|
||||
MS_EXCEPTION_IF_NULL(ms_func_graph);
|
||||
py::object ret = out;
|
||||
if (ms_func_graph->modify_output()) {
|
||||
auto tuple_out = py::cast<py::tuple>(out);
|
||||
ret = tuple_out[0];
|
||||
}
|
||||
|
||||
// Make Adjoint for grad graph of ms_function.
|
||||
if (!grad_flag_) {
|
||||
MS_LOG(DEBUG) << "Only run forward infer computation, no need to construct grad graph.";
|
||||
set_graph_phase("");
|
||||
return ret;
|
||||
}
|
||||
FuncGraphPtr grad_graph = executor->GetGradGraph(phase);
|
||||
MS_EXCEPTION_IF_NULL(grad_graph);
|
||||
|
||||
GradMsFunctionInner(phase, out, args, ms_func_graph, grad_graph);
|
||||
|
||||
set_graph_phase("");
|
||||
return ret;
|
||||
}
|
||||
|
||||
void GradExecutor::ClearGrad(const py::object &cell, const py::args &args) {
|
||||
|
@ -3083,8 +3094,8 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
|
|||
MS_LOG(DEBUG) << "Leave end graph process.";
|
||||
}
|
||||
|
||||
void PynativeExecutor::GradMsFunction(const py::object &out, const py::args &args) {
|
||||
grad_executor()->GradMsFunction(out, args);
|
||||
py::object PynativeExecutor::GradMsFunction(const py::object &out, const py::args &args) {
|
||||
return grad_executor()->GradMsFunction(out, args);
|
||||
}
|
||||
|
||||
void PynativeExecutor::GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||
|
|
|
@ -194,7 +194,7 @@ class GradExecutor {
|
|||
// Construct grad graph for ms_function
|
||||
bool eliminate_forward() const { return eliminate_forward_; }
|
||||
void set_eliminate_forward(bool eliminate_forward) { eliminate_forward_ = eliminate_forward; }
|
||||
void GradMsFunction(const py::object &out, const py::args &args);
|
||||
py::object GradMsFunction(const py::object &out, const py::args &args);
|
||||
void GradMsFunctionInner(const std::string &phase, const py::object &out, const py::args &args,
|
||||
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph);
|
||||
void UpdateMsFunctionForwardTensors(const OpExecInfoPtr &op_exec_info, const ValuePtr &new_forward_value);
|
||||
|
@ -370,11 +370,11 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void set_graph_phase(const std::string &graph_phase);
|
||||
void set_py_exe_path(const py::object &py_exe_path);
|
||||
void set_kernel_build_server_dir(const py::object &kernel_build_server_dir);
|
||||
void GradMsFunction(const py::object &out, const py::args &args);
|
||||
void NewGraph(const py::object &cell, const py::args &args);
|
||||
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
|
||||
void GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||
const py::args &args);
|
||||
py::object GradMsFunction(const py::object &out, const py::args &args);
|
||||
py::object CheckGraph(const py::object &cell, const py::args &args);
|
||||
py::object CheckAlreadyRun(const py::object &cell, const py::args &args);
|
||||
py::object Run(const py::object &cell, const py::tuple &args);
|
||||
|
|
|
@ -209,9 +209,7 @@ class _MindsporeFunctionExecutor:
|
|||
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
_pynative_executor.set_graph_phase(phase)
|
||||
_pynative_executor.grad_ms_function(output, *new_inputs)
|
||||
if phase.find("export") == -1:
|
||||
output = output[0]
|
||||
output = _pynative_executor.grad_ms_function(output, *new_inputs)
|
||||
|
||||
return output
|
||||
|
||||
|
@ -395,7 +393,7 @@ class _PynativeExecutor:
|
|||
self._executor.sync()
|
||||
|
||||
def grad_ms_function(self, output, *args):
|
||||
self._executor.grad_ms_function(output, *args)
|
||||
return self._executor.grad_ms_function(output, *args)
|
||||
|
||||
def set_graph_phase(self, phase):
|
||||
self._executor.set_graph_phase(phase)
|
||||
|
|
|
@ -405,6 +405,8 @@ class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfo
|
|||
std::string bprop_hash() const { return bprop_hash_; }
|
||||
void set_bprop_hash(const std::string &bprop_hash) { bprop_hash_ = bprop_hash; }
|
||||
|
||||
bool modify_output() const { return modify_output_; }
|
||||
void set_modify_output(bool modify_output) { modify_output_ = modify_output; }
|
||||
const std::unordered_set<AnfNodePtr> &used_forward_nodes() const { return used_forward_nodes_; }
|
||||
void set_used_forward_nodes(const std::vector<AnfNodePtr> &used_forward_nodes);
|
||||
void ClearUsedForwardNodes() { used_forward_nodes_.clear(); }
|
||||
|
@ -496,6 +498,7 @@ class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfo
|
|||
|
||||
// If the graph is decorated by @ms_function and runs grad process in pynative mode,
|
||||
// forward nodes used in grad graph will be added to output for holding output values.
|
||||
bool modify_output_ = false;
|
||||
std::unordered_set<AnfNodePtr> used_forward_nodes_;
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue