forked from mindspore-Ecosystem/mindspore
!46108 Add pynative front dynamic detect function form master
Merge pull request !46108 from wanghenchang/front-dynamic-detect-master1
This commit is contained in:
commit
c189903d04
|
@ -688,6 +688,7 @@ void MindRTBackend::RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_i
|
||||||
graph_compiler_->CalculateForwardOpOutputCount(graph, inputs[graph_index], &forward_op_output_tensor_id_);
|
graph_compiler_->CalculateForwardOpOutputCount(graph, inputs[graph_index], &forward_op_output_tensor_id_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool use_dynamic_shape_process = root_graph_->has_flag(kFlagUseDynamicShapeProcess);
|
||||||
py::gil_scoped_release release;
|
py::gil_scoped_release release;
|
||||||
for (const auto &kernel : graph->execution_order()) {
|
for (const auto &kernel : graph->execution_order()) {
|
||||||
InputTensorInfo input_tensor_info;
|
InputTensorInfo input_tensor_info;
|
||||||
|
@ -714,9 +715,8 @@ void MindRTBackend::RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_i
|
||||||
GraphInfo graph_info;
|
GraphInfo graph_info;
|
||||||
graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index],
|
graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index],
|
||||||
&input_tensor_info);
|
&input_tensor_info);
|
||||||
graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info, &op_run_info, &graph_info,
|
graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info, use_dynamic_shape_process,
|
||||||
&graph_output_info);
|
&op_run_info, &graph_info, &graph_output_info);
|
||||||
bool use_dynamic_shape_process = op_run_info->base_op_run_info.use_dynamic_shape_process;
|
|
||||||
if (use_dynamic_shape_process) {
|
if (use_dynamic_shape_process) {
|
||||||
RunOpDynamic(op_run_info, &op_outputs);
|
RunOpDynamic(op_run_info, &op_outputs);
|
||||||
} else {
|
} else {
|
||||||
|
@ -751,7 +751,8 @@ void MindRTBackend::RunGraphByCondition(const ActorInfo &actor_info, const Graph
|
||||||
}
|
}
|
||||||
|
|
||||||
if (contain_cut_graph || root_graph_->has_flag(kFlagIsDynamicStructure) ||
|
if (contain_cut_graph || root_graph_->has_flag(kFlagIsDynamicStructure) ||
|
||||||
(enable_backend_dynamic_detect_ && root_graph_->has_flag(kFlagIsPynativeBpropGraph) && is_dynamic)) {
|
(enable_backend_dynamic_detect_ && root_graph_->has_flag(kFlagIsPynativeBpropGraph) && is_dynamic) ||
|
||||||
|
root_graph_->has_flag(kFlagUseDynamicShapeProcess)) {
|
||||||
RunGraphBySingleOp(graph_compiler_info, args, outputs);
|
RunGraphBySingleOp(graph_compiler_info, args, outputs);
|
||||||
} else {
|
} else {
|
||||||
RunGraphByActors(actor_info, graph_compiler_info, args, outputs);
|
RunGraphByActors(actor_info, graph_compiler_info, args, outputs);
|
||||||
|
|
|
@ -149,7 +149,7 @@ class BACKEND_EXPORT MindRTBackendBase : public Backend {
|
||||||
|
|
||||||
// Save the mapping between cell id and actor info.
|
// Save the mapping between cell id and actor info.
|
||||||
mindspore::HashMap<std::string, ActorInfo> graph_actor_infos_;
|
mindspore::HashMap<std::string, ActorInfo> graph_actor_infos_;
|
||||||
bool enable_backend_dynamic_detect_{true};
|
bool enable_backend_dynamic_detect_{false};
|
||||||
FuncGraphPtr root_graph_;
|
FuncGraphPtr root_graph_;
|
||||||
GraphPartitionPtr graph_partition_;
|
GraphPartitionPtr graph_partition_;
|
||||||
std::shared_ptr<GraphCompiler> graph_compiler_;
|
std::shared_ptr<GraphCompiler> graph_compiler_;
|
||||||
|
|
|
@ -46,11 +46,11 @@ struct GradParam {
|
||||||
: cnode(cnode), op_args(op_args), out(out), fprop_fg(std::move(fprop_fg)) {}
|
: cnode(cnode), op_args(op_args), out(out), fprop_fg(std::move(fprop_fg)) {}
|
||||||
|
|
||||||
// Primal CNode create by op forward process
|
// Primal CNode create by op forward process
|
||||||
const CNodePtr &cnode;
|
const CNodePtr cnode;
|
||||||
// Input value for cnode
|
// Input value for cnode
|
||||||
const ValuePtrList &op_args;
|
const ValuePtrList op_args;
|
||||||
// Output of op
|
// Output of op
|
||||||
const ValuePtr &out;
|
const ValuePtr out;
|
||||||
// Bprop func graph
|
// Bprop func graph
|
||||||
const FuncGraphPtr fprop_fg;
|
const FuncGraphPtr fprop_fg;
|
||||||
// High order used this, which
|
// High order used this, which
|
||||||
|
|
|
@ -907,6 +907,7 @@ constexpr auto kFlagIsPynativeBpropGraph = "is_pynative_bprop_graph";
|
||||||
constexpr auto kFlagPyNativeRunInGraph = "pynative_run_in_graph";
|
constexpr auto kFlagPyNativeRunInGraph = "pynative_run_in_graph";
|
||||||
constexpr auto kFlagNeedRenormalize = "need_renormalize";
|
constexpr auto kFlagNeedRenormalize = "need_renormalize";
|
||||||
constexpr auto kFlagEnableZeroCopyInGraph = "enable_zero_copy_in_graph";
|
constexpr auto kFlagEnableZeroCopyInGraph = "enable_zero_copy_in_graph";
|
||||||
|
constexpr auto kFlagUseDynamicShapeProcess = "use_dynamic_shape_process";
|
||||||
// TODO(dsj): for ms_function running in graph_mode. should be delete later
|
// TODO(dsj): for ms_function running in graph_mode. should be delete later
|
||||||
constexpr auto kAttrMSFunction = "ms_function_graph";
|
constexpr auto kAttrMSFunction = "ms_function_graph";
|
||||||
|
|
||||||
|
|
|
@ -59,6 +59,7 @@ struct FrontendOpRunInfo {
|
||||||
bool grad_flag = false;
|
bool grad_flag = false;
|
||||||
bool output_get_by_infer_value = false;
|
bool output_get_by_infer_value = false;
|
||||||
int mix_type{0};
|
int mix_type{0};
|
||||||
|
size_t op_index = 0;
|
||||||
size_t input_size = 0;
|
size_t input_size = 0;
|
||||||
size_t custom_bprop_cell_count = 0;
|
size_t custom_bprop_cell_count = 0;
|
||||||
PrimitivePyPtr op_prim{nullptr};
|
PrimitivePyPtr op_prim{nullptr};
|
||||||
|
@ -88,6 +89,8 @@ struct InputArgsInfo {
|
||||||
size_t input_size;
|
size_t input_size;
|
||||||
std::string obj_id;
|
std::string obj_id;
|
||||||
bool has_sens{false};
|
bool has_sens{false};
|
||||||
|
bool is_run_cell{false};
|
||||||
|
bool use_dynamic_shape_process = false;
|
||||||
PrimitivePyPtr custom_bprp_prim{nullptr};
|
PrimitivePyPtr custom_bprp_prim{nullptr};
|
||||||
ValuePtr out_value{nullptr};
|
ValuePtr out_value{nullptr};
|
||||||
std::string cell_id;
|
std::string cell_id;
|
||||||
|
|
|
@ -274,6 +274,7 @@ ValuePtr CastOperation::DoAutoCast(const FrontendOpRunInfoPtr &op_run_info, cons
|
||||||
cast_run_info->base_op_run_info.next_op_name = op_name;
|
cast_run_info->base_op_run_info.next_op_name = op_name;
|
||||||
cast_run_info->base_op_run_info.next_input_index = index;
|
cast_run_info->base_op_run_info.next_input_index = index;
|
||||||
cast_run_info->base_op_run_info.lazy_build = op_run_info->base_op_run_info.lazy_build;
|
cast_run_info->base_op_run_info.lazy_build = op_run_info->base_op_run_info.lazy_build;
|
||||||
|
cast_run_info->base_op_run_info.use_dynamic_shape_process = op_run_info->base_op_run_info.use_dynamic_shape_process;
|
||||||
(void)cast_run_info->input_value.emplace_back(v);
|
(void)cast_run_info->input_value.emplace_back(v);
|
||||||
(void)cast_run_info->input_value.emplace_back(GetDstType(type_id));
|
(void)cast_run_info->input_value.emplace_back(GetDstType(type_id));
|
||||||
cast_run_info->input_size = input_size;
|
cast_run_info->input_size = input_size;
|
||||||
|
|
|
@ -183,10 +183,21 @@ void ForwardExecutor::RunOpForward(const FrontendOpRunInfoPtr &op_run_info) {
|
||||||
if (!op_run_info->output_get_by_infer_value) {
|
if (!op_run_info->output_get_by_infer_value) {
|
||||||
GetOutput(op_run_info);
|
GetOutput(op_run_info);
|
||||||
}
|
}
|
||||||
|
if (!op_run_info->grad_flag) {
|
||||||
|
MS_LOG(DEBUG) << "Grad flag is false";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set forward output flag for release memory,
|
||||||
|
// Because tensor address may change, it should set in main thread to ensure consistency.
|
||||||
|
PyNativeAlgo::Common::SetForwardOutputFlag(op_run_info->out_value);
|
||||||
|
|
||||||
|
// Const value no need do op grad
|
||||||
|
if (op_run_info->output_get_by_infer_value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
// 4. Do op grad and record op info
|
// 4. Do op grad and record op info
|
||||||
if (enable_async_) {
|
if (!is_ms_function_compiling_) {
|
||||||
grad()->AsyncProcessOpGradInfo(op_run_info);
|
|
||||||
} else {
|
|
||||||
grad()->ProcessOpGradInfo(op_run_info);
|
grad()->ProcessOpGradInfo(op_run_info);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -199,10 +210,13 @@ FrontendOpRunInfoPtr ForwardExecutor::GenerateOpRunInfo(const py::args &args) co
|
||||||
// Used for async run
|
// Used for async run
|
||||||
op_run_info->grad_flag = grad()->grad_flag();
|
op_run_info->grad_flag = grad()->grad_flag();
|
||||||
op_run_info->custom_bprop_cell_count = grad()->custom_bprop_cell_count();
|
op_run_info->custom_bprop_cell_count = grad()->custom_bprop_cell_count();
|
||||||
|
op_run_info->base_op_run_info.use_dynamic_shape_process =
|
||||||
|
(device_target_ == kAscendDevice ? false : grad()->use_dynamic_shape_process());
|
||||||
op_run_info->base_op_run_info.op_name = args[static_cast<size_t>(RunOpArgsEnum::PY_NAME)].cast<std::string>();
|
op_run_info->base_op_run_info.op_name = args[static_cast<size_t>(RunOpArgsEnum::PY_NAME)].cast<std::string>();
|
||||||
op_run_info->base_op_run_info.lazy_build = lazy_build_;
|
op_run_info->base_op_run_info.lazy_build = lazy_build_;
|
||||||
PyNativeAlgo::PyParser::SetPrim(op_run_info, args[static_cast<size_t>(RunOpArgsEnum::PY_PRIM)]);
|
PyNativeAlgo::PyParser::SetPrim(op_run_info, args[static_cast<size_t>(RunOpArgsEnum::PY_PRIM)]);
|
||||||
PyNativeAlgo::PyParser::ParseOpInputByPythonObj(op_run_info, args[static_cast<size_t>(RunOpArgsEnum::PY_INPUTS)]);
|
PyNativeAlgo::PyParser::ParseOpInputByPythonObj(op_run_info, args[static_cast<size_t>(RunOpArgsEnum::PY_INPUTS)]);
|
||||||
|
(void)op_run_prim_py_list_.emplace_back(op_run_info->op_prim);
|
||||||
return op_run_info;
|
return op_run_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -412,6 +426,7 @@ void ForwardExecutor::Sync() {
|
||||||
MS_EXCEPTION_IF_NULL(item.second);
|
MS_EXCEPTION_IF_NULL(item.second);
|
||||||
item.second->SyncStream();
|
item.second->SyncStream();
|
||||||
}
|
}
|
||||||
|
op_run_prim_py_list_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr ForwardExecutor::RunOpInMs(const FrontendOpRunInfoPtr &op_run_info) {
|
ValuePtr ForwardExecutor::RunOpInMs(const FrontendOpRunInfoPtr &op_run_info) {
|
||||||
|
@ -466,6 +481,7 @@ void ForwardExecutor::ClearRes() {
|
||||||
infer_operation()->ClearConstFlagPrimCache();
|
infer_operation()->ClearConstFlagPrimCache();
|
||||||
std::stack<CellPtr>().swap(forward_cell_stack_);
|
std::stack<CellPtr>().swap(forward_cell_stack_);
|
||||||
mindrt_backends_.clear();
|
mindrt_backends_.clear();
|
||||||
|
op_run_prim_py_list_.clear();
|
||||||
}
|
}
|
||||||
} // namespace pynative
|
} // namespace pynative
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <stack>
|
#include <stack>
|
||||||
|
#include <vector>
|
||||||
#include "pipeline/pynative/forward/do_cast.h"
|
#include "pipeline/pynative/forward/do_cast.h"
|
||||||
#include "pipeline/pynative/forward/do_infer.h"
|
#include "pipeline/pynative/forward/do_infer.h"
|
||||||
#include "backend/graph_compiler/backend.h"
|
#include "backend/graph_compiler/backend.h"
|
||||||
|
@ -71,6 +72,10 @@ class ForwardExecutor {
|
||||||
MS_EXCEPTION_IF_NULL(infer_operation_);
|
MS_EXCEPTION_IF_NULL(infer_operation_);
|
||||||
return infer_operation_;
|
return infer_operation_;
|
||||||
}
|
}
|
||||||
|
inline void set_is_ms_function_compiling(bool is_ms_function_compiling) {
|
||||||
|
is_ms_function_compiling_ = is_ms_function_compiling;
|
||||||
|
}
|
||||||
|
inline std::string device_target() { return device_target_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GradExecutorPtr grad() const;
|
GradExecutorPtr grad() const;
|
||||||
|
@ -94,6 +99,7 @@ class ForwardExecutor {
|
||||||
private:
|
private:
|
||||||
bool init_{false};
|
bool init_{false};
|
||||||
bool lazy_build_{false};
|
bool lazy_build_{false};
|
||||||
|
bool is_ms_function_compiling_{false};
|
||||||
uint32_t device_id_{0};
|
uint32_t device_id_{0};
|
||||||
std::string last_target_{"Unknown"};
|
std::string last_target_{"Unknown"};
|
||||||
std::string device_target_;
|
std::string device_target_;
|
||||||
|
@ -103,6 +109,7 @@ class ForwardExecutor {
|
||||||
InferOperationPtr infer_operation_;
|
InferOperationPtr infer_operation_;
|
||||||
MindrtBackendMap mindrt_backends_;
|
MindrtBackendMap mindrt_backends_;
|
||||||
bool enable_async_ = false;
|
bool enable_async_ = false;
|
||||||
|
mutable std::vector<PrimitivePyPtr> op_run_prim_py_list_;
|
||||||
};
|
};
|
||||||
} // namespace pynative
|
} // namespace pynative
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -63,8 +63,8 @@ std::string GetCellId(const py::object &obj, const py::args &args, const InputAr
|
||||||
return cell_id;
|
return cell_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
InputArgsInfoPtr GetInputArgsInfo(const py::object &obj, const py::args &args, bool is_grad_top_cell,
|
InputArgsInfoPtr ParsePyArgsToInputArgsInfo(const py::object &obj, const py::args &args, bool is_grad_top_cell,
|
||||||
bool is_high_order_top_cell) {
|
bool is_high_order_top_cell) {
|
||||||
bool has_custom_bprop = py::hasattr(obj, parse::CUSTOM_BPROP_NAME);
|
bool has_custom_bprop = py::hasattr(obj, parse::CUSTOM_BPROP_NAME);
|
||||||
const auto &obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
|
const auto &obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
|
||||||
const auto &input_args_info =
|
const auto &input_args_info =
|
||||||
|
@ -82,6 +82,7 @@ InputArgsInfoPtr GetInputArgsInfo(const py::object &obj, const py::args &args, b
|
||||||
}
|
}
|
||||||
pipeline::CheckArgsValid(obj, args);
|
pipeline::CheckArgsValid(obj, args);
|
||||||
}
|
}
|
||||||
|
input_args_info->is_run_cell = py::isinstance<Cell>(obj);
|
||||||
input_args_info->cell_id = GetCellId(obj, args, input_args_info);
|
input_args_info->cell_id = GetCellId(obj, args, input_args_info);
|
||||||
MS_LOG(DEBUG) << "cell_id is " << obj_id << ", is grad top cell " << (is_grad_top_cell || is_high_order_top_cell);
|
MS_LOG(DEBUG) << "cell_id is " << obj_id << ", is grad top cell " << (is_grad_top_cell || is_high_order_top_cell);
|
||||||
return input_args_info;
|
return input_args_info;
|
||||||
|
@ -200,10 +201,10 @@ ForwardExecutorPtr GradExecutor::forward() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GradExecutor::GetCurCellOrder() const {
|
std::string GradExecutor::GetCurCellOrder() const {
|
||||||
if (input_args_info_stack_.empty()) {
|
if (cur_cell_id_.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "The input_args_info_stack_ is empty!";
|
MS_LOG(EXCEPTION) << "The cur_cell_id_ is empty!";
|
||||||
}
|
}
|
||||||
return input_args_info_stack_.top()->cell_id + "_" + std::to_string(cell_order_);
|
return cur_cell_id_ + "_" + std::to_string(cell_order_);
|
||||||
}
|
}
|
||||||
|
|
||||||
TopCellInfoPtr GradExecutor::PopHighOrderGraphStack() {
|
TopCellInfoPtr GradExecutor::PopHighOrderGraphStack() {
|
||||||
|
@ -300,12 +301,18 @@ void GradExecutor::InitResourceAndDfBuilder(const InputArgsInfoPtr &input_args_i
|
||||||
auto graph_info_cg = std::make_shared<GraphInfo>();
|
auto graph_info_cg = std::make_shared<GraphInfo>();
|
||||||
top_cell()->SetGraphInfoMap(curr_g(), graph_info_cg);
|
top_cell()->SetGraphInfoMap(curr_g(), graph_info_cg);
|
||||||
HandleInputArgsForTopCell(input_args_info, false);
|
HandleInputArgsForTopCell(input_args_info, false);
|
||||||
|
top_cell()->set_need_compile_graph(true);
|
||||||
top_cell()->set_init_kpynative(true);
|
top_cell()->set_init_kpynative(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GradExecutor::UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph) const {
|
||||||
|
top_cell()->set_need_compile_graph(need_compile_graph);
|
||||||
|
top_cell()->set_forward_already_run(forward_already_run);
|
||||||
|
}
|
||||||
|
|
||||||
void GradExecutor::NewGraphInner(const py::object &obj, const py::args &args) {
|
void GradExecutor::NewGraphInner(const py::object &obj, const py::args &args) {
|
||||||
const auto &input_args_info = GetInputArgsInfo(obj, args, input_args_info_stack_.empty(), is_high_order_top_cell());
|
const auto input_args_info = GetInputArgsInfo(obj, args);
|
||||||
PushInputArgsInfoStack(input_args_info);
|
PushInputArgsInfoStack(input_args_info);
|
||||||
|
|
||||||
if (input_args_info->has_custom_bprop) {
|
if (input_args_info->has_custom_bprop) {
|
||||||
|
@ -317,17 +324,21 @@ void GradExecutor::NewGraphInner(const py::object &obj, const py::args &args) {
|
||||||
}
|
}
|
||||||
input_args_info->grad_order = grad_order_;
|
input_args_info->grad_order = grad_order_;
|
||||||
// May be can async here
|
// May be can async here
|
||||||
if (enable_async_) {
|
NewGraphImpl(input_args_info);
|
||||||
AsyncNewGraphImpl(input_args_info);
|
}
|
||||||
} else {
|
|
||||||
NewGraphImpl(input_args_info);
|
InputArgsInfoPtr GradExecutor::GetInputArgsInfo(const py::object &obj, const py::args &args) {
|
||||||
}
|
auto input_args_info =
|
||||||
|
ParsePyArgsToInputArgsInfo(obj, args, input_args_info_stack_.empty(), is_high_order_top_cell());
|
||||||
|
input_args_info->use_dynamic_shape_process = use_dynamic_shape_process_;
|
||||||
|
return input_args_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GradExecutor::NewGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
void GradExecutor::NewGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
||||||
MS_EXCEPTION_IF_NULL(input_args_info);
|
MS_EXCEPTION_IF_NULL(input_args_info);
|
||||||
++cell_order_;
|
++cell_order_;
|
||||||
const auto &cell_id = input_args_info->cell_id;
|
const auto &cell_id = input_args_info->cell_id;
|
||||||
|
cur_cell_id_ = cell_id;
|
||||||
MS_LOG(DEBUG) << "NewGraphInner start " << input_args_info->input_size << ", cell_id " << cell_id
|
MS_LOG(DEBUG) << "NewGraphInner start " << input_args_info->input_size << ", cell_id " << cell_id
|
||||||
<< ", input args info ptr " << input_args_info.get();
|
<< ", input args info ptr " << input_args_info.get();
|
||||||
// Make top graph and init resource
|
// Make top graph and init resource
|
||||||
|
@ -357,12 +368,18 @@ void GradExecutor::MakeNewTopGraph(const InputArgsInfoPtr &input_args_info) {
|
||||||
auto fg = std::make_shared<FuncGraph>();
|
auto fg = std::make_shared<FuncGraph>();
|
||||||
fg->debug_info()->set_name("pynative_forward_graph");
|
fg->debug_info()->set_name("pynative_forward_graph");
|
||||||
auto resource = std::make_shared<pipeline::Resource>();
|
auto resource = std::make_shared<pipeline::Resource>();
|
||||||
const auto &already_run_cell_id = input_args_info->cell_id + std::to_string(input_args_info->grad_order);
|
const auto &already_run_cell_id = GetAlreadyRunCellId(input_args_info->cell_id);
|
||||||
top_cell_ = std::make_shared<TopCellInfo>(input_args_info->grad_order, input_args_info->cell_id, already_run_cell_id,
|
top_cell_ = std::make_shared<TopCellInfo>(input_args_info->grad_order, input_args_info->obj_id,
|
||||||
resource, fg);
|
input_args_info->cell_id, already_run_cell_id, resource, fg);
|
||||||
top_cell_->set_forward_already_run(true);
|
top_cell_->set_forward_already_run(true);
|
||||||
|
top_cell_->set_is_run_cell(input_args_info->is_run_cell);
|
||||||
top_cell_->set_input_args_id(input_args_info->input_args_id);
|
top_cell_->set_input_args_id(input_args_info->input_args_id);
|
||||||
PushHighOrderGraphStack(top_cell_);
|
PushHighOrderGraphStack(top_cell_);
|
||||||
|
(void)top_cell_list_.emplace_back(top_cell_);
|
||||||
|
|
||||||
|
const auto &cell_id = input_args_info->obj_id.append("_").append(std::to_string(grad_order_));
|
||||||
|
is_cell_id_in_dynamic_detect_nodes_map_ =
|
||||||
|
(cell_id_with_dynamic_detect_nodes_.find(cell_id) != cell_id_with_dynamic_detect_nodes_.end());
|
||||||
MS_LOG(DEBUG) << "New top graph, fg ptr " << fg.get() << " resource ptr " << resource.get();
|
MS_LOG(DEBUG) << "New top graph, fg ptr " << fg.get() << " resource ptr " << resource.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -387,7 +404,11 @@ void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v, const std::string &
|
||||||
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
|
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
|
||||||
auto sens_v = ConvertOutputValueToTensor(v);
|
auto sens_v = ConvertOutputValueToTensor(v);
|
||||||
auto cloned_value = ShallowCopyTensorValue(sens_v);
|
auto cloned_value = ShallowCopyTensorValue(sens_v);
|
||||||
auto_grad_cell_ptr->UpdateOutputNodeOfTopCell(output_node, cloned_value);
|
if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) {
|
||||||
|
AsyncUpdateOutputNodeOfTopCell(output_node, cloned_value);
|
||||||
|
} else {
|
||||||
|
auto_grad_cell_ptr->UpdateOutputNodeOfTopCell(output_node, cloned_value);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GradExecutor::EndGraphInner(const py::object &obj, const py::object &out, const py::args &args) {
|
void GradExecutor::EndGraphInner(const py::object &obj, const py::object &out, const py::args &args) {
|
||||||
|
@ -400,16 +421,13 @@ void GradExecutor::EndGraphInner(const py::object &obj, const py::object &out, c
|
||||||
GetCustomBpropPrim(obj, args, out, input_args_info);
|
GetCustomBpropPrim(obj, args, out, input_args_info);
|
||||||
}
|
}
|
||||||
input_args_info->out_value = PyNativeAlgo::DataConvert::PyObjToValue(out);
|
input_args_info->out_value = PyNativeAlgo::DataConvert::PyObjToValue(out);
|
||||||
|
input_args_info->use_dynamic_shape_process = use_dynamic_shape_process_;
|
||||||
PopInputArgsInfoStack();
|
PopInputArgsInfoStack();
|
||||||
if (input_args_info->is_grad_topest_cell) {
|
if (input_args_info->is_grad_topest_cell) {
|
||||||
set_grad_flag(false);
|
set_grad_flag(false);
|
||||||
}
|
}
|
||||||
// May be can async here
|
// May be can async here
|
||||||
if (enable_async_) {
|
EndGraphImpl(input_args_info);
|
||||||
AsyncEndGraphImpl(input_args_info);
|
|
||||||
} else {
|
|
||||||
EndGraphImpl(input_args_info);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
||||||
|
@ -453,6 +471,7 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
||||||
SetForwardLastNodeInfo(out_value, out_id);
|
SetForwardLastNodeInfo(out_value, out_id);
|
||||||
}
|
}
|
||||||
top_cell()->CheckSubCellHookChanged();
|
top_cell()->CheckSubCellHookChanged();
|
||||||
|
CheckNeedCompileGraph(input_args_info);
|
||||||
top_input_args_info_ = input_args_info;
|
top_input_args_info_ = input_args_info;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -478,7 +497,15 @@ void GradExecutor::DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info,
|
||||||
op_run_info->input_size = input_args_info->input_arg_value_vec.size();
|
op_run_info->input_size = input_args_info->input_arg_value_vec.size();
|
||||||
op_run_info->input_value_id = input_args_info->input_arg_id_vec;
|
op_run_info->input_value_id = input_args_info->input_arg_id_vec;
|
||||||
auto cnode = ConstructForwardGraph(op_run_info);
|
auto cnode = ConstructForwardGraph(op_run_info);
|
||||||
|
|
||||||
|
if (grad_is_running_ && !bprop_grad_stack_.top().second) {
|
||||||
|
MS_LOG(DEBUG) << "Custom bprop, no need do op grad";
|
||||||
|
return;
|
||||||
|
}
|
||||||
DoOpGrad(op_run_info, cnode, input_args_info->out_value);
|
DoOpGrad(op_run_info, cnode, input_args_info->out_value);
|
||||||
|
CheckGraphDynamic(cnode, top_cell()->op_index());
|
||||||
|
top_cell()->IncreaseOpIndex();
|
||||||
|
|
||||||
SaveOutputNodeMap(out_id, op_run_info, cnode);
|
SaveOutputNodeMap(out_id, op_run_info, cnode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -535,6 +562,56 @@ void GradExecutor::GetCustomBpropPrim(const py::object &obj, const py::args &arg
|
||||||
input_args_info->custom_bprp_prim = fake_prim;
|
input_args_info->custom_bprp_prim = fake_prim;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GradExecutor::CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info) {
|
||||||
|
const auto &new_top_cell = top_cell();
|
||||||
|
const auto &already_top_cell_id = new_top_cell->already_run_cell_id();
|
||||||
|
// Update top cell by current cell op info
|
||||||
|
if (already_run_top_cell_.find(already_top_cell_id) == already_run_top_cell_.end()) {
|
||||||
|
MS_LOG(DEBUG) << "Cell " << already_top_cell_id << " has never been ran, need compile graph";
|
||||||
|
already_run_top_cell_[already_top_cell_id] = new_top_cell;
|
||||||
|
pre_top_cell_ = top_cell();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(DEBUG) << "Top cell " << new_top_cell->cell_id() << " has been ran";
|
||||||
|
auto pre_top_cell = already_run_top_cell_.at(already_top_cell_id);
|
||||||
|
MS_EXCEPTION_IF_NULL(pre_top_cell);
|
||||||
|
|
||||||
|
if (input_args_info->use_dynamic_shape_process || !input_args_info->is_run_cell) {
|
||||||
|
// Function need compile every time.
|
||||||
|
MS_LOG(DEBUG) << "The graph is dynamic, need to compile graph again";
|
||||||
|
EraseTopCellFromTopCellList(pre_top_cell);
|
||||||
|
{
|
||||||
|
py::gil_scoped_acquire acquire;
|
||||||
|
pre_top_cell->Clear();
|
||||||
|
}
|
||||||
|
already_run_top_cell_[already_top_cell_id] = new_top_cell;
|
||||||
|
pre_top_cell_ = nullptr;
|
||||||
|
} else {
|
||||||
|
MS_LOG(DEBUG) << "no need to compile graph again";
|
||||||
|
pre_top_cell->set_input_args_id(new_top_cell->input_args_id());
|
||||||
|
// In high order situations, the internal top cell remains unchanged, but the external top cell has changed. Then
|
||||||
|
// the graph info of the internal top cell needs to be updated so that the external top cell can perceive it.
|
||||||
|
if (!input_args_info->is_grad_topest_cell) {
|
||||||
|
pre_top_cell->SetGraphInfoMap(pre_top_cell->fg(), new_top_cell->graph_info_map().at(new_top_cell->fg()));
|
||||||
|
}
|
||||||
|
pre_top_cell_ = pre_top_cell;
|
||||||
|
pre_top_cell->set_forward_already_run(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GradExecutor::EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell) {
|
||||||
|
MS_EXCEPTION_IF_NULL(top_cell);
|
||||||
|
auto iter = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
|
||||||
|
[&](const TopCellInfoPtr &elem) { return elem.get() == top_cell.get(); });
|
||||||
|
if (iter == top_cell_list_.end()) {
|
||||||
|
MS_LOG(WARNING) << "Can not find top cell " << top_cell.get() << " cell id " << top_cell->cell_id()
|
||||||
|
<< " from top cell list";
|
||||||
|
} else {
|
||||||
|
(void)top_cell_list_.erase(iter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights,
|
void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights,
|
||||||
const py::object &grad_position, const py::args &args) {
|
const py::object &grad_position, const py::args &args) {
|
||||||
{
|
{
|
||||||
|
@ -558,7 +635,19 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob
|
||||||
(void)top_input_args_info_->input_arg_value_vec.emplace_back(ShallowCopyTensorValue(sens_v));
|
(void)top_input_args_info_->input_arg_value_vec.emplace_back(ShallowCopyTensorValue(sens_v));
|
||||||
top_input_args_info_->has_sens = true;
|
top_input_args_info_->has_sens = true;
|
||||||
}
|
}
|
||||||
|
if (pre_top_cell_ != nullptr) {
|
||||||
|
set_top_cell(pre_top_cell_);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!top_cell()->need_compile_graph()) {
|
||||||
|
MS_LOG(DEBUG) << "No need compile graph";
|
||||||
|
top_cell_list_.pop_back();
|
||||||
|
|
||||||
|
UpdateTopCellInfo(false, false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Need compile graph";
|
||||||
|
top_cell()->set_grad_operation(grad_operation_);
|
||||||
SetBpropGraphJitLevel(obj);
|
SetBpropGraphJitLevel(obj);
|
||||||
bool weight_param_is_tuple = true;
|
bool weight_param_is_tuple = true;
|
||||||
auto w_args = GetWeightsArgs(weights, &weight_param_is_tuple);
|
auto w_args = GetWeightsArgs(weights, &weight_param_is_tuple);
|
||||||
|
@ -568,12 +657,22 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob
|
||||||
GetGradGraph(grad_attr, w_args, p_args);
|
GetGradGraph(grad_attr, w_args, p_args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string GradExecutor::GetAlreadyRunCellId(const std::string &cell_id) const {
|
||||||
|
std::string already_run_cell_id(cell_id);
|
||||||
|
already_run_cell_id += std::to_string(grad_order_ == 0 ? 1 : grad_order_);
|
||||||
|
already_run_cell_id += "_" + grad_operation_;
|
||||||
|
MS_LOG(DEBUG) << "Get already run top cell id " << already_run_cell_id;
|
||||||
|
return already_run_cell_id;
|
||||||
|
}
|
||||||
|
|
||||||
void GradExecutor::GetGradGraph(const ad::GradAttr &grad_attr, const std::vector<AnfNodePtr> &w_args,
|
void GradExecutor::GetGradGraph(const ad::GradAttr &grad_attr, const std::vector<AnfNodePtr> &w_args,
|
||||||
const std::vector<size_t> &p_args) {
|
const std::vector<size_t> &p_args) {
|
||||||
// Get bprop graph of top cell
|
// Get bprop graph of top cell
|
||||||
auto bprop_graph = GetBpropGraph(grad_attr, w_args, p_args);
|
auto bprop_graph = GetBpropGraph(grad_attr, w_args, p_args);
|
||||||
MS_EXCEPTION_IF_NULL(bprop_graph);
|
MS_EXCEPTION_IF_NULL(bprop_graph);
|
||||||
bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true);
|
bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true);
|
||||||
|
bool use_dynamic_shape_process = (forward()->device_target() == kAscendDevice ? false : use_dynamic_shape_process_);
|
||||||
|
bprop_graph->set_flag(kFlagUseDynamicShapeProcess, use_dynamic_shape_process);
|
||||||
MS_EXCEPTION_IF_NULL(top_input_args_info_);
|
MS_EXCEPTION_IF_NULL(top_input_args_info_);
|
||||||
bprop_graph->set_attr(kAttrFuncGraphCellId, MakeValue(top_input_args_info_->obj_id));
|
bprop_graph->set_attr(kAttrFuncGraphCellId, MakeValue(top_input_args_info_->obj_id));
|
||||||
auto resource = top_cell()->resource();
|
auto resource = top_cell()->resource();
|
||||||
|
@ -583,14 +682,13 @@ void GradExecutor::GetGradGraph(const ad::GradAttr &grad_attr, const std::vector
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
manager->AddFuncGraph(bprop_graph, true);
|
manager->AddFuncGraph(bprop_graph, true);
|
||||||
PyNativeAlgo::Common::DumpGraphIR("launch_bprop_graph.ir", bprop_graph);
|
PyNativeAlgo::Common::DumpGraphIR("launch_bprop_graph.ir", bprop_graph);
|
||||||
if (backends_.find(top_input_args_info_->obj_id) == backends_.end()) {
|
SaveForwardTensorInfoInBpropGraph(resource);
|
||||||
backends_[top_input_args_info_->obj_id] = compile::CreateBackend();
|
resource->SetBackendAsync([]() { return compile::CreateBackend(); });
|
||||||
}
|
|
||||||
resource->SetBackendAsync([&]() { return backends_[top_input_args_info_->obj_id]; });
|
|
||||||
MS_LOG(DEBUG) << "Start task emit action";
|
MS_LOG(DEBUG) << "Start task emit action";
|
||||||
(void)TaskEmitAction(resource);
|
(void)TaskEmitAction(resource);
|
||||||
MS_LOG(DEBUG) << "Start execute action";
|
MS_LOG(DEBUG) << "Start execute action";
|
||||||
(void)ExecuteAction(resource);
|
(void)ExecuteAction(resource);
|
||||||
|
UpdateTopCellInfo(false, false);
|
||||||
resource->Clean();
|
resource->Clean();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -761,10 +859,18 @@ void GradExecutor::SetGradOrder(const std::string &cell_id) {
|
||||||
}
|
}
|
||||||
|
|
||||||
py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj,
|
py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj,
|
||||||
const py::args &args) {
|
const py::object &grad_hash_id, const py::args &args) {
|
||||||
auto cell_id = GetCellId(obj, args, nullptr);
|
auto cell_id = GetCellId(obj, args, nullptr);
|
||||||
|
|
||||||
// Check current cell grad order and erase it if in current top cell list
|
// Check current cell grad order and erase it if in current top cell list
|
||||||
SetGradOrder(cell_id);
|
SetGradOrder(cell_id);
|
||||||
|
// Include weight param size and required grad flag
|
||||||
|
std::string grad_hash_id_str;
|
||||||
|
if (!py::isinstance<py::none>(grad_hash_id)) {
|
||||||
|
grad_hash_id_str = std::string(py::str(grad_hash_id));
|
||||||
|
}
|
||||||
|
grad_operation_ = std::to_string(static_cast<int>(grad->get_all_)) +
|
||||||
|
std::to_string(static_cast<int>(grad->get_by_list_)) + grad_hash_id_str;
|
||||||
|
|
||||||
std::string input_args_id;
|
std::string input_args_id;
|
||||||
for (size_t i = 0; i < args.size(); ++i) {
|
for (size_t i = 0; i < args.size(); ++i) {
|
||||||
|
@ -774,8 +880,9 @@ py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, con
|
||||||
// check whether need to run forward process
|
// check whether need to run forward process
|
||||||
bool forward_run = false;
|
bool forward_run = false;
|
||||||
if (input_args_info_stack_.empty() && top_cell_ != nullptr) {
|
if (input_args_info_stack_.empty() && top_cell_ != nullptr) {
|
||||||
cell_id += std::to_string(grad_order_ == 0 ? 1 : grad_order_);
|
const auto &check_already_run_cell_id = GetAlreadyRunCellId(cell_id);
|
||||||
if (CanGetTopCell(cell_id)) {
|
auto find_top_cell = GetTopCell(check_already_run_cell_id);
|
||||||
|
if (find_top_cell != nullptr) {
|
||||||
MS_LOG(DEBUG) << "Find already run top cell";
|
MS_LOG(DEBUG) << "Find already run top cell";
|
||||||
forward_run = top_cell()->forward_already_run();
|
forward_run = top_cell()->forward_already_run();
|
||||||
bool input_args_changed = !top_cell()->input_args_id().empty() && top_cell()->input_args_id() != input_args_id;
|
bool input_args_changed = !top_cell()->input_args_id().empty() && top_cell()->input_args_id() != input_args_id;
|
||||||
|
@ -948,7 +1055,12 @@ void GradExecutor::ClearGradRes() {
|
||||||
if (top_cell_ != nullptr) {
|
if (top_cell_ != nullptr) {
|
||||||
top_cell_->ClearDeviceMemory();
|
top_cell_->ClearDeviceMemory();
|
||||||
}
|
}
|
||||||
top_cell_ = nullptr;
|
|
||||||
|
if (use_dynamic_shape_process_ ||
|
||||||
|
already_run_top_cell_.find(top_cell_->already_run_cell_id()) != already_run_top_cell_.end()) {
|
||||||
|
top_cell_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
DecreaseGradOrder();
|
DecreaseGradOrder();
|
||||||
ClearGlobalRes();
|
ClearGlobalRes();
|
||||||
}
|
}
|
||||||
|
@ -959,13 +1071,21 @@ void GradExecutor::ClearRes() {
|
||||||
grad_is_running_ = false;
|
grad_is_running_ = false;
|
||||||
need_renormalize_ = false;
|
need_renormalize_ = false;
|
||||||
eliminate_forward_ = true;
|
eliminate_forward_ = true;
|
||||||
|
use_dynamic_shape_process_ = false;
|
||||||
|
is_cell_id_in_dynamic_detect_nodes_map_ = false;
|
||||||
custom_bprop_cell_count_ = 0;
|
custom_bprop_cell_count_ = 0;
|
||||||
grad_order_ = 0;
|
grad_order_ = 0;
|
||||||
top_cell_ = nullptr;
|
top_cell_ = nullptr;
|
||||||
top_input_args_info_ = nullptr;
|
top_input_args_info_ = nullptr;
|
||||||
bprop_cell_list_.clear();
|
bprop_cell_list_.clear();
|
||||||
backends_.clear();
|
|
||||||
async_executor_->Reset();
|
async_executor_->Reset();
|
||||||
|
for (const auto &cell_ptr : top_cell_list_) {
|
||||||
|
MS_EXCEPTION_IF_NULL(cell_ptr);
|
||||||
|
cell_ptr->Clear();
|
||||||
|
}
|
||||||
|
top_cell_list_.clear();
|
||||||
|
already_run_top_cell_.clear();
|
||||||
|
cell_id_with_dynamic_detect_nodes_.clear();
|
||||||
std::stack<InputArgsInfoPtr>().swap(input_args_info_stack_);
|
std::stack<InputArgsInfoPtr>().swap(input_args_info_stack_);
|
||||||
std::stack<std::pair<std::string, bool>>().swap(bprop_grad_stack_);
|
std::stack<std::pair<std::string, bool>>().swap(bprop_grad_stack_);
|
||||||
std::stack<TopCellInfoPtr>().swap(high_order_stack_);
|
std::stack<TopCellInfoPtr>().swap(high_order_stack_);
|
||||||
|
@ -1072,6 +1192,8 @@ AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v, const std::str
|
||||||
auto cnode = curr_g()->NewCNode(inputs);
|
auto cnode = curr_g()->NewCNode(inputs);
|
||||||
MS_LOG(DEBUG) << "Create make tuple node: " << cnode->DebugString();
|
MS_LOG(DEBUG) << "Create make tuple node: " << cnode->DebugString();
|
||||||
top_cell()->SetNodeMapInGraphInfoMap(obj_id, cnode, -1, false);
|
top_cell()->SetNodeMapInGraphInfoMap(obj_id, cnode, -1, false);
|
||||||
|
CheckGraphDynamic(cnode, top_cell()->op_index());
|
||||||
|
top_cell()->IncreaseOpIndex();
|
||||||
return cnode;
|
return cnode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1099,10 +1221,42 @@ AnfNodePtr GradExecutor::CreateTupleGetItemNode(const std::string &obj_id,
|
||||||
c_node->set_abstract(prim_abs);
|
c_node->set_abstract(prim_abs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
CheckGraphDynamic(c_node, top_cell()->op_index());
|
||||||
|
top_cell()->IncreaseOpIndex();
|
||||||
MS_LOG(DEBUG) << "Get input node " << c_node->ToString() << ", id " << obj_id;
|
MS_LOG(DEBUG) << "Get input node " << c_node->ToString() << ", id " << obj_id;
|
||||||
return c_node;
|
return c_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TopCellInfoPtr GradExecutor::GetTopCell(const std::string &already_run_cell_id) {
|
||||||
|
TopCellInfoPtr find_top_cell = nullptr;
|
||||||
|
for (const auto &top_cell : top_cell_list_) {
|
||||||
|
MS_EXCEPTION_IF_NULL(top_cell);
|
||||||
|
// Complete match, means run grad operation first
|
||||||
|
if (top_cell->already_run_cell_id() == already_run_cell_id) {
|
||||||
|
return top_cell;
|
||||||
|
}
|
||||||
|
// Partial match, means run forward first
|
||||||
|
if (already_run_cell_id.find(top_cell->already_run_cell_id()) != std::string::npos &&
|
||||||
|
top_cell->already_run_cell_id().back() == '_') {
|
||||||
|
find_top_cell = top_cell;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Same topcell info, but grad operation is not the same, construct backward graph again
|
||||||
|
if (find_top_cell != nullptr) {
|
||||||
|
if (!find_top_cell->grad_operation().empty() && find_top_cell->grad_operation() != grad_operation_) {
|
||||||
|
MS_LOG(DEBUG) << "Already exist grad operation " << find_top_cell->grad_operation() << " is different with new "
|
||||||
|
<< grad_operation_;
|
||||||
|
EraseTopCellFromTopCellList(find_top_cell);
|
||||||
|
(void)already_run_top_cell_.erase(find_top_cell->already_run_cell_id());
|
||||||
|
return nullptr;
|
||||||
|
} else {
|
||||||
|
return find_top_cell;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
void GradExecutor::SetHookChanged(const py::object &cell) const {
|
void GradExecutor::SetHookChanged(const py::object &cell) const {
|
||||||
if (top_cell_ == nullptr) {
|
if (top_cell_ == nullptr) {
|
||||||
return;
|
return;
|
||||||
|
@ -1118,24 +1272,19 @@ void GradExecutor::SetHookChanged(const py::object &cell) const {
|
||||||
|
|
||||||
void GradExecutor::ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const {
|
void GradExecutor::ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const {
|
||||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||||
if (!op_run_info->grad_flag) {
|
|
||||||
MS_LOG(DEBUG) << "Grad flag is false";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// Set forward output flag for release memory
|
|
||||||
PyNativeAlgo::Common::SetForwardOutputFlag(op_run_info->out_value);
|
|
||||||
|
|
||||||
// Const value no need do op grad
|
|
||||||
if (op_run_info->output_get_by_infer_value) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// Do op grad and save node info. If cell have custom bprop, no need do op grad. Otherwise, need do.
|
// Do op grad and save node info. If cell have custom bprop, no need do op grad. Otherwise, need do.
|
||||||
if (op_run_info->custom_bprop_cell_count <= 0) {
|
if (op_run_info->custom_bprop_cell_count <= 0) {
|
||||||
const auto &cnode = ConstructForwardGraph(op_run_info);
|
const auto &cnode = ConstructForwardGraph(op_run_info);
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
cnode->set_abstract(op_run_info->base_op_run_info.abstract);
|
cnode->set_abstract(op_run_info->base_op_run_info.abstract);
|
||||||
SaveOutputNodeMap(op_run_info->out_value_id, op_run_info, cnode);
|
SaveOutputNodeMap(op_run_info->out_value_id, op_run_info, cnode);
|
||||||
|
if (grad_is_running_ && !bprop_grad_stack_.top().second) {
|
||||||
|
MS_LOG(DEBUG) << "Custom bprop, no need do op grad";
|
||||||
|
return;
|
||||||
|
}
|
||||||
DoOpGrad(op_run_info, cnode, op_run_info->out_value);
|
DoOpGrad(op_run_info, cnode, op_run_info->out_value);
|
||||||
|
CheckGraphDynamic(cnode, top_cell()->op_index());
|
||||||
|
UpdateForwardTensorInfoInBpropGraph(op_run_info);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1163,10 +1312,6 @@ void GradExecutor::SaveOutputNodeMap(const std::string &obj_id, const FrontendOp
|
||||||
// Run ad grad for curr op and connect grad graph with previous op
|
// Run ad grad for curr op and connect grad graph with previous op
|
||||||
void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode,
|
void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode,
|
||||||
const ValuePtr &op_out) const {
|
const ValuePtr &op_out) const {
|
||||||
if (grad_is_running_ && !bprop_grad_stack_.top().second) {
|
|
||||||
MS_LOG(DEBUG) << "Custom bprop, no need do op grad";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||||
|
|
||||||
// to avoid out exist in tape bprop, avoid out be modified.
|
// to avoid out exist in tape bprop, avoid out be modified.
|
||||||
|
@ -1175,14 +1320,196 @@ void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNode
|
||||||
std::back_inserter(cloned_op_args),
|
std::back_inserter(cloned_op_args),
|
||||||
[](const ValuePtr &value) { return ShallowCopyTensorValue(value); });
|
[](const ValuePtr &value) { return ShallowCopyTensorValue(value); });
|
||||||
ValuePtr cloned_out = ShallowCopyTensorValue(op_out);
|
ValuePtr cloned_out = ShallowCopyTensorValue(op_out);
|
||||||
std::vector<tensor::TensorPtr> tensors;
|
|
||||||
TensorValueToTensor(cloned_out, &tensors);
|
|
||||||
for (auto tensor : tensors) {
|
|
||||||
tensor->set_is_forward_output(true);
|
|
||||||
}
|
|
||||||
if (!ad::GradPynativeOp(top_cell()->auto_grad_cell_ptr(), cnode, cloned_op_args, cloned_out)) {
|
if (!ad::GradPynativeOp(top_cell()->auto_grad_cell_ptr(), cnode, cloned_op_args, cloned_out)) {
|
||||||
MS_LOG(EXCEPTION) << "Failed to run ad grad for op " << op_run_info->base_op_run_info.op_name;
|
MS_LOG(EXCEPTION) << "Failed to run ad grad for op " << op_run_info->base_op_run_info.op_name;
|
||||||
}
|
}
|
||||||
|
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
|
||||||
|
if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) {
|
||||||
|
AsyncGradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out);
|
||||||
|
} else {
|
||||||
|
GradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GradExecutor::GradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode,
|
||||||
|
const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const {
|
||||||
|
if (!ad::GradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Failed to run ad grad for op ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GradExecutor::AsyncGradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode,
|
||||||
|
const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const {
|
||||||
|
const auto fn = [this, auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out]() {
|
||||||
|
this->GradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out);
|
||||||
|
};
|
||||||
|
auto task = std::make_shared<BpropTask>(fn);
|
||||||
|
async_executor_->Push(task);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GradExecutor::AsyncUpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &cloned_value) const {
|
||||||
|
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
|
||||||
|
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
|
||||||
|
const auto fn = [auto_grad_cell_ptr, output_node, cloned_value]() {
|
||||||
|
auto_grad_cell_ptr->UpdateOutputNodeOfTopCell(output_node, cloned_value);
|
||||||
|
};
|
||||||
|
auto task = std::make_shared<BpropTask>(fn);
|
||||||
|
async_executor_->Push(task);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPtr &op_run_info) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||||
|
if (op_run_info->base_op_run_info.use_dynamic_shape_process) {
|
||||||
|
MS_LOG(DEBUG) << "Get dynamic shape process";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
top_cell()->GetOpInfo(op_run_info);
|
||||||
|
MS_LOG(DEBUG) << "Current op info: " << op_run_info->op_info;
|
||||||
|
|
||||||
|
std::vector<tensor::TensorPtr> op_output_tensors;
|
||||||
|
// Get output tensors
|
||||||
|
TensorValueToTensor(op_run_info->out_value, &op_output_tensors);
|
||||||
|
// Save all tensors info of current op
|
||||||
|
top_cell()->set_opinfo_with_tensor_id(op_run_info->op_info, op_output_tensors);
|
||||||
|
|
||||||
|
// First run top cell
|
||||||
|
if (already_run_top_cell_.find(top_cell_->already_run_cell_id()) == already_run_top_cell_.end()) {
|
||||||
|
MS_LOG(DEBUG) << "Top cell " << top_cell_->cell_id() << " run firstly";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Non-first run
|
||||||
|
const auto &pre_top_cell = already_run_top_cell_.at(top_cell_->already_run_cell_id());
|
||||||
|
MS_EXCEPTION_IF_NULL(pre_top_cell);
|
||||||
|
if (pre_top_cell->op_info_with_tensor_id().find(op_run_info->op_info) ==
|
||||||
|
pre_top_cell->op_info_with_tensor_id().end()) {
|
||||||
|
MS_LOG(DEBUG) << "Can not find op info " << op_run_info->op_info << " in op info with tensor id map. Top cell "
|
||||||
|
<< top_cell_->cell_id();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update new output tensor info in bprop graph
|
||||||
|
const auto &pre_op_tensor_id = pre_top_cell->op_info_with_tensor_id().at(op_run_info->op_info);
|
||||||
|
if (pre_op_tensor_id.size() != op_output_tensors.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The size of op pre output tensor size: " << pre_op_tensor_id.size()
|
||||||
|
<< " is not equal to current " << op_output_tensors.size();
|
||||||
|
}
|
||||||
|
// For value node tensor in the bprop graph, take its id for tensor, and save in tensor_id_with_tensor_object;
|
||||||
|
// And then take the output of the op and find out if the output used by tensor_id_with_tensor_object,
|
||||||
|
// if there is a tensor need to replace it.
|
||||||
|
const auto &pre_tensor_id_with_tensor_object = pre_top_cell->tensor_id_with_tensor_object();
|
||||||
|
for (size_t i = 0; i < pre_op_tensor_id.size(); ++i) {
|
||||||
|
auto pre_id = pre_op_tensor_id[i];
|
||||||
|
if (pre_tensor_id_with_tensor_object.find(pre_id) == pre_tensor_id_with_tensor_object.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Based on the output size of the op is fixed, so can use index.
|
||||||
|
const auto &new_tensor = op_output_tensors[i];
|
||||||
|
const auto &pre_tensor_object = pre_tensor_id_with_tensor_object.at(pre_id);
|
||||||
|
UpdatePreTensorInfo(new_tensor, pre_tensor_object);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GradExecutor::UpdatePreTensorInfo(const tensor::TensorPtr &new_tensor,
|
||||||
|
const std::vector<tensor::TensorPtr> &pre_tensors) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(new_tensor);
|
||||||
|
if (pre_tensors.empty() || new_tensor->device_address() == nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "The number of pre tensors is zero or the device address of new tensor is nullptr.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const auto &device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||||
|
for (auto &pre_tensor : pre_tensors) {
|
||||||
|
MS_EXCEPTION_IF_NULL(pre_tensor);
|
||||||
|
MS_LOG(DEBUG) << "Replace Old tensor id " << pre_tensor->id() << " device_address: " << pre_tensor->device_address()
|
||||||
|
<< " shape and type " << pre_tensor->GetShapeAndDataTypeInfo() << " with New tensor id "
|
||||||
|
<< new_tensor->id() << " device_address " << new_tensor->device_address() << " shape and dtype "
|
||||||
|
<< new_tensor->GetShapeAndDataTypeInfo();
|
||||||
|
(void)pre_tensor->set_shape(new_tensor->shape());
|
||||||
|
(void)pre_tensor->set_data_type(new_tensor->data_type());
|
||||||
|
auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
|
||||||
|
MS_EXCEPTION_IF_NULL(device_address);
|
||||||
|
if (device_target != kCPUDevice && device_address->GetDeviceType() != device::DeviceType::kCPU) {
|
||||||
|
pre_tensor->set_device_address(new_tensor->device_address());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (const auto &item : PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->mindrt_backend()) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item.second);
|
||||||
|
item.second->WaitTaskFinish();
|
||||||
|
}
|
||||||
|
// Replace data in device address when run in CPU device.
|
||||||
|
if (pre_tensor->device_address() != nullptr) {
|
||||||
|
// If tensor is dynamic shape, Just replace device address.
|
||||||
|
if (PyNativeAlgo::Common::ValueHasDynamicShape(pre_tensor)) {
|
||||||
|
pre_tensor->set_device_address(new_tensor->device_address());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(pre_tensor->device_address());
|
||||||
|
MS_EXCEPTION_IF_NULL(old_device_address);
|
||||||
|
auto new_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
|
||||||
|
MS_EXCEPTION_IF_NULL(new_device_address);
|
||||||
|
|
||||||
|
// CPU host tensor data_c is different from device address if the address is from mem_pool.
|
||||||
|
if (new_device_address->from_mem_pool()) {
|
||||||
|
pre_tensor->set_device_address(new_device_address);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto old_ptr = old_device_address->GetMutablePtr();
|
||||||
|
MS_EXCEPTION_IF_NULL(old_ptr);
|
||||||
|
auto new_ptr = new_device_address->GetPtr();
|
||||||
|
MS_EXCEPTION_IF_NULL(new_ptr);
|
||||||
|
MS_EXCEPTION_IF_CHECK_FAIL(old_device_address->GetSize() == new_device_address->GetSize(), "Size not equal");
|
||||||
|
if (old_device_address->GetSize() < SECUREC_MEM_MAX_LEN) {
|
||||||
|
auto ret_code = memcpy_s(old_ptr, old_device_address->GetSize(), new_ptr, new_device_address->GetSize());
|
||||||
|
MS_EXCEPTION_IF_CHECK_FAIL(ret_code == EOK, "Memory copy failed, ret code: " + std::to_string(ret_code));
|
||||||
|
} else {
|
||||||
|
auto ret_code = std::memcpy(old_ptr, new_ptr, old_device_address->GetSize());
|
||||||
|
MS_EXCEPTION_IF_CHECK_FAIL(ret_code == old_ptr, "Memory copy failed");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
pre_tensor->set_device_address(device_address);
|
||||||
|
pre_tensor->data_sync();
|
||||||
|
pre_tensor->set_device_address(nullptr);
|
||||||
|
pre_tensor->set_sync_status(kNeedSyncHostToDevice);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const {
|
||||||
|
if (use_dynamic_shape_process_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Get all tensors id of forward op
|
||||||
|
mindspore::HashSet<std::string> forward_op_tensor_id;
|
||||||
|
const auto &op_info_with_tensor_id = top_cell()->op_info_with_tensor_id();
|
||||||
|
for (const auto &record : op_info_with_tensor_id) {
|
||||||
|
(void)std::for_each(
|
||||||
|
record.second.begin(), record.second.end(),
|
||||||
|
[&forward_op_tensor_id](const std::string &tensor_id) { (void)forward_op_tensor_id.emplace(tensor_id); });
|
||||||
|
}
|
||||||
|
// Get all tensors obj in value node of bprop graph
|
||||||
|
MS_EXCEPTION_IF_NULL(resource);
|
||||||
|
const auto &bprop_graph = resource->func_graph();
|
||||||
|
MS_EXCEPTION_IF_NULL(bprop_graph);
|
||||||
|
const auto &value_node_list = bprop_graph->value_nodes();
|
||||||
|
std::vector<tensor::TensorPtr> tensors_in_bprop_graph;
|
||||||
|
for (const auto &elem : value_node_list) {
|
||||||
|
auto value_node = elem.first->cast<ValueNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
|
TensorValueToTensor(value_node->value(), &tensors_in_bprop_graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save tensor in value node of bprop graph
|
||||||
|
for (const auto &tensor : tensors_in_bprop_graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
|
if (forward_op_tensor_id.find(tensor->id()) == forward_op_tensor_id.end() || tensor->device_address() == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
tensor->set_is_forward_output(true);
|
||||||
|
top_cell()->set_tensor_id_with_tensor_object(tensor->id(), tensor);
|
||||||
|
MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
|
||||||
|
<< " device address: " << tensor->device_address() << " shape and dtype "
|
||||||
|
<< tensor->GetShapeAndDataTypeInfo();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr GradExecutor::GetRealInputNodeBySkipHook(const AnfNodePtr &input_node) const {
|
AnfNodePtr GradExecutor::GetRealInputNodeBySkipHook(const AnfNodePtr &input_node) const {
|
||||||
|
@ -1241,6 +1568,7 @@ CNodePtr GradExecutor::ConstructForwardGraph(const FrontendOpRunInfoPtr &op_run_
|
||||||
if (IsPrimitiveCNode(cnode, prim::kPrimCellBackwardHook)) {
|
if (IsPrimitiveCNode(cnode, prim::kPrimCellBackwardHook)) {
|
||||||
top_cell()->RecordCellBackwardHookOp(GetCurCellOrder(), cnode);
|
top_cell()->RecordCellBackwardHookOp(GetCurCellOrder(), cnode);
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "Make CNode for " << op_run_info->base_op_run_info.op_name << ", new cnode is "
|
MS_LOG(DEBUG) << "Make CNode for " << op_run_info->base_op_run_info.op_name << ", new cnode is "
|
||||||
<< cnode->DebugString();
|
<< cnode->DebugString();
|
||||||
return cnode;
|
return cnode;
|
||||||
|
@ -1260,5 +1588,235 @@ void GradExecutor::SetBpropGraphJitLevel(const py::object &obj) const {
|
||||||
MS_EXCEPTION_IF_NULL(graph_executor);
|
MS_EXCEPTION_IF_NULL(graph_executor);
|
||||||
graph_executor->SetJitConfig(jit_config_dict);
|
graph_executor->SetJitConfig(jit_config_dict);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, const size_t &node_idx,
|
||||||
|
bool is_ms_function_node,
|
||||||
|
const std::string &graph_phase) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
auto node_info = std::make_shared<DynamicDetectNodeInfo>();
|
||||||
|
if (!is_ms_function_node) {
|
||||||
|
node_info->prim = GetCNodePrimitive(cnode);
|
||||||
|
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||||
|
const auto &input_node = cnode->input(i);
|
||||||
|
MS_EXCEPTION_IF_NULL(input_node);
|
||||||
|
|
||||||
|
if (input_node->isa<ValueNode>()) {
|
||||||
|
node_info->input_values[i] = GetValueNode(input_node);
|
||||||
|
} else if (input_node->isa<CNode>()) {
|
||||||
|
const auto &node_abs = input_node->abstract();
|
||||||
|
auto op_index = top_cell()->get_op_index_by_cnode_hash(input_node->hash());
|
||||||
|
node_info->input_cnode_info[i] = std::make_pair(op_index, node_abs);
|
||||||
|
} else {
|
||||||
|
if (!input_node->isa<Parameter>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "input_node:" << input_node->fullname_with_scope()
|
||||||
|
<< " is none of value node, cnode and parameter.";
|
||||||
|
}
|
||||||
|
const auto ¶m = input_node->cast<ParameterPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(param);
|
||||||
|
node_info->input_param_infos[i] = param->param_info();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node_info->output_abs = cnode->abstract();
|
||||||
|
} else {
|
||||||
|
node_info->is_graph_node = true;
|
||||||
|
node_info->graph_phase = graph_phase;
|
||||||
|
}
|
||||||
|
top_cell()->set_cnode_hash_with_op_index(cnode->hash(), node_idx);
|
||||||
|
const auto &cell_id = top_cell()->c_cell_id() + "_" + std::to_string(top_cell()->grad_order());
|
||||||
|
(void)cell_id_with_dynamic_detect_nodes_[cell_id].emplace_back(node_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsAbsDifferent(const AbstractBasePtr &old_abs, const AbstractBasePtr &new_abs) {
|
||||||
|
if (old_abs == new_abs) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (old_abs == nullptr || new_abs == nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "graph is dynamic, old_abs is different with new_abs";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!common::IsEqual(old_abs->BuildType(), new_abs->BuildType()) ||
|
||||||
|
!common::IsEqual(old_abs->BuildShape(), new_abs->BuildShape())) {
|
||||||
|
MS_LOG(DEBUG) << "graph is dynamic, old_abs is different with new_abs, old abs:" << old_abs->ToString()
|
||||||
|
<< " new abs:" << new_abs->ToString();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsValuePtrEqual(const ValuePtr &v1, const ValuePtr &v2) {
|
||||||
|
if (v1 == v2) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (v1 == nullptr || v2 == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (v1->isa<tensor::Tensor>() && v2->isa<tensor::Tensor>()) {
|
||||||
|
return v1->cast<tensor::TensorPtr>()->ValueEqual(*(v2->cast<tensor::TensorPtr>()));
|
||||||
|
}
|
||||||
|
return *v1 == *v2;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsParamInfoEqual(const ParamInfoPtr &p1, const ParamInfoPtr &p2) {
|
||||||
|
if (p1 == p2) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (p1 == nullptr || p2 == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return p1->key() == p2->key();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GradExecutor::IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info,
|
||||||
|
const std::vector<AnfNodePtr> &new_anf_inputs) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(old_node_info);
|
||||||
|
|
||||||
|
auto old_input_size = old_node_info->input_cnode_info.size() + old_node_info->input_values.size() +
|
||||||
|
old_node_info->input_param_infos.size();
|
||||||
|
if (old_input_size != new_anf_inputs.size() - 1) {
|
||||||
|
MS_LOG(DEBUG) << "graph is dynamic, old input size:" << old_input_size
|
||||||
|
<< " new input_infos:" << (new_anf_inputs.size() - 1);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 1; i < new_anf_inputs.size(); i++) {
|
||||||
|
const auto &new_anf_input = new_anf_inputs[i];
|
||||||
|
MS_EXCEPTION_IF_NULL(new_anf_input);
|
||||||
|
if (new_anf_input->isa<ValueNode>()) {
|
||||||
|
const auto &value_iter = old_node_info->input_values.find(i);
|
||||||
|
if (value_iter == old_node_info->input_values.end()) {
|
||||||
|
MS_LOG(DEBUG) << "The " << i << "th input is different, cur input is a value, old input is not a value.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!IsValuePtrEqual(value_iter->second, GetValueNode(new_anf_input))) {
|
||||||
|
MS_LOG(DEBUG) << "The " << i << "th input, value is different.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} else if (new_anf_input->isa<CNode>()) {
|
||||||
|
// Compare cnode abstract.
|
||||||
|
const auto &node_iter = old_node_info->input_cnode_info.find(i);
|
||||||
|
if (node_iter == old_node_info->input_cnode_info.end()) {
|
||||||
|
MS_LOG(DEBUG) << "The " << i << "th input is different, cur input is a cnode, old input is not a cnode.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t old_op_index = 0;
|
||||||
|
AbstractBasePtr old_abs = nullptr;
|
||||||
|
std::tie(old_op_index, old_abs) = node_iter->second;
|
||||||
|
if (IsAbsDifferent(old_abs, new_anf_input->abstract())) {
|
||||||
|
MS_LOG(DEBUG) << "The " << i << "th input, abs is different.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare cnode edge.
|
||||||
|
if (old_op_index != top_cell()->get_op_index_by_cnode_hash(new_anf_input->hash())) {
|
||||||
|
MS_LOG(DEBUG) << "The " << i << "th input, op_index is different, old op_index:" << old_op_index
|
||||||
|
<< " new op_index:" << top_cell()->get_op_index_by_cnode_hash(new_anf_input->hash());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Compare parameter.
|
||||||
|
if (!new_anf_input->isa<Parameter>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "new_anf_input:" << new_anf_input->fullname_with_scope()
|
||||||
|
<< " is none of value node, cnode and parameter.";
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &node_iter = old_node_info->input_param_infos.find(i);
|
||||||
|
if (node_iter == old_node_info->input_param_infos.end()) {
|
||||||
|
MS_LOG(DEBUG) << "The " << i
|
||||||
|
<< "th input is different, cur input is a parameter, old input is not a parameter.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto ¶m = new_anf_input->cast<ParameterPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(param);
|
||||||
|
if (!IsParamInfoEqual(node_iter->second, param->param_info())) {
|
||||||
|
MS_LOG(DEBUG) << "The " << i << "th input, param info is different.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GradExecutor::IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info,
|
||||||
|
const CNodePtr &new_cnode, bool is_ms_function_node,
|
||||||
|
const std::string &graph_phase) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||||
|
MS_EXCEPTION_IF_NULL(old_node_info);
|
||||||
|
|
||||||
|
// 1.Detect ms_function phase
|
||||||
|
if (is_ms_function_node != old_node_info->is_graph_node ||
|
||||||
|
(is_ms_function_node && graph_phase != old_node_info->graph_phase)) {
|
||||||
|
MS_LOG(DEBUG) << "graph is dynamic, old is_graph_node:" << old_node_info->is_graph_node
|
||||||
|
<< " new is_graph_node:" << is_ms_function_node << " old graph_phase" << old_node_info->graph_phase
|
||||||
|
<< " new graph_phase:" << graph_phase;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2.Detect cnode prim
|
||||||
|
auto new_prim = GetCNodePrimitive(new_cnode);
|
||||||
|
if (!common::IsEqual(new_prim, old_node_info->prim)) {
|
||||||
|
MS_LOG(DEBUG) << "graph is dynamic, old prim:" << (old_node_info->prim == nullptr ? 0 : old_node_info->prim->name())
|
||||||
|
<< " new prim:" << (new_prim == nullptr ? 0 : new_prim->name());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3.Detect output abs
|
||||||
|
if (IsAbsDifferent(old_node_info->output_abs, new_cnode->abstract())) {
|
||||||
|
MS_LOG(DEBUG) << "graph is dynamic, output_abs is different";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4.Detect inputs
|
||||||
|
return IsCnodeInputsDynamic(old_node_info, new_cnode->inputs());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
|
||||||
|
const std::string &graph_phase) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if (!is_cell_id_in_dynamic_detect_nodes_map_) {
|
||||||
|
SaveDynamicDetectNodeInfoInFirstTime(cnode, node_idx, is_ms_function_node, graph_phase);
|
||||||
|
// The net is regarded as a static net by default in the first time.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &cell_id = top_cell()->c_cell_id() + "_" + std::to_string(top_cell()->grad_order());
|
||||||
|
const auto &dynamic_nodes = cell_id_with_dynamic_detect_nodes_[cell_id];
|
||||||
|
if (node_idx >= dynamic_nodes.size()) {
|
||||||
|
MS_LOG(DEBUG) << "old dynamic_nodes size:" << dynamic_nodes.size() << " cur node_idx is:" << node_idx
|
||||||
|
<< ", graph is dynamic.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (IsDynamicDetectNodeInfoChange(dynamic_nodes[node_idx], cnode, is_ms_function_node, graph_phase)) {
|
||||||
|
MS_LOG(DEBUG) << "graph is dynamic, node_idx:" << node_idx
|
||||||
|
<< " is different, cnode:" << cnode->fullname_with_scope();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
top_cell()->set_cnode_hash_with_op_index(cnode->hash(), node_idx);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GradExecutor::CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
|
||||||
|
const std::string &graph_phase) const {
|
||||||
|
if (!top_cell()->is_run_cell() || use_dynamic_shape_process_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
use_dynamic_shape_process_ = IsGraphDynamic(cnode, node_idx, is_ms_function_node, graph_phase);
|
||||||
|
if (use_dynamic_shape_process_) {
|
||||||
|
MS_LOG(DEBUG) << "cnode:" << cnode->fullname_with_scope() << ",node_idx:" << node_idx
|
||||||
|
<< ",is_ms_function_node:" << is_ms_function_node << ",graph_phase:" << graph_phase
|
||||||
|
<< ",use_dynamic_shape_process_:" << use_dynamic_shape_process_;
|
||||||
|
cell_id_with_dynamic_detect_nodes_.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace pynative
|
} // namespace pynative
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -36,6 +36,17 @@ class ForwardExecutor;
|
||||||
using ForwardExecutorPtr = std::shared_ptr<ForwardExecutor>;
|
using ForwardExecutorPtr = std::shared_ptr<ForwardExecutor>;
|
||||||
using ForwardExecutorWeakPtr = std::weak_ptr<ForwardExecutor>;
|
using ForwardExecutorWeakPtr = std::weak_ptr<ForwardExecutor>;
|
||||||
|
|
||||||
|
struct DynamicDetectNodeInfo {
|
||||||
|
PrimitivePtr prim{nullptr};
|
||||||
|
AbstractBasePtr output_abs{nullptr};
|
||||||
|
bool is_graph_node{false};
|
||||||
|
std::string graph_phase;
|
||||||
|
mindspore::HashMap<size_t, std::pair<size_t, AbstractBasePtr>> input_cnode_info;
|
||||||
|
mindspore::HashMap<size_t, ValuePtr> input_values;
|
||||||
|
mindspore::HashMap<size_t, ParamInfoPtr> input_param_infos;
|
||||||
|
};
|
||||||
|
using DynamicDetectNodeInfoPtr = std::shared_ptr<DynamicDetectNodeInfo>;
|
||||||
|
|
||||||
class GradExecutor {
|
class GradExecutor {
|
||||||
public:
|
public:
|
||||||
GradExecutor() = default;
|
GradExecutor() = default;
|
||||||
|
@ -43,8 +54,7 @@ class GradExecutor {
|
||||||
explicit GradExecutor(const ForwardExecutorPtr &forward_executor = nullptr)
|
explicit GradExecutor(const ForwardExecutorPtr &forward_executor = nullptr)
|
||||||
: forward_executor_(ForwardExecutorWeakPtr(forward_executor)),
|
: forward_executor_(ForwardExecutorWeakPtr(forward_executor)),
|
||||||
ms_function_(std::make_shared<MsFunction>()),
|
ms_function_(std::make_shared<MsFunction>()),
|
||||||
async_executor_(std::make_unique<AsyncQueue>()),
|
async_executor_(std::make_shared<AsyncQueue>()) {}
|
||||||
enable_async_(std::getenv("ENABLE_ASYNC")) {}
|
|
||||||
|
|
||||||
std::function<void(const py::object &, const py::args &)> InitGraph = [this](auto &&PH1, auto &&PH2) {
|
std::function<void(const py::object &, const py::args &)> InitGraph = [this](auto &&PH1, auto &&PH2) {
|
||||||
NewGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2));
|
NewGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2));
|
||||||
|
@ -69,6 +79,10 @@ class GradExecutor {
|
||||||
MS_EXCEPTION_IF_NULL(ms_function_);
|
MS_EXCEPTION_IF_NULL(ms_function_);
|
||||||
return ms_function_;
|
return ms_function_;
|
||||||
}
|
}
|
||||||
|
inline void set_use_dynamic_shape_process(bool use_dynamic_shape_process) {
|
||||||
|
use_dynamic_shape_process_ = use_dynamic_shape_process;
|
||||||
|
}
|
||||||
|
|
||||||
inline bool need_renormalize() const { return need_renormalize_; }
|
inline bool need_renormalize() const { return need_renormalize_; }
|
||||||
inline void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); }
|
inline void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); }
|
||||||
inline bool grad_flag() const { return grad_flag_; }
|
inline bool grad_flag() const { return grad_flag_; }
|
||||||
|
@ -77,12 +91,16 @@ class GradExecutor {
|
||||||
inline bool eliminate_forward() const { return eliminate_forward_; }
|
inline bool eliminate_forward() const { return eliminate_forward_; }
|
||||||
inline void set_eliminate_forward(bool eliminate_forward) { eliminate_forward_ = eliminate_forward; }
|
inline void set_eliminate_forward(bool eliminate_forward) { eliminate_forward_ = eliminate_forward; }
|
||||||
inline size_t custom_bprop_cell_count() const { return custom_bprop_cell_count_; }
|
inline size_t custom_bprop_cell_count() const { return custom_bprop_cell_count_; }
|
||||||
|
inline bool use_dynamic_shape_process() const { return use_dynamic_shape_process_; }
|
||||||
|
inline std::shared_ptr<AsyncQueue> async_executor() const { return async_executor_; }
|
||||||
void SetHookChanged(const py::object &cell) const;
|
void SetHookChanged(const py::object &cell) const;
|
||||||
void GradNetInner(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights,
|
void GradNetInner(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights,
|
||||||
const py::object &grad_position, const py::args &args);
|
const py::object &grad_position, const py::args &args);
|
||||||
py::object RunGradGraph();
|
py::object RunGradGraph();
|
||||||
CNodePtr ConstructForwardGraph(const FrontendOpRunInfoPtr &op_run_info) const;
|
CNodePtr ConstructForwardGraph(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||||
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::args &args);
|
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &grad_hash_id,
|
||||||
|
const py::args &args);
|
||||||
|
TopCellInfoPtr GetTopCell(const std::string &already_run_cell_id);
|
||||||
void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
|
void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||||
void AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
|
void AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||||
void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args);
|
void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args);
|
||||||
|
@ -90,25 +108,34 @@ class GradExecutor {
|
||||||
AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const;
|
AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const;
|
||||||
void AsyncEndGraphImpl(const InputArgsInfoPtr input_args_info);
|
void AsyncEndGraphImpl(const InputArgsInfoPtr input_args_info);
|
||||||
AnfNodePtr GetParamInput(const ValuePtr &v, const std::string &id) const;
|
AnfNodePtr GetParamInput(const ValuePtr &v, const std::string &id) const;
|
||||||
|
void UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||||
|
void UpdatePreTensorInfo(const tensor::TensorPtr &new_tensor,
|
||||||
|
const std::vector<tensor::TensorPtr> &pre_tensors) const;
|
||||||
void ClearRes();
|
void ClearRes();
|
||||||
void WorkerJoin() { async_executor_->WorkerJoin(); }
|
void WorkerJoin() { async_executor_->WorkerJoin(); }
|
||||||
|
|
||||||
|
void CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node = false,
|
||||||
|
const std::string &graph_phase = "") const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ForwardExecutorPtr forward() const;
|
ForwardExecutorPtr forward() const;
|
||||||
inline FuncGraphPtr curr_g() const { return top_cell()->fg(); }
|
inline FuncGraphPtr curr_g() const { return top_cell()->fg(); }
|
||||||
inline void PushHighOrderGraphStack(const TopCellInfoPtr &top_cell) { high_order_stack_.push(top_cell); }
|
inline void PushHighOrderGraphStack(const TopCellInfoPtr &top_cell) { high_order_stack_.push(top_cell); }
|
||||||
inline bool CanGetTopCell(const string &already_run_cell_id) {
|
|
||||||
return already_run_cell_id.find(top_cell()->already_run_cell_id()) != std::string::npos;
|
|
||||||
}
|
|
||||||
std::string GetCurCellOrder() const;
|
std::string GetCurCellOrder() const;
|
||||||
void SetGradOrder(const std::string &cell_id);
|
void SetGradOrder(const std::string &cell_id);
|
||||||
void SaveOutputNodeMap(const std::string &obj_id, const FrontendOpRunInfoPtr &op_run_info,
|
void SaveOutputNodeMap(const std::string &obj_id, const FrontendOpRunInfoPtr &op_run_info,
|
||||||
const CNodePtr &cnode) const;
|
const CNodePtr &cnode) const;
|
||||||
void DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode, const ValuePtr &op_out) const;
|
void DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode, const ValuePtr &op_out) const;
|
||||||
|
void GradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode,
|
||||||
|
const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const;
|
||||||
|
void AsyncGradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode,
|
||||||
|
const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const;
|
||||||
|
void AsyncUpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &cloned_value) const;
|
||||||
AnfNodePtr GetRealInputNodeBySkipHook(const AnfNodePtr &input_node) const;
|
AnfNodePtr GetRealInputNodeBySkipHook(const AnfNodePtr &input_node) const;
|
||||||
void SetBpropGraphJitLevel(const py::object &obj) const;
|
void SetBpropGraphJitLevel(const py::object &obj) const;
|
||||||
void ClearGlobalRes();
|
void ClearGlobalRes();
|
||||||
void ClearGradRes();
|
void ClearGradRes();
|
||||||
|
std::string GetAlreadyRunCellId(const std::string &cell_id) const;
|
||||||
|
|
||||||
// Higher derivative
|
// Higher derivative
|
||||||
inline bool IsNestedGrad() const { return grad_order_ > 1; }
|
inline bool IsNestedGrad() const { return grad_order_ > 1; }
|
||||||
|
@ -121,6 +148,7 @@ class GradExecutor {
|
||||||
inline bool is_high_order_top_cell() const {
|
inline bool is_high_order_top_cell() const {
|
||||||
return !input_args_info_stack_.empty() && IsNestedGrad() && top_cell()->grad_order() != grad_order_;
|
return !input_args_info_stack_.empty() && IsNestedGrad() && top_cell()->grad_order() != grad_order_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SwitchTopCell();
|
void SwitchTopCell();
|
||||||
void DoParameterReplace(const FuncGraphPtr &first_grad_fg, const std::vector<ValuePtr> &forward_args,
|
void DoParameterReplace(const FuncGraphPtr &first_grad_fg, const std::vector<ValuePtr> &forward_args,
|
||||||
std::vector<AnfNodePtr> *inputs, ValuePtrList *weights_args);
|
std::vector<AnfNodePtr> *inputs, ValuePtrList *weights_args);
|
||||||
|
@ -132,15 +160,20 @@ class GradExecutor {
|
||||||
void HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_info, bool is_bprop_top) const;
|
void HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_info, bool is_bprop_top) const;
|
||||||
void InitResourceAndDfBuilder(const InputArgsInfoPtr &cell_info);
|
void InitResourceAndDfBuilder(const InputArgsInfoPtr &cell_info);
|
||||||
void MakeNewTopGraph(const InputArgsInfoPtr &input_args_info);
|
void MakeNewTopGraph(const InputArgsInfoPtr &input_args_info);
|
||||||
|
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph) const;
|
||||||
|
|
||||||
// Manage resource when run grad process.
|
// Manage resource when run grad process.
|
||||||
bool IsBpropGraph(const std::string &cell_id) const;
|
bool IsBpropGraph(const std::string &cell_id) const;
|
||||||
void NewGraphInner(const py::object &obj, const py::args &args);
|
void NewGraphInner(const py::object &obj, const py::args &args);
|
||||||
|
InputArgsInfoPtr GetInputArgsInfo(const py::object &obj, const py::args &args);
|
||||||
void NewGraphImpl(const InputArgsInfoPtr &input_args_info);
|
void NewGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||||
void AsyncNewGraphImpl(const InputArgsInfoPtr &input_args_info);
|
void AsyncNewGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||||
void SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const;
|
void SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const;
|
||||||
void GetCustomBpropPrim(const py::object &obj, const py::args &args, const py::object &out,
|
void GetCustomBpropPrim(const py::object &obj, const py::args &args, const py::object &out,
|
||||||
const InputArgsInfoPtr &input_args_info);
|
const InputArgsInfoPtr &input_args_info);
|
||||||
void DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info, const std::string &out_id);
|
void DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info, const std::string &out_id);
|
||||||
|
void CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info);
|
||||||
|
void EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell);
|
||||||
void GetGradGraph(const ad::GradAttr &grad_attr, const std::vector<AnfNodePtr> &w_args,
|
void GetGradGraph(const ad::GradAttr &grad_attr, const std::vector<AnfNodePtr> &w_args,
|
||||||
const std::vector<size_t> &p_args);
|
const std::vector<size_t> &p_args);
|
||||||
FuncGraphPtr GetBpropGraph(const ad::GradAttr &grad_attr, const vector<AnfNodePtr> &w_args,
|
FuncGraphPtr GetBpropGraph(const ad::GradAttr &grad_attr, const vector<AnfNodePtr> &w_args,
|
||||||
|
@ -151,22 +184,38 @@ class GradExecutor {
|
||||||
const abstract::AbstractBasePtr ¶m_tensor_abs, const std::string &input_shape);
|
const abstract::AbstractBasePtr ¶m_tensor_abs, const std::string &input_shape);
|
||||||
void UpdateParamAbsByArgs(const std::vector<ValuePtr> &input_args, const FuncGraphPtr &bprop_graph, bool has_sens);
|
void UpdateParamAbsByArgs(const std::vector<ValuePtr> &input_args, const FuncGraphPtr &bprop_graph, bool has_sens);
|
||||||
std::vector<size_t> GetGradPositionArgs(const py::object &grad_position, bool get_by_position) const;
|
std::vector<size_t> GetGradPositionArgs(const py::object &grad_position, bool get_by_position) const;
|
||||||
|
void SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const;
|
||||||
// Manage resource for construct forward graph.
|
// Manage resource for construct forward graph.
|
||||||
AnfNodePtr GetOutputNodeAsInput(const std::string &obj_id) const;
|
AnfNodePtr GetOutputNodeAsInput(const std::string &obj_id) const;
|
||||||
AnfNodePtr GetValueSequenceInput(const ValuePtr &v, const std::string &obj_id) const;
|
AnfNodePtr GetValueSequenceInput(const ValuePtr &v, const std::string &obj_id) const;
|
||||||
AnfNodePtr CreateTupleGetItemNode(const std::string &obj_id,
|
AnfNodePtr CreateTupleGetItemNode(const std::string &obj_id,
|
||||||
const std::pair<AnfNodePtr, std::vector<int64_t>> &out) const;
|
const std::pair<AnfNodePtr, std::vector<int64_t>> &out) const;
|
||||||
|
|
||||||
|
void SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
|
||||||
|
const std::string &graph_phase) const;
|
||||||
|
bool IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
|
||||||
|
const std::string &graph_phase) const;
|
||||||
|
bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info,
|
||||||
|
const std::vector<AnfNodePtr> &new_anf_inputs) const;
|
||||||
|
bool IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info, const CNodePtr &new_cnode,
|
||||||
|
bool is_ms_function_node, const std::string &graph_phase) const;
|
||||||
bool grad_flag_{false};
|
bool grad_flag_{false};
|
||||||
bool grad_is_running_{false};
|
bool grad_is_running_{false};
|
||||||
bool need_renormalize_{false};
|
bool need_renormalize_{false};
|
||||||
bool eliminate_forward_{true};
|
bool eliminate_forward_{true};
|
||||||
|
mutable bool use_dynamic_shape_process_{false};
|
||||||
|
mutable bool is_cell_id_in_dynamic_detect_nodes_map_{false};
|
||||||
int custom_bprop_cell_count_{0};
|
int custom_bprop_cell_count_{0};
|
||||||
|
|
||||||
|
// Used in sub thread
|
||||||
size_t cell_order_{0};
|
size_t cell_order_{0};
|
||||||
|
std::string cur_cell_id_{""};
|
||||||
|
|
||||||
// If grad_order=1, indicate first derivative; grad_order=2, indicate second derivative; ...
|
// If grad_order=1, indicate first derivative; grad_order=2, indicate second derivative; ...
|
||||||
size_t grad_order_{0};
|
size_t grad_order_{0};
|
||||||
|
std::string grad_operation_;
|
||||||
TopCellInfoPtr top_cell_{nullptr};
|
TopCellInfoPtr top_cell_{nullptr};
|
||||||
|
TopCellInfoPtr pre_top_cell_{nullptr};
|
||||||
InputArgsInfoPtr top_input_args_info_{nullptr};
|
InputArgsInfoPtr top_input_args_info_{nullptr};
|
||||||
// Records every cell info for share, regardless of whether need construct grad graph
|
// Records every cell info for share, regardless of whether need construct grad graph
|
||||||
std::stack<InputArgsInfoPtr> input_args_info_stack_;
|
std::stack<InputArgsInfoPtr> input_args_info_stack_;
|
||||||
|
@ -175,11 +224,13 @@ class GradExecutor {
|
||||||
std::vector<std::string> bprop_cell_list_;
|
std::vector<std::string> bprop_cell_list_;
|
||||||
// For high grad order
|
// For high grad order
|
||||||
std::stack<TopCellInfoPtr> high_order_stack_;
|
std::stack<TopCellInfoPtr> high_order_stack_;
|
||||||
|
std::vector<TopCellInfoPtr> top_cell_list_;
|
||||||
|
// Record all top cell which has been ran
|
||||||
|
mindspore::HashMap<std::string, TopCellInfoPtr> already_run_top_cell_;
|
||||||
ForwardExecutorWeakPtr forward_executor_;
|
ForwardExecutorWeakPtr forward_executor_;
|
||||||
MsFunctionPtr ms_function_;
|
MsFunctionPtr ms_function_;
|
||||||
std::unique_ptr<AsyncQueue> async_executor_;
|
std::shared_ptr<AsyncQueue> async_executor_;
|
||||||
std::map<std::string, compile::BackendPtr> backends_;
|
mutable mindspore::HashMap<std::string, std::vector<DynamicDetectNodeInfoPtr>> cell_id_with_dynamic_detect_nodes_;
|
||||||
bool enable_async_ = false;
|
|
||||||
};
|
};
|
||||||
} // namespace pynative
|
} // namespace pynative
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -19,6 +19,8 @@
|
||||||
#include "include/common/utils/anfalgo.h"
|
#include "include/common/utils/anfalgo.h"
|
||||||
#include "include/common/utils/parallel_context.h"
|
#include "include/common/utils/parallel_context.h"
|
||||||
#include "ir/func_graph_cloner.h"
|
#include "ir/func_graph_cloner.h"
|
||||||
|
#include "runtime/pynative/async/async_queue.h"
|
||||||
|
#include "pipeline/pynative/grad/bprop_task.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace pynative {
|
namespace pynative {
|
||||||
|
@ -151,6 +153,35 @@ void MsFunction::ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, co
|
||||||
RunReplace(added_make_tuple, total_output_tensors, grad_graph);
|
RunReplace(added_make_tuple, total_output_tensors, grad_graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MsFunction::UpdateMsFunctionForwardTensors(const GradExecutor *grad_executor, const string &op_info,
|
||||||
|
const ValuePtr &new_forward_value) const {
|
||||||
|
if (grad_executor->use_dynamic_shape_process()) {
|
||||||
|
MS_LOG(DEBUG) << "Get dynamic shape process";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(new_forward_value);
|
||||||
|
MS_LOG(DEBUG) << "Ms func graph has already ran before. The graph phase is: " << graph_phase_;
|
||||||
|
MS_LOG(DEBUG) << "The output values of added forward nodes are: " << new_forward_value->ToString();
|
||||||
|
std::vector<tensor::TensorPtr> new_tensors;
|
||||||
|
TensorValueToTensor(new_forward_value, &new_tensors);
|
||||||
|
if (new_tensors.empty()) {
|
||||||
|
MS_LOG(DEBUG) << "The size of added forward tensors is zero, no need to update.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(grad_executor);
|
||||||
|
const auto &top_cell = grad_executor->top_cell();
|
||||||
|
const auto &old_tensors = top_cell->op_info_with_ms_func_forward_tensors().at(op_info);
|
||||||
|
if (old_tensors.size() != new_tensors.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The size of old tensors is: " << old_tensors.size()
|
||||||
|
<< ", but the size of new tensors is: " << new_tensors.size()
|
||||||
|
<< ", the current op info is: " << op_info;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < new_tensors.size(); ++i) {
|
||||||
|
grad_executor->UpdatePreTensorInfo(new_tensors[i], {old_tensors[i]});
|
||||||
|
old_tensors[i]->set_sync_status(kNeedSyncDeviceToHost);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void MsFunction::GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, AnfNodePtrList *input_nodes,
|
void MsFunction::GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, AnfNodePtrList *input_nodes,
|
||||||
const GradExecutor *grad_executor) const {
|
const GradExecutor *grad_executor) const {
|
||||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||||
|
@ -213,6 +244,7 @@ void MsFunction::GetWeightsNode(const FrontendOpRunInfoPtr &op_run_info, const G
|
||||||
|
|
||||||
void MsFunction::MakeCNodeForMsFunction(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
|
void MsFunction::MakeCNodeForMsFunction(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
|
||||||
const FuncGraphPtr &ms_func_graph, CNodePtr *ms_function_cnode) const {
|
const FuncGraphPtr &ms_func_graph, CNodePtr *ms_function_cnode) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||||
// Get input node info of ms_function
|
// Get input node info of ms_function
|
||||||
std::vector<AnfNodePtr> input_nodes{NewValueNode(ms_func_graph)};
|
std::vector<AnfNodePtr> input_nodes{NewValueNode(ms_func_graph)};
|
||||||
MS_EXCEPTION_IF_NULL(grad_executor);
|
MS_EXCEPTION_IF_NULL(grad_executor);
|
||||||
|
@ -222,6 +254,7 @@ void MsFunction::MakeCNodeForMsFunction(const FrontendOpRunInfoPtr &op_run_info,
|
||||||
// Make a CNode which includes ms_function fprop graph and inputs node
|
// Make a CNode which includes ms_function fprop graph and inputs node
|
||||||
MS_EXCEPTION_IF_NULL(ms_function_cnode);
|
MS_EXCEPTION_IF_NULL(ms_function_cnode);
|
||||||
*ms_function_cnode = grad_executor->top_cell()->fg()->NewCNode(input_nodes);
|
*ms_function_cnode = grad_executor->top_cell()->fg()->NewCNode(input_nodes);
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "Make ms function forward CNode: " << (*ms_function_cnode)->DebugString();
|
MS_LOG(DEBUG) << "Make ms function forward CNode: " << (*ms_function_cnode)->DebugString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -242,6 +275,10 @@ CNodePtr MsFunction::MakeAdjointForMsFunction(const FrontendOpRunInfoPtr &op_run
|
||||||
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
|
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
|
||||||
auto grad_param =
|
auto grad_param =
|
||||||
std::make_shared<ad::GradParam>(ms_function_cnode, op_run_info->input_value, op_run_info->out_value, grad_graph);
|
std::make_shared<ad::GradParam>(ms_function_cnode, op_run_info->input_value, op_run_info->out_value, grad_graph);
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
grad_executor->async_executor()->Wait();
|
||||||
|
}
|
||||||
if (!auto_grad_cell_ptr->KPynativeWithFProp(grad_param)) {
|
if (!auto_grad_cell_ptr->KPynativeWithFProp(grad_param)) {
|
||||||
MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode, ms_function cnode info: "
|
MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode, ms_function cnode info: "
|
||||||
<< ms_function_cnode->DebugString();
|
<< ms_function_cnode->DebugString();
|
||||||
|
@ -250,21 +287,55 @@ CNodePtr MsFunction::MakeAdjointForMsFunction(const FrontendOpRunInfoPtr &op_run
|
||||||
return ms_function_cnode;
|
return ms_function_cnode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MsFunction::AsyncKPynativeWithFProp(const GradExecutor *grad_executor,
|
||||||
|
const ad::AutoGradCellImplPtr &auto_grad_cell_ptr,
|
||||||
|
const ad::GradParamPtr &grad_param) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(grad_executor);
|
||||||
|
|
||||||
|
const auto fn = [this, grad_param, auto_grad_cell_ptr]() {
|
||||||
|
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
|
||||||
|
if (!auto_grad_cell_ptr->KPynativeWithFProp(grad_param)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto task = std::make_shared<BpropTask>(fn);
|
||||||
|
grad_executor->async_executor()->Push(task);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MsFunction::AsyncGradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
|
||||||
|
const ValuePtr &added_out_v, const FuncGraphPtr &ms_func_graph,
|
||||||
|
const FuncGraphPtr &grad_graph) const {
|
||||||
|
const auto fn = [this, op_run_info, grad_executor, added_out_v, ms_func_graph, grad_graph]() {
|
||||||
|
this->GradMsFunctionInner(op_run_info, grad_executor, added_out_v, ms_func_graph, grad_graph);
|
||||||
|
};
|
||||||
|
auto task = std::make_shared<BpropTask>(fn);
|
||||||
|
grad_executor->async_executor()->Push(task);
|
||||||
|
}
|
||||||
|
|
||||||
void MsFunction::GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
|
void MsFunction::GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
|
||||||
const ValuePtr &added_out_v, const FuncGraphPtr &ms_func_graph,
|
const ValuePtr &added_out_v, const FuncGraphPtr &ms_func_graph,
|
||||||
const FuncGraphPtr &grad_graph) const {
|
const FuncGraphPtr &grad_graph) const {
|
||||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||||
MS_EXCEPTION_IF_NULL(grad_executor);
|
MS_EXCEPTION_IF_NULL(grad_executor);
|
||||||
MS_LOG(DEBUG) << "ms_function actual output value: " << op_run_info->out_value->ToString();
|
MS_LOG(DEBUG) << "ms_function actual output value: " << op_run_info->out_value->ToString();
|
||||||
if (!grad_executor->grad_flag()) {
|
// Step 1: Update actual output tensors used in grad graph.
|
||||||
MS_LOG(EXCEPTION) << "The flag of need construct graph is False.";
|
MS_EXCEPTION_IF_NULL(op_run_info->out_value);
|
||||||
|
MS_LOG(DEBUG) << "ms_function actual output value: " << op_run_info->out_value->ToString();
|
||||||
|
// The output of ms_function may be used in subsequent PyNative process
|
||||||
|
grad_executor->UpdateForwardTensorInfoInBpropGraph(op_run_info);
|
||||||
|
|
||||||
|
// Step 2: Update output tensors of added forward nodes, which are added to return node of ms_function func graph.
|
||||||
|
if (grad_executor->top_cell()->op_info_with_ms_func_forward_tensors().find(op_run_info->op_info) !=
|
||||||
|
grad_executor->top_cell()->op_info_with_ms_func_forward_tensors().end()) {
|
||||||
|
UpdateMsFunctionForwardTensors(grad_executor, op_run_info->op_info, added_out_v);
|
||||||
}
|
}
|
||||||
// Update actual output tensors used in grad graph.
|
|
||||||
ReplaceNewTensorsInGradGraph(grad_executor->top_cell(), added_out_v, ms_func_graph, grad_graph);
|
ReplaceNewTensorsInGradGraph(grad_executor->top_cell(), added_out_v, ms_func_graph, grad_graph);
|
||||||
|
|
||||||
// Clone new ms_function func graph and grad graph.
|
// Clone new ms_function func graph and grad graph.
|
||||||
auto new_ms_func_graph = BasicClone(ms_func_graph);
|
auto new_ms_func_graph = BasicClone(ms_func_graph);
|
||||||
auto new_grad_graph = BasicClone(grad_graph, true);
|
auto new_grad_graph = BasicClone(grad_graph, true);
|
||||||
|
|
||||||
auto new_make_tuple = new_ms_func_graph->output()->cast<CNodePtr>();
|
auto new_make_tuple = new_ms_func_graph->output()->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(new_make_tuple);
|
MS_EXCEPTION_IF_NULL(new_make_tuple);
|
||||||
new_ms_func_graph->set_output(new_make_tuple->input(1));
|
new_ms_func_graph->set_output(new_make_tuple->input(1));
|
||||||
|
@ -273,6 +344,11 @@ void MsFunction::GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, co
|
||||||
const auto &ms_function_cnode =
|
const auto &ms_function_cnode =
|
||||||
MakeAdjointForMsFunction(op_run_info, grad_executor, new_ms_func_graph, new_grad_graph);
|
MakeAdjointForMsFunction(op_run_info, grad_executor, new_ms_func_graph, new_grad_graph);
|
||||||
ms_function_cnode->set_abstract(new_ms_func_graph->output()->abstract()->Broaden());
|
ms_function_cnode->set_abstract(new_ms_func_graph->output()->abstract()->Broaden());
|
||||||
|
|
||||||
|
auto grad_exec_ptr = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
|
||||||
|
MS_EXCEPTION_IF_NULL(grad_exec_ptr);
|
||||||
|
grad_exec_ptr->CheckGraphDynamic(ms_function_cnode, op_run_info->op_index, true,
|
||||||
|
op_run_info->base_op_run_info.op_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MsFunction::SetMsFuncGraphParameters(const FuncGraphPtr &ms_func_graph) {
|
void MsFunction::SetMsFuncGraphParameters(const FuncGraphPtr &ms_func_graph) {
|
||||||
|
@ -316,6 +392,9 @@ py::object MsFunction::GradMsFunction(const py::object &out, const py::args &arg
|
||||||
const auto &op_run_info = GetOpRunInfo(out, args, graph_phase_, &added_out_v);
|
const auto &op_run_info = GetOpRunInfo(out, args, graph_phase_, &added_out_v);
|
||||||
FuncGraphPtr grad_graph = executor->GetGradGraph(graph_phase_);
|
FuncGraphPtr grad_graph = executor->GetGradGraph(graph_phase_);
|
||||||
PyNativeAlgo::Common::DumpGraphIR("ms_func_forward_graph.ir", ms_func_graph);
|
PyNativeAlgo::Common::DumpGraphIR("ms_func_forward_graph.ir", ms_func_graph);
|
||||||
|
if (!grad_executor->grad_flag()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The flag of need construct graph is False.";
|
||||||
|
}
|
||||||
GradMsFunctionInner(op_run_info, grad_executor.get(), added_out_v, ms_func_graph, grad_graph);
|
GradMsFunctionInner(op_run_info, grad_executor.get(), added_out_v, ms_func_graph, grad_graph);
|
||||||
SetMsFuncGraphParameters(ms_func_graph);
|
SetMsFuncGraphParameters(ms_func_graph);
|
||||||
graph_phase_.clear();
|
graph_phase_.clear();
|
||||||
|
|
|
@ -42,11 +42,18 @@ class MsFunction {
|
||||||
void GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
|
void GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
|
||||||
const ValuePtr &added_out_v, const FuncGraphPtr &ms_func_graph,
|
const ValuePtr &added_out_v, const FuncGraphPtr &ms_func_graph,
|
||||||
const FuncGraphPtr &grad_graph) const;
|
const FuncGraphPtr &grad_graph) const;
|
||||||
|
void AsyncGradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
|
||||||
|
const ValuePtr &added_out_v, const FuncGraphPtr &ms_func_graph,
|
||||||
|
const FuncGraphPtr &grad_graph) const;
|
||||||
|
void AsyncKPynativeWithFProp(const GradExecutor *grad_executor, const ad::AutoGradCellImplPtr &auto_grad_cell_ptr,
|
||||||
|
const ad::GradParamPtr &grad_param) const;
|
||||||
// Update device address of value node in grad graph by forward tensors.
|
// Update device address of value node in grad graph by forward tensors.
|
||||||
void RunReplace(const CNodePtr &added_make_tuple, const std::vector<tensor::TensorPtr> &total_output_tensors,
|
void RunReplace(const CNodePtr &added_make_tuple, const std::vector<tensor::TensorPtr> &total_output_tensors,
|
||||||
const FuncGraphPtr &grad_graph) const;
|
const FuncGraphPtr &grad_graph) const;
|
||||||
void ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const ValuePtr &added_out,
|
void ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const ValuePtr &added_out,
|
||||||
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) const;
|
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) const;
|
||||||
|
void UpdateMsFunctionForwardTensors(const GradExecutor *grad_executor, const string &op_info,
|
||||||
|
const ValuePtr &new_forward_value) const;
|
||||||
// Make CNode for ms_function forward graph.
|
// Make CNode for ms_function forward graph.
|
||||||
void GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, AnfNodePtrList *input_nodes,
|
void GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, AnfNodePtrList *input_nodes,
|
||||||
const GradExecutor *grad_executor) const;
|
const GradExecutor *grad_executor) const;
|
||||||
|
|
|
@ -57,6 +57,40 @@ void TopCellInfo::RecordCellBackwardHookOp(const std::string &cell_order, const
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TopCellInfo::GetOpInfo(const FrontendOpRunInfoPtr &op_run_info) {
|
||||||
|
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||||
|
std::string input_args_info;
|
||||||
|
// Record input args info (weight or data)
|
||||||
|
// self.p = Parameter();
|
||||||
|
// def construct(x, y)
|
||||||
|
// if y:
|
||||||
|
// x = x + x
|
||||||
|
// else:
|
||||||
|
// x = x + self.p
|
||||||
|
// return x
|
||||||
|
for (size_t i = 0; i < op_run_info->base_op_run_info.input_tensor.size(); i++) {
|
||||||
|
const auto &t = op_run_info->base_op_run_info.input_tensor[i];
|
||||||
|
MS_EXCEPTION_IF_NULL(t);
|
||||||
|
if (t->is_parameter() && t->param_info() != nullptr && t->param_info()->requires_grad()) {
|
||||||
|
input_args_info += "w";
|
||||||
|
} else {
|
||||||
|
input_args_info += "d";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Record op name and index
|
||||||
|
op_run_info->op_info.clear();
|
||||||
|
op_run_info->op_info +=
|
||||||
|
op_run_info->base_op_run_info.op_name + "-" + std::to_string(op_index_) + "-" + input_args_info;
|
||||||
|
const auto &out_abs = op_run_info->base_op_run_info.abstract;
|
||||||
|
auto shape = out_abs->BuildShape();
|
||||||
|
MS_EXCEPTION_IF_NULL(shape);
|
||||||
|
if (!shape->isa<abstract::NoShape>() && !shape->IsDimZero()) {
|
||||||
|
op_run_info->op_info += "-" + shape->ToString();
|
||||||
|
}
|
||||||
|
op_run_info->op_index = op_index_;
|
||||||
|
++op_index_;
|
||||||
|
}
|
||||||
|
|
||||||
void TopCellInfo::ClearDeviceMemory() const {
|
void TopCellInfo::ClearDeviceMemory() const {
|
||||||
MS_LOG(DEBUG) << "Clear device memory in value nodes of bprop graph, top cell: " << cell_id_;
|
MS_LOG(DEBUG) << "Clear device memory in value nodes of bprop graph, top cell: " << cell_id_;
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
@ -154,11 +188,40 @@ void TopCellInfo::SetNestedMultipleOutputToGraphInfoMap(const string &id, const
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TopCellInfo::Clear() {
|
||||||
|
MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_;
|
||||||
|
hook_changed_ = false;
|
||||||
|
ms_function_flag_ = false;
|
||||||
|
is_init_kpynative_ = false;
|
||||||
|
need_compile_graph_ = false;
|
||||||
|
forward_already_run_ = false;
|
||||||
|
op_index_ = 0;
|
||||||
|
resource_ = nullptr;
|
||||||
|
fg_ = nullptr;
|
||||||
|
graph_info_map_.clear();
|
||||||
|
op_info_with_tensor_id_.clear();
|
||||||
|
tensor_id_with_tensor_object_.clear();
|
||||||
|
op_info_with_ms_func_forward_tensors_.clear();
|
||||||
|
cnode_hash_with_op_index_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
void TopCellInfo::SetUnpackOutputToGraphInfoMap(const std::string &id, const AnfNodePtr &node,
|
void TopCellInfo::SetUnpackOutputToGraphInfoMap(const std::string &id, const AnfNodePtr &node,
|
||||||
const std::vector<int64_t> &index) const {
|
const std::vector<int64_t> &index) const {
|
||||||
auto &graph_info = graph_info_map().at(fg());
|
auto &graph_info = graph_info_map().at(fg());
|
||||||
MS_EXCEPTION_IF_NULL(graph_info);
|
MS_EXCEPTION_IF_NULL(graph_info);
|
||||||
graph_info->node_map[id] = std::make_pair(node, index);
|
graph_info->node_map[id] = std::make_pair(node, index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TopCellInfo::set_opinfo_with_tensor_id(const std::string &op_info,
|
||||||
|
const std::vector<tensor::TensorPtr> &op_out_tensors) {
|
||||||
|
if (op_info_with_tensor_id_.find(op_info) != op_info_with_tensor_id_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Top cell: " << cell_id_ << " records op info with tensor id, but get op info " << op_info
|
||||||
|
<< " in op_info_with_tensor_id map";
|
||||||
|
}
|
||||||
|
// Record the relationship between the forward op and its output tensor id
|
||||||
|
(void)std::for_each(op_out_tensors.begin(), op_out_tensors.end(), [this, &op_info](const tensor::TensorPtr &tensor) {
|
||||||
|
(void)op_info_with_tensor_id_[op_info].emplace_back(tensor->id());
|
||||||
|
});
|
||||||
|
}
|
||||||
} // namespace pynative
|
} // namespace pynative
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -42,6 +42,9 @@ namespace mindspore {
|
||||||
namespace pynative {
|
namespace pynative {
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
class GradExecutor;
|
class GradExecutor;
|
||||||
|
using OpInfoWithTensorId = mindspore::HashMap<std::string, std::vector<std::string>>;
|
||||||
|
using TensorIdWithTensorObject = mindspore::HashMap<std::string, std::vector<tensor::TensorPtr>>;
|
||||||
|
using OpInfoWithMsFuncForwardTensors = mindspore::HashMap<std::string, std::vector<tensor::TensorPtr>>;
|
||||||
using CellIdWithBackwardHookOp = mindspore::HashMap<std::string, std::vector<AnfNodePtr>>;
|
using CellIdWithBackwardHookOp = mindspore::HashMap<std::string, std::vector<AnfNodePtr>>;
|
||||||
|
|
||||||
struct GraphInfo {
|
struct GraphInfo {
|
||||||
|
@ -55,9 +58,10 @@ using GraphInfoPtr = std::shared_ptr<GraphInfo>;
|
||||||
class TopCellInfo {
|
class TopCellInfo {
|
||||||
public:
|
public:
|
||||||
~TopCellInfo() = default;
|
~TopCellInfo() = default;
|
||||||
TopCellInfo(size_t grad_order, std::string cellid, std::string already_run_cell_id, pipeline::ResourcePtr r,
|
TopCellInfo(size_t grad_order, std::string c_cell_id, std::string cellid, std::string already_run_cell_id,
|
||||||
FuncGraphPtr fg)
|
pipeline::ResourcePtr r, FuncGraphPtr fg)
|
||||||
: grad_order_(grad_order),
|
: grad_order_(grad_order),
|
||||||
|
c_cell_id_(std::move(c_cell_id)),
|
||||||
cell_id_(std::move(cellid)),
|
cell_id_(std::move(cellid)),
|
||||||
already_run_cell_id_(std::move(already_run_cell_id)),
|
already_run_cell_id_(std::move(already_run_cell_id)),
|
||||||
resource_(std::move(r)),
|
resource_(std::move(r)),
|
||||||
|
@ -70,11 +74,14 @@ class TopCellInfo {
|
||||||
inline void set_sub_cell_hook_changed(const std::string &sub_cell) { (void)sub_cell_hook_changed_.emplace(sub_cell); }
|
inline void set_sub_cell_hook_changed(const std::string &sub_cell) { (void)sub_cell_hook_changed_.emplace(sub_cell); }
|
||||||
inline const CellIdWithBackwardHookOp &cell_backward_hook_op() const { return cell_backward_hook_op_; }
|
inline const CellIdWithBackwardHookOp &cell_backward_hook_op() const { return cell_backward_hook_op_; }
|
||||||
void RecordCellBackwardHookOp(const std::string &cell_order, const AnfNodePtr &hook_op);
|
void RecordCellBackwardHookOp(const std::string &cell_order, const AnfNodePtr &hook_op);
|
||||||
|
void GetOpInfo(const FrontendOpRunInfoPtr &op_run_info);
|
||||||
inline void ClearCellHookOp() { cell_backward_hook_op_.clear(); }
|
inline void ClearCellHookOp() { cell_backward_hook_op_.clear(); }
|
||||||
inline bool ms_function_flag() const { return ms_function_flag_; }
|
inline bool ms_function_flag() const { return ms_function_flag_; }
|
||||||
inline void set_ms_function_flag(bool ms_function_flag) { ms_function_flag_ = ms_function_flag; }
|
inline void set_ms_function_flag(bool ms_function_flag) { ms_function_flag_ = ms_function_flag; }
|
||||||
inline bool forward_already_run() const { return forward_already_run_; }
|
inline bool forward_already_run() const { return forward_already_run_; }
|
||||||
inline void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
|
inline void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
|
||||||
|
inline bool need_compile_graph() const { return need_compile_graph_; }
|
||||||
|
inline void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
|
||||||
inline pipeline::ResourcePtr resource() const { return resource_; }
|
inline pipeline::ResourcePtr resource() const { return resource_; }
|
||||||
inline FuncGraphPtr fg() const {
|
inline FuncGraphPtr fg() const {
|
||||||
MS_EXCEPTION_IF_NULL(fg_);
|
MS_EXCEPTION_IF_NULL(fg_);
|
||||||
|
@ -82,18 +89,51 @@ class TopCellInfo {
|
||||||
}
|
}
|
||||||
inline void set_fg(const FuncGraphPtr &fg) { fg_ = fg; }
|
inline void set_fg(const FuncGraphPtr &fg) { fg_ = fg; }
|
||||||
inline const std::string &cell_id() const { return cell_id_; }
|
inline const std::string &cell_id() const { return cell_id_; }
|
||||||
|
inline const std::string &c_cell_id() const { return c_cell_id_; }
|
||||||
inline const std::string &already_run_cell_id() const { return already_run_cell_id_; }
|
inline const std::string &already_run_cell_id() const { return already_run_cell_id_; }
|
||||||
inline void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; }
|
inline void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; }
|
||||||
inline const std::string &input_args_id() const { return input_args_id_; }
|
inline const std::string &input_args_id() const { return input_args_id_; }
|
||||||
|
const std::string &grad_operation() const { return grad_operation_; }
|
||||||
|
void set_grad_operation(const std::string &grad_operation) { grad_operation_ = grad_operation; }
|
||||||
inline void CheckSubCellHookChanged() { sub_cell_hook_changed_.clear(); }
|
inline void CheckSubCellHookChanged() { sub_cell_hook_changed_.clear(); }
|
||||||
inline void SetGraphInfoMap(const FuncGraphPtr &fg, const GraphInfoPtr &graph_info) {
|
inline void SetGraphInfoMap(const FuncGraphPtr &fg, const GraphInfoPtr &graph_info) {
|
||||||
graph_info_map_[fg] = graph_info;
|
graph_info_map_[fg] = graph_info;
|
||||||
}
|
}
|
||||||
|
inline void set_is_run_cell(bool is_run_cell) { is_run_cell_ = is_run_cell; }
|
||||||
|
inline bool is_run_cell() { return is_run_cell_; }
|
||||||
inline const OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() const { return graph_info_map_; }
|
inline const OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() const { return graph_info_map_; }
|
||||||
inline ad::AutoGradCellImplPtr auto_grad_cell_ptr() const { return auto_grad_cell_ptr_; }
|
inline ad::AutoGradCellImplPtr auto_grad_cell_ptr() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr_);
|
||||||
|
return auto_grad_cell_ptr_;
|
||||||
|
}
|
||||||
void set_auto_grad_cell_ptr(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr) {
|
void set_auto_grad_cell_ptr(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr) {
|
||||||
auto_grad_cell_ptr_ = auto_grad_cell_ptr;
|
auto_grad_cell_ptr_ = auto_grad_cell_ptr;
|
||||||
}
|
}
|
||||||
|
inline const OpInfoWithTensorId &op_info_with_tensor_id() const { return op_info_with_tensor_id_; }
|
||||||
|
void set_opinfo_with_tensor_id(const std::string &op_info, const std::vector<tensor::TensorPtr> &op_out_tensors);
|
||||||
|
inline const TensorIdWithTensorObject &tensor_id_with_tensor_object() const { return tensor_id_with_tensor_object_; }
|
||||||
|
inline void set_tensor_id_with_tensor_object(const std::string &id, const tensor::TensorPtr &tensor) {
|
||||||
|
(void)tensor_id_with_tensor_object_[id].emplace_back(tensor);
|
||||||
|
}
|
||||||
|
inline const OpInfoWithMsFuncForwardTensors &op_info_with_ms_func_forward_tensors() const {
|
||||||
|
return op_info_with_ms_func_forward_tensors_;
|
||||||
|
}
|
||||||
|
inline size_t op_index() const { return op_index_; }
|
||||||
|
inline void IncreaseOpIndex() { op_index_++; }
|
||||||
|
|
||||||
|
inline void set_cnode_hash_with_op_index(const size_t &node_hash, const size_t &op_index) {
|
||||||
|
cnode_hash_with_op_index_[node_hash] = op_index;
|
||||||
|
}
|
||||||
|
inline size_t get_op_index_by_cnode_hash(const size_t &node_hash) {
|
||||||
|
auto iter = cnode_hash_with_op_index_.find(node_hash);
|
||||||
|
if (iter == cnode_hash_with_op_index_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "hash:" << node_hash << " is not found in cnode_hash_with_op_index_";
|
||||||
|
}
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Clear();
|
||||||
|
|
||||||
void DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id);
|
void DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id);
|
||||||
void SetParamNodeMapInGraphInfoMap(const std::string &id, const ParameterPtr ¶m, bool is_weight = false) const;
|
void SetParamNodeMapInGraphInfoMap(const std::string &id, const ParameterPtr ¶m, bool is_weight = false) const;
|
||||||
void SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index = -1,
|
void SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index = -1,
|
||||||
|
@ -111,7 +151,11 @@ class TopCellInfo {
|
||||||
bool ms_function_flag_{false};
|
bool ms_function_flag_{false};
|
||||||
bool is_init_kpynative_{false};
|
bool is_init_kpynative_{false};
|
||||||
bool forward_already_run_{false};
|
bool forward_already_run_{false};
|
||||||
|
bool need_compile_graph_{false};
|
||||||
|
bool is_run_cell_{false};
|
||||||
|
size_t op_index_{0};
|
||||||
size_t grad_order_{0};
|
size_t grad_order_{0};
|
||||||
|
std::string c_cell_id_;
|
||||||
std::string cell_id_;
|
std::string cell_id_;
|
||||||
std::string already_run_cell_id_;
|
std::string already_run_cell_id_;
|
||||||
std::string input_args_id_;
|
std::string input_args_id_;
|
||||||
|
@ -126,6 +170,10 @@ class TopCellInfo {
|
||||||
// Record backward hook ops for each cell object.
|
// Record backward hook ops for each cell object.
|
||||||
// Each cell object has two backward hook ops.
|
// Each cell object has two backward hook ops.
|
||||||
CellIdWithBackwardHookOp cell_backward_hook_op_;
|
CellIdWithBackwardHookOp cell_backward_hook_op_;
|
||||||
|
OpInfoWithTensorId op_info_with_tensor_id_;
|
||||||
|
TensorIdWithTensorObject tensor_id_with_tensor_object_;
|
||||||
|
OpInfoWithMsFuncForwardTensors op_info_with_ms_func_forward_tensors_;
|
||||||
|
mindspore::HashMap<size_t, size_t> cnode_hash_with_op_index_;
|
||||||
};
|
};
|
||||||
using TopCellInfoPtr = std::shared_ptr<TopCellInfo>;
|
using TopCellInfoPtr = std::shared_ptr<TopCellInfo>;
|
||||||
} // namespace pynative
|
} // namespace pynative
|
||||||
|
|
|
@ -139,6 +139,7 @@ void PyNativeExecutor::ClearRes() const {
|
||||||
void PyNativeExecutor::Init() {
|
void PyNativeExecutor::Init() {
|
||||||
MS_LOG(DEBUG) << "Init PyNativeExecutor";
|
MS_LOG(DEBUG) << "Init PyNativeExecutor";
|
||||||
forward_executor_ = std::make_shared<ForwardExecutor>();
|
forward_executor_ = std::make_shared<ForwardExecutor>();
|
||||||
|
forward_executor_->Init();
|
||||||
grad_executor_ = std::make_shared<GradExecutor>(forward_executor_);
|
grad_executor_ = std::make_shared<GradExecutor>(forward_executor_);
|
||||||
forward_executor_->set_grad_executor(grad_executor_);
|
forward_executor_->set_grad_executor(grad_executor_);
|
||||||
}
|
}
|
||||||
|
@ -161,8 +162,8 @@ bool PyNativeExecutor::grad_flag() const { return grad_executor()->grad_flag();
|
||||||
void PyNativeExecutor::set_grad_flag(bool flag) const { grad_executor()->set_grad_flag(flag); }
|
void PyNativeExecutor::set_grad_flag(bool flag) const { grad_executor()->set_grad_flag(flag); }
|
||||||
|
|
||||||
py::object PyNativeExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj,
|
py::object PyNativeExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj,
|
||||||
const py::args &args) const {
|
const py::object &grad_hash_id, const py::args &args) const {
|
||||||
return grad_executor()->CheckAlreadyRun(grad, obj, args);
|
return grad_executor()->CheckAlreadyRun(grad, obj, grad_hash_id, args);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PyNativeExecutor::NewGraph(const py::object &obj, const py::args &args) const {
|
void PyNativeExecutor::NewGraph(const py::object &obj, const py::args &args) const {
|
||||||
|
@ -187,7 +188,10 @@ void PyNativeExecutor::EndGraph(const py::object &obj, const py::object &out, co
|
||||||
forward_executor()->ProcessAfterEndGraph(obj, is_cell);
|
forward_executor()->ProcessAfterEndGraph(obj, is_cell);
|
||||||
}
|
}
|
||||||
|
|
||||||
py::object PyNativeExecutor::Run() const { return PyNativeExecutorTry(grad_executor()->RunGraph); }
|
py::object PyNativeExecutor::Run() const {
|
||||||
|
const auto &ret = PyNativeExecutorTry(grad_executor()->RunGraph);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
void PyNativeExecutor::GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
void PyNativeExecutor::GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||||
const py::object &grad_position, const py::args &args) const {
|
const py::object &grad_position, const py::args &args) const {
|
||||||
|
@ -195,13 +199,22 @@ void PyNativeExecutor::GradNet(const prim::GradOperationPtr &grad, const py::obj
|
||||||
}
|
}
|
||||||
|
|
||||||
py::object PyNativeExecutor::GradMsFunction(const py::object &out, const py::args &args) const {
|
py::object PyNativeExecutor::GradMsFunction(const py::object &out, const py::args &args) const {
|
||||||
return grad_executor()->ms_function()->GradMsFunction(out, args);
|
const auto &ret = grad_executor()->ms_function()->GradMsFunction(out, args);
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PyNativeExecutor::SetLazyBuild(bool enable) const { forward_executor()->set_lazy_build(enable); }
|
void PyNativeExecutor::SetLazyBuild(bool enable) const { forward_executor()->set_lazy_build(enable); }
|
||||||
|
|
||||||
bool PyNativeExecutor::IsFirstCell() const { return forward_executor()->IsFirstCell(); }
|
bool PyNativeExecutor::IsFirstCell() const { return forward_executor()->IsFirstCell(); }
|
||||||
|
|
||||||
|
void PyNativeExecutor::SetMsFunctionCompileStatus(bool is_compiling) const {
|
||||||
|
forward_executor()->set_is_ms_function_compiling(is_compiling);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PyNativeExecutor::SetDynamicInput(const py::object &cell, const py::args &args) const {
|
||||||
|
grad_executor()->set_use_dynamic_shape_process(true);
|
||||||
|
}
|
||||||
|
|
||||||
void RegPyNativeExecutor(const py::module *m) {
|
void RegPyNativeExecutor(const py::module *m) {
|
||||||
(void)py::class_<PyNativeExecutor, std::shared_ptr<PyNativeExecutor>>(*m, "PyNativeExecutor_")
|
(void)py::class_<PyNativeExecutor, std::shared_ptr<PyNativeExecutor>>(*m, "PyNativeExecutor_")
|
||||||
.def_static("get_instance", &PyNativeExecutor::GetInstance, "PyNativeExecutor get_instance.")
|
.def_static("get_instance", &PyNativeExecutor::GetInstance, "PyNativeExecutor get_instance.")
|
||||||
|
@ -220,10 +233,13 @@ void RegPyNativeExecutor(const py::module *m) {
|
||||||
.def("set_hook_changed", &PyNativeExecutor::SetHookChanged, "set pynative hook changed")
|
.def("set_hook_changed", &PyNativeExecutor::SetHookChanged, "set pynative hook changed")
|
||||||
.def("set_grad_flag", &PyNativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
|
.def("set_grad_flag", &PyNativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
|
||||||
"Executor set grad flag.")
|
"Executor set grad flag.")
|
||||||
|
.def("set_dynamic_input", &PyNativeExecutor::SetDynamicInput, "set dynamic input")
|
||||||
.def("set_py_exe_path", &PyNativeExecutor::set_py_exe_path, py::arg("py_exe_path") = py::str(""),
|
.def("set_py_exe_path", &PyNativeExecutor::set_py_exe_path, py::arg("py_exe_path") = py::str(""),
|
||||||
"set python executable path.")
|
"set python executable path.")
|
||||||
.def("set_kernel_build_server_dir", &PyNativeExecutor::set_kernel_build_server_dir,
|
.def("set_kernel_build_server_dir", &PyNativeExecutor::set_kernel_build_server_dir,
|
||||||
py::arg("kernel_build_server_dir") = py::str(""), "set kernel build server directory path.")
|
py::arg("kernel_build_server_dir") = py::str(""), "set kernel build server directory path.")
|
||||||
|
.def("set_ms_function_compile_status", &PyNativeExecutor::SetMsFunctionCompileStatus,
|
||||||
|
"set ms_funciton compile status.")
|
||||||
.def("real_run_op", &PyNativeExecutor::RealRunOp, "Run op pynatively.")
|
.def("real_run_op", &PyNativeExecutor::RealRunOp, "Run op pynatively.")
|
||||||
.def("constant_folding", &PyNativeExecutor::CallConstantFolding, "Call Constant Folding Primitive");
|
.def("constant_folding", &PyNativeExecutor::CallConstantFolding, "Call Constant Folding Primitive");
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,13 +65,17 @@ class PyNativeExecutor : public std::enable_shared_from_this<PyNativeExecutor> {
|
||||||
void GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
void GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||||
const py::object &grad_position, const py::args &args) const;
|
const py::object &grad_position, const py::args &args) const;
|
||||||
py::object GradMsFunction(const py::object &out, const py::args &args) const;
|
py::object GradMsFunction(const py::object &out, const py::args &args) const;
|
||||||
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::args &args) const;
|
void SetDynamicInput(const py::object &cell, const py::args &args) const;
|
||||||
|
|
||||||
|
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &grad_hash_id,
|
||||||
|
const py::args &args) const;
|
||||||
void ClearRes() const;
|
void ClearRes() const;
|
||||||
// Sync stream
|
// Sync stream
|
||||||
void Sync() const;
|
void Sync() const;
|
||||||
void SetLazyBuild(bool enable) const;
|
void SetLazyBuild(bool enable) const;
|
||||||
bool IsFirstCell() const;
|
bool IsFirstCell() const;
|
||||||
void WorkerJoin() { grad_executor_->WorkerJoin(); }
|
void WorkerJoin() { grad_executor_->WorkerJoin(); }
|
||||||
|
void SetMsFunctionCompileStatus(bool is_compiling) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
PyNativeExecutor() = default;
|
PyNativeExecutor() = default;
|
||||||
|
|
|
@ -602,13 +602,16 @@ TensorPtr GraphCompiler::GetSingleOpInputTensorByIndex(const CNodePtr &kernel,
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info,
|
void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info,
|
||||||
|
bool use_dynamic_shape_process,
|
||||||
session::BackendOpRunInfoPtr *op_run_info, GraphInfo *graph_info,
|
session::BackendOpRunInfoPtr *op_run_info, GraphInfo *graph_info,
|
||||||
const GraphOutputInfo *const graph_output_info) {
|
const GraphOutputInfo *const graph_output_info) {
|
||||||
MS_EXCEPTION_IF_NULL(session_);
|
MS_EXCEPTION_IF_NULL(session_);
|
||||||
MS_EXCEPTION_IF_NULL(graph_info);
|
MS_EXCEPTION_IF_NULL(graph_info);
|
||||||
*op_run_info = session_->GetSingleOpRunInfo(kernel, *graph_info, tensor_info, graph_output_info);
|
*op_run_info = session_->GetSingleOpRunInfo(kernel, *graph_info, tensor_info, graph_output_info);
|
||||||
session_->GetSingleOpGraphInfo(kernel, tensor_info, graph_info, *op_run_info);
|
session_->GetSingleOpGraphInfo(kernel, tensor_info, graph_info, *op_run_info);
|
||||||
|
MS_EXCEPTION_IF_NULL(*op_run_info);
|
||||||
(*op_run_info)->base_op_run_info.graph_info = *graph_info;
|
(*op_run_info)->base_op_run_info.graph_info = *graph_info;
|
||||||
|
(*op_run_info)->base_op_run_info.use_dynamic_shape_process = use_dynamic_shape_process;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const {
|
void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const {
|
||||||
|
|
|
@ -130,8 +130,8 @@ class GraphCompiler {
|
||||||
|
|
||||||
// Get OpRunInfo and GraphInfo for single op compile and run.
|
// Get OpRunInfo and GraphInfo for single op compile and run.
|
||||||
void GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info,
|
void GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info,
|
||||||
session::BackendOpRunInfoPtr *op_run_info, GraphInfo *graph_info,
|
bool use_dynamic_shape_process, session::BackendOpRunInfoPtr *op_run_info,
|
||||||
const GraphOutputInfo *const graph_output_info);
|
GraphInfo *graph_info, const GraphOutputInfo *const graph_output_info);
|
||||||
|
|
||||||
// Calculate ref count of PyNative back propagation operators.
|
// Calculate ref count of PyNative back propagation operators.
|
||||||
void CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const;
|
void CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const;
|
||||||
|
|
|
@ -296,7 +296,9 @@ class _MindsporeFunctionExecutor:
|
||||||
args_list = args
|
args_list = args
|
||||||
if self.obj is not None:
|
if self.obj is not None:
|
||||||
args_list = args_list[1:]
|
args_list = args_list[1:]
|
||||||
|
_pynative_executor.set_ms_function_compile_status(True)
|
||||||
phase = self.compile(args_list, self.fn.__name__)
|
phase = self.compile(args_list, self.fn.__name__)
|
||||||
|
_pynative_executor.set_ms_function_compile_status(False)
|
||||||
if context.get_context("precompile_only"):
|
if context.get_context("precompile_only"):
|
||||||
return None
|
return None
|
||||||
new_inputs = self._generate_run_args(args_list)
|
new_inputs = self._generate_run_args(args_list)
|
||||||
|
@ -428,6 +430,7 @@ class _MindsporeFunctionExecutor:
|
||||||
self.input_signature.append(args_list[-1])
|
self.input_signature.append(args_list[-1])
|
||||||
Validator.check_dynamic_shape(self.input_signature, args_list)
|
Validator.check_dynamic_shape(self.input_signature, args_list)
|
||||||
compile_args = tuple(self.input_signature)
|
compile_args = tuple(self.input_signature)
|
||||||
|
_pynative_executor.set_dynamic_input(self.obj, *compile_args)
|
||||||
return compile_args
|
return compile_args
|
||||||
|
|
||||||
def _generate_run_args(self, args_list):
|
def _generate_run_args(self, args_list):
|
||||||
|
@ -1012,7 +1015,7 @@ class _PyNativeExecutor:
|
||||||
"""
|
"""
|
||||||
self._executor.end_graph(obj, output, *args, *(kwargs.values()))
|
self._executor.end_graph(obj, output, *args, *(kwargs.values()))
|
||||||
|
|
||||||
def check_run(self, grad, obj, *args, **kwargs):
|
def check_run(self, grad, obj, grad_hash_id, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Whether the forward graph need to construct.
|
Whether the forward graph need to construct.
|
||||||
|
|
||||||
|
@ -1026,7 +1029,7 @@ class _PyNativeExecutor:
|
||||||
Return:
|
Return:
|
||||||
bool, specifies whether the forward graph need to construct.
|
bool, specifies whether the forward graph need to construct.
|
||||||
"""
|
"""
|
||||||
return self._executor.check_run(grad, obj, *args, *(kwargs.values()))
|
return self._executor.check_run(grad, obj, grad_hash_id, *args, *(kwargs.values()))
|
||||||
|
|
||||||
def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
|
def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -1122,6 +1125,30 @@ class _PyNativeExecutor:
|
||||||
"""
|
"""
|
||||||
self._executor.set_grad_flag(flag)
|
self._executor.set_grad_flag(flag)
|
||||||
|
|
||||||
|
def set_ms_function_compile_status(self, status):
|
||||||
|
"""
|
||||||
|
Set ms_function is compiling
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status(bool): ms_function compile status
|
||||||
|
Return:
|
||||||
|
None.
|
||||||
|
"""
|
||||||
|
self._executor.set_ms_function_compile_status(status)
|
||||||
|
|
||||||
|
def set_dynamic_input(self, obj, *args):
|
||||||
|
"""
|
||||||
|
Set dynamic shape tensor of input arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj (Function/Cell): The function or cell instance.
|
||||||
|
args (tuple): Function or cell dynamic input arguments.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
None.
|
||||||
|
"""
|
||||||
|
self._executor.set_dynamic_input(obj, *args)
|
||||||
|
|
||||||
def is_first_cell(self):
|
def is_first_cell(self):
|
||||||
"""
|
"""
|
||||||
The flag of first cell instance.
|
The flag of first cell instance.
|
||||||
|
|
|
@ -891,6 +891,8 @@ class Cell(Cell_):
|
||||||
self._check_construct_args(*inputs)
|
self._check_construct_args(*inputs)
|
||||||
if self._dynamic_shape_inputs:
|
if self._dynamic_shape_inputs:
|
||||||
ds.config.set_dynamic_shape(True)
|
ds.config.set_dynamic_shape(True)
|
||||||
|
if context._get_mode() == context.PYNATIVE_MODE:
|
||||||
|
_pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
|
||||||
|
|
||||||
def get_inputs(self):
|
def get_inputs(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -392,14 +392,14 @@ class GradOperation(GradOperation_):
|
||||||
new_kwargs = kwargs.copy()
|
new_kwargs = kwargs.copy()
|
||||||
new_kwargs.pop('sens')
|
new_kwargs.pop('sens')
|
||||||
if isinstance(fn, (FunctionType, MethodType)):
|
if isinstance(fn, (FunctionType, MethodType)):
|
||||||
if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
|
if not _pynative_executor.check_run(grad, fn, self.weights_id, *args, **new_kwargs):
|
||||||
_pynative_executor.set_grad_flag(True)
|
_pynative_executor.set_grad_flag(True)
|
||||||
_pynative_executor.new_graph(fn, *args, **new_kwargs)
|
_pynative_executor.new_graph(fn, *args, **new_kwargs)
|
||||||
output = fn(*args, **new_kwargs)
|
output = fn(*args, **new_kwargs)
|
||||||
_pynative_executor.end_graph(fn, output, *args, **new_kwargs)
|
_pynative_executor.end_graph(fn, output, *args, **new_kwargs)
|
||||||
else:
|
else:
|
||||||
# Check if fn have run already
|
# Check if fn have run already
|
||||||
if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
|
if not _pynative_executor.check_run(grad, fn, self.weights_id, *args, **new_kwargs):
|
||||||
fn.set_grad()
|
fn.set_grad()
|
||||||
fn(*args, **new_kwargs)
|
fn(*args, **new_kwargs)
|
||||||
fn.set_grad(False)
|
fn.set_grad(False)
|
||||||
|
@ -465,6 +465,7 @@ class _Grad(GradOperation_):
|
||||||
self.pynative_ = False
|
self.pynative_ = False
|
||||||
self.grad_position = None
|
self.grad_position = None
|
||||||
self.weights_id = None
|
self.weights_id = None
|
||||||
|
self.grad_hash_id = None
|
||||||
|
|
||||||
def __call__(self, fn, weights=None, grad_position=0):
|
def __call__(self, fn, weights=None, grad_position=0):
|
||||||
weights_id = _get_grad_weights_id(weights)
|
weights_id = _get_grad_weights_id(weights)
|
||||||
|
@ -537,6 +538,7 @@ class _Grad(GradOperation_):
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
self.grad_position = grad_position
|
self.grad_position = grad_position
|
||||||
self.weights_id = weights_id
|
self.weights_id = weights_id
|
||||||
|
self.grad_hash_id = (grad_position, weights_id)
|
||||||
return self.grad_fn
|
return self.grad_fn
|
||||||
|
|
||||||
def _pynative_forward_run(self, fn, grad, args, kwargs):
|
def _pynative_forward_run(self, fn, grad, args, kwargs):
|
||||||
|
@ -550,7 +552,7 @@ class _Grad(GradOperation_):
|
||||||
else:
|
else:
|
||||||
args = args[:-1]
|
args = args[:-1]
|
||||||
if isinstance(fn, (FunctionType, MethodType)):
|
if isinstance(fn, (FunctionType, MethodType)):
|
||||||
if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
|
if not _pynative_executor.check_run(grad, fn, self.grad_hash_id, *args, **new_kwargs):
|
||||||
_pynative_executor.set_grad_flag(True)
|
_pynative_executor.set_grad_flag(True)
|
||||||
_pynative_executor.new_graph(fn, *args, **new_kwargs)
|
_pynative_executor.new_graph(fn, *args, **new_kwargs)
|
||||||
outputs = fn(*args, **new_kwargs)
|
outputs = fn(*args, **new_kwargs)
|
||||||
|
@ -558,7 +560,7 @@ class _Grad(GradOperation_):
|
||||||
return outputs
|
return outputs
|
||||||
else:
|
else:
|
||||||
# Check if fn has run already.
|
# Check if fn has run already.
|
||||||
if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
|
if not _pynative_executor.check_run(grad, fn, self.grad_hash_id, *args, **new_kwargs):
|
||||||
fn.set_grad()
|
fn.set_grad()
|
||||||
outputs = fn(*args, **new_kwargs)
|
outputs = fn(*args, **new_kwargs)
|
||||||
fn.set_grad(False)
|
fn.set_grad(False)
|
||||||
|
|
Loading…
Reference in New Issue