!46108 Add pynative front dynamic detect function form master

Merge pull request !46108 from wanghenchang/front-dynamic-detect-master1
This commit is contained in:
i-robot 2022-11-29 01:59:04 +00:00 committed by Gitee
commit c189903d04
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
21 changed files with 978 additions and 89 deletions

View File

@ -688,6 +688,7 @@ void MindRTBackend::RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_i
graph_compiler_->CalculateForwardOpOutputCount(graph, inputs[graph_index], &forward_op_output_tensor_id_);
}
bool use_dynamic_shape_process = root_graph_->has_flag(kFlagUseDynamicShapeProcess);
py::gil_scoped_release release;
for (const auto &kernel : graph->execution_order()) {
InputTensorInfo input_tensor_info;
@ -714,9 +715,8 @@ void MindRTBackend::RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_i
GraphInfo graph_info;
graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index],
&input_tensor_info);
graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info, &op_run_info, &graph_info,
&graph_output_info);
bool use_dynamic_shape_process = op_run_info->base_op_run_info.use_dynamic_shape_process;
graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info, use_dynamic_shape_process,
&op_run_info, &graph_info, &graph_output_info);
if (use_dynamic_shape_process) {
RunOpDynamic(op_run_info, &op_outputs);
} else {
@ -751,7 +751,8 @@ void MindRTBackend::RunGraphByCondition(const ActorInfo &actor_info, const Graph
}
if (contain_cut_graph || root_graph_->has_flag(kFlagIsDynamicStructure) ||
(enable_backend_dynamic_detect_ && root_graph_->has_flag(kFlagIsPynativeBpropGraph) && is_dynamic)) {
(enable_backend_dynamic_detect_ && root_graph_->has_flag(kFlagIsPynativeBpropGraph) && is_dynamic) ||
root_graph_->has_flag(kFlagUseDynamicShapeProcess)) {
RunGraphBySingleOp(graph_compiler_info, args, outputs);
} else {
RunGraphByActors(actor_info, graph_compiler_info, args, outputs);

View File

@ -149,7 +149,7 @@ class BACKEND_EXPORT MindRTBackendBase : public Backend {
// Save the mapping between cell id and actor info.
mindspore::HashMap<std::string, ActorInfo> graph_actor_infos_;
bool enable_backend_dynamic_detect_{true};
bool enable_backend_dynamic_detect_{false};
FuncGraphPtr root_graph_;
GraphPartitionPtr graph_partition_;
std::shared_ptr<GraphCompiler> graph_compiler_;

View File

@ -46,11 +46,11 @@ struct GradParam {
: cnode(cnode), op_args(op_args), out(out), fprop_fg(std::move(fprop_fg)) {}
// Primal CNode create by op forward process
const CNodePtr &cnode;
const CNodePtr cnode;
// Input value for cnode
const ValuePtrList &op_args;
const ValuePtrList op_args;
// Output of op
const ValuePtr &out;
const ValuePtr out;
// Bprop func graph
const FuncGraphPtr fprop_fg;
// High order used this, which

View File

@ -907,6 +907,7 @@ constexpr auto kFlagIsPynativeBpropGraph = "is_pynative_bprop_graph";
constexpr auto kFlagPyNativeRunInGraph = "pynative_run_in_graph";
constexpr auto kFlagNeedRenormalize = "need_renormalize";
constexpr auto kFlagEnableZeroCopyInGraph = "enable_zero_copy_in_graph";
constexpr auto kFlagUseDynamicShapeProcess = "use_dynamic_shape_process";
// TODO(dsj): for ms_function running in graph_mode. should be delete later
constexpr auto kAttrMSFunction = "ms_function_graph";

View File

@ -59,6 +59,7 @@ struct FrontendOpRunInfo {
bool grad_flag = false;
bool output_get_by_infer_value = false;
int mix_type{0};
size_t op_index = 0;
size_t input_size = 0;
size_t custom_bprop_cell_count = 0;
PrimitivePyPtr op_prim{nullptr};
@ -88,6 +89,8 @@ struct InputArgsInfo {
size_t input_size;
std::string obj_id;
bool has_sens{false};
bool is_run_cell{false};
bool use_dynamic_shape_process = false;
PrimitivePyPtr custom_bprp_prim{nullptr};
ValuePtr out_value{nullptr};
std::string cell_id;

View File

@ -274,6 +274,7 @@ ValuePtr CastOperation::DoAutoCast(const FrontendOpRunInfoPtr &op_run_info, cons
cast_run_info->base_op_run_info.next_op_name = op_name;
cast_run_info->base_op_run_info.next_input_index = index;
cast_run_info->base_op_run_info.lazy_build = op_run_info->base_op_run_info.lazy_build;
cast_run_info->base_op_run_info.use_dynamic_shape_process = op_run_info->base_op_run_info.use_dynamic_shape_process;
(void)cast_run_info->input_value.emplace_back(v);
(void)cast_run_info->input_value.emplace_back(GetDstType(type_id));
cast_run_info->input_size = input_size;

View File

@ -183,10 +183,21 @@ void ForwardExecutor::RunOpForward(const FrontendOpRunInfoPtr &op_run_info) {
if (!op_run_info->output_get_by_infer_value) {
GetOutput(op_run_info);
}
if (!op_run_info->grad_flag) {
MS_LOG(DEBUG) << "Grad flag is false";
return;
}
// Set forward output flag for release memory,
// Because tensor address may change, it should set in main thread to ensure consistency.
PyNativeAlgo::Common::SetForwardOutputFlag(op_run_info->out_value);
// Const value no need do op grad
if (op_run_info->output_get_by_infer_value) {
return;
}
// 4. Do op grad and record op info
if (enable_async_) {
grad()->AsyncProcessOpGradInfo(op_run_info);
} else {
if (!is_ms_function_compiling_) {
grad()->ProcessOpGradInfo(op_run_info);
}
}
@ -199,10 +210,13 @@ FrontendOpRunInfoPtr ForwardExecutor::GenerateOpRunInfo(const py::args &args) co
// Used for async run
op_run_info->grad_flag = grad()->grad_flag();
op_run_info->custom_bprop_cell_count = grad()->custom_bprop_cell_count();
op_run_info->base_op_run_info.use_dynamic_shape_process =
(device_target_ == kAscendDevice ? false : grad()->use_dynamic_shape_process());
op_run_info->base_op_run_info.op_name = args[static_cast<size_t>(RunOpArgsEnum::PY_NAME)].cast<std::string>();
op_run_info->base_op_run_info.lazy_build = lazy_build_;
PyNativeAlgo::PyParser::SetPrim(op_run_info, args[static_cast<size_t>(RunOpArgsEnum::PY_PRIM)]);
PyNativeAlgo::PyParser::ParseOpInputByPythonObj(op_run_info, args[static_cast<size_t>(RunOpArgsEnum::PY_INPUTS)]);
(void)op_run_prim_py_list_.emplace_back(op_run_info->op_prim);
return op_run_info;
}
@ -412,6 +426,7 @@ void ForwardExecutor::Sync() {
MS_EXCEPTION_IF_NULL(item.second);
item.second->SyncStream();
}
op_run_prim_py_list_.clear();
}
ValuePtr ForwardExecutor::RunOpInMs(const FrontendOpRunInfoPtr &op_run_info) {
@ -466,6 +481,7 @@ void ForwardExecutor::ClearRes() {
infer_operation()->ClearConstFlagPrimCache();
std::stack<CellPtr>().swap(forward_cell_stack_);
mindrt_backends_.clear();
op_run_prim_py_list_.clear();
}
} // namespace pynative
} // namespace mindspore

View File

@ -22,6 +22,7 @@
#include <map>
#include <utility>
#include <stack>
#include <vector>
#include "pipeline/pynative/forward/do_cast.h"
#include "pipeline/pynative/forward/do_infer.h"
#include "backend/graph_compiler/backend.h"
@ -71,6 +72,10 @@ class ForwardExecutor {
MS_EXCEPTION_IF_NULL(infer_operation_);
return infer_operation_;
}
inline void set_is_ms_function_compiling(bool is_ms_function_compiling) {
is_ms_function_compiling_ = is_ms_function_compiling;
}
inline std::string device_target() { return device_target_; }
private:
GradExecutorPtr grad() const;
@ -94,6 +99,7 @@ class ForwardExecutor {
private:
bool init_{false};
bool lazy_build_{false};
bool is_ms_function_compiling_{false};
uint32_t device_id_{0};
std::string last_target_{"Unknown"};
std::string device_target_;
@ -103,6 +109,7 @@ class ForwardExecutor {
InferOperationPtr infer_operation_;
MindrtBackendMap mindrt_backends_;
bool enable_async_ = false;
mutable std::vector<PrimitivePyPtr> op_run_prim_py_list_;
};
} // namespace pynative
} // namespace mindspore

View File

@ -63,7 +63,7 @@ std::string GetCellId(const py::object &obj, const py::args &args, const InputAr
return cell_id;
}
InputArgsInfoPtr GetInputArgsInfo(const py::object &obj, const py::args &args, bool is_grad_top_cell,
InputArgsInfoPtr ParsePyArgsToInputArgsInfo(const py::object &obj, const py::args &args, bool is_grad_top_cell,
bool is_high_order_top_cell) {
bool has_custom_bprop = py::hasattr(obj, parse::CUSTOM_BPROP_NAME);
const auto &obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
@ -82,6 +82,7 @@ InputArgsInfoPtr GetInputArgsInfo(const py::object &obj, const py::args &args, b
}
pipeline::CheckArgsValid(obj, args);
}
input_args_info->is_run_cell = py::isinstance<Cell>(obj);
input_args_info->cell_id = GetCellId(obj, args, input_args_info);
MS_LOG(DEBUG) << "cell_id is " << obj_id << ", is grad top cell " << (is_grad_top_cell || is_high_order_top_cell);
return input_args_info;
@ -200,10 +201,10 @@ ForwardExecutorPtr GradExecutor::forward() const {
}
std::string GradExecutor::GetCurCellOrder() const {
if (input_args_info_stack_.empty()) {
MS_LOG(EXCEPTION) << "The input_args_info_stack_ is empty!";
if (cur_cell_id_.empty()) {
MS_LOG(EXCEPTION) << "The cur_cell_id_ is empty!";
}
return input_args_info_stack_.top()->cell_id + "_" + std::to_string(cell_order_);
return cur_cell_id_ + "_" + std::to_string(cell_order_);
}
TopCellInfoPtr GradExecutor::PopHighOrderGraphStack() {
@ -300,12 +301,18 @@ void GradExecutor::InitResourceAndDfBuilder(const InputArgsInfoPtr &input_args_i
auto graph_info_cg = std::make_shared<GraphInfo>();
top_cell()->SetGraphInfoMap(curr_g(), graph_info_cg);
HandleInputArgsForTopCell(input_args_info, false);
top_cell()->set_need_compile_graph(true);
top_cell()->set_init_kpynative(true);
}
}
void GradExecutor::UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph) const {
top_cell()->set_need_compile_graph(need_compile_graph);
top_cell()->set_forward_already_run(forward_already_run);
}
void GradExecutor::NewGraphInner(const py::object &obj, const py::args &args) {
const auto &input_args_info = GetInputArgsInfo(obj, args, input_args_info_stack_.empty(), is_high_order_top_cell());
const auto input_args_info = GetInputArgsInfo(obj, args);
PushInputArgsInfoStack(input_args_info);
if (input_args_info->has_custom_bprop) {
@ -317,17 +324,21 @@ void GradExecutor::NewGraphInner(const py::object &obj, const py::args &args) {
}
input_args_info->grad_order = grad_order_;
// May be can async here
if (enable_async_) {
AsyncNewGraphImpl(input_args_info);
} else {
NewGraphImpl(input_args_info);
}
InputArgsInfoPtr GradExecutor::GetInputArgsInfo(const py::object &obj, const py::args &args) {
auto input_args_info =
ParsePyArgsToInputArgsInfo(obj, args, input_args_info_stack_.empty(), is_high_order_top_cell());
input_args_info->use_dynamic_shape_process = use_dynamic_shape_process_;
return input_args_info;
}
void GradExecutor::NewGraphImpl(const InputArgsInfoPtr &input_args_info) {
MS_EXCEPTION_IF_NULL(input_args_info);
++cell_order_;
const auto &cell_id = input_args_info->cell_id;
cur_cell_id_ = cell_id;
MS_LOG(DEBUG) << "NewGraphInner start " << input_args_info->input_size << ", cell_id " << cell_id
<< ", input args info ptr " << input_args_info.get();
// Make top graph and init resource
@ -357,12 +368,18 @@ void GradExecutor::MakeNewTopGraph(const InputArgsInfoPtr &input_args_info) {
auto fg = std::make_shared<FuncGraph>();
fg->debug_info()->set_name("pynative_forward_graph");
auto resource = std::make_shared<pipeline::Resource>();
const auto &already_run_cell_id = input_args_info->cell_id + std::to_string(input_args_info->grad_order);
top_cell_ = std::make_shared<TopCellInfo>(input_args_info->grad_order, input_args_info->cell_id, already_run_cell_id,
resource, fg);
const auto &already_run_cell_id = GetAlreadyRunCellId(input_args_info->cell_id);
top_cell_ = std::make_shared<TopCellInfo>(input_args_info->grad_order, input_args_info->obj_id,
input_args_info->cell_id, already_run_cell_id, resource, fg);
top_cell_->set_forward_already_run(true);
top_cell_->set_is_run_cell(input_args_info->is_run_cell);
top_cell_->set_input_args_id(input_args_info->input_args_id);
PushHighOrderGraphStack(top_cell_);
(void)top_cell_list_.emplace_back(top_cell_);
const auto &cell_id = input_args_info->obj_id.append("_").append(std::to_string(grad_order_));
is_cell_id_in_dynamic_detect_nodes_map_ =
(cell_id_with_dynamic_detect_nodes_.find(cell_id) != cell_id_with_dynamic_detect_nodes_.end());
MS_LOG(DEBUG) << "New top graph, fg ptr " << fg.get() << " resource ptr " << resource.get();
}
@ -387,8 +404,12 @@ void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v, const std::string &
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
auto sens_v = ConvertOutputValueToTensor(v);
auto cloned_value = ShallowCopyTensorValue(sens_v);
if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) {
AsyncUpdateOutputNodeOfTopCell(output_node, cloned_value);
} else {
auto_grad_cell_ptr->UpdateOutputNodeOfTopCell(output_node, cloned_value);
}
}
void GradExecutor::EndGraphInner(const py::object &obj, const py::object &out, const py::args &args) {
if (input_args_info_stack_.empty()) {
@ -400,17 +421,14 @@ void GradExecutor::EndGraphInner(const py::object &obj, const py::object &out, c
GetCustomBpropPrim(obj, args, out, input_args_info);
}
input_args_info->out_value = PyNativeAlgo::DataConvert::PyObjToValue(out);
input_args_info->use_dynamic_shape_process = use_dynamic_shape_process_;
PopInputArgsInfoStack();
if (input_args_info->is_grad_topest_cell) {
set_grad_flag(false);
}
// May be can async here
if (enable_async_) {
AsyncEndGraphImpl(input_args_info);
} else {
EndGraphImpl(input_args_info);
}
}
void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
MS_EXCEPTION_IF_NULL(input_args_info);
@ -453,6 +471,7 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
SetForwardLastNodeInfo(out_value, out_id);
}
top_cell()->CheckSubCellHookChanged();
CheckNeedCompileGraph(input_args_info);
top_input_args_info_ = input_args_info;
}
}
@ -478,7 +497,15 @@ void GradExecutor::DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info,
op_run_info->input_size = input_args_info->input_arg_value_vec.size();
op_run_info->input_value_id = input_args_info->input_arg_id_vec;
auto cnode = ConstructForwardGraph(op_run_info);
if (grad_is_running_ && !bprop_grad_stack_.top().second) {
MS_LOG(DEBUG) << "Custom bprop, no need do op grad";
return;
}
DoOpGrad(op_run_info, cnode, input_args_info->out_value);
CheckGraphDynamic(cnode, top_cell()->op_index());
top_cell()->IncreaseOpIndex();
SaveOutputNodeMap(out_id, op_run_info, cnode);
}
@ -535,6 +562,56 @@ void GradExecutor::GetCustomBpropPrim(const py::object &obj, const py::args &arg
input_args_info->custom_bprp_prim = fake_prim;
}
void GradExecutor::CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info) {
const auto &new_top_cell = top_cell();
const auto &already_top_cell_id = new_top_cell->already_run_cell_id();
// Update top cell by current cell op info
if (already_run_top_cell_.find(already_top_cell_id) == already_run_top_cell_.end()) {
MS_LOG(DEBUG) << "Cell " << already_top_cell_id << " has never been ran, need compile graph";
already_run_top_cell_[already_top_cell_id] = new_top_cell;
pre_top_cell_ = top_cell();
return;
}
MS_LOG(DEBUG) << "Top cell " << new_top_cell->cell_id() << " has been ran";
auto pre_top_cell = already_run_top_cell_.at(already_top_cell_id);
MS_EXCEPTION_IF_NULL(pre_top_cell);
if (input_args_info->use_dynamic_shape_process || !input_args_info->is_run_cell) {
// Function need compile every time.
MS_LOG(DEBUG) << "The graph is dynamic, need to compile graph again";
EraseTopCellFromTopCellList(pre_top_cell);
{
py::gil_scoped_acquire acquire;
pre_top_cell->Clear();
}
already_run_top_cell_[already_top_cell_id] = new_top_cell;
pre_top_cell_ = nullptr;
} else {
MS_LOG(DEBUG) << "no need to compile graph again";
pre_top_cell->set_input_args_id(new_top_cell->input_args_id());
// In high order situations, the internal top cell remains unchanged, but the external top cell has changed. Then
// the graph info of the internal top cell needs to be updated so that the external top cell can perceive it.
if (!input_args_info->is_grad_topest_cell) {
pre_top_cell->SetGraphInfoMap(pre_top_cell->fg(), new_top_cell->graph_info_map().at(new_top_cell->fg()));
}
pre_top_cell_ = pre_top_cell;
pre_top_cell->set_forward_already_run(true);
}
}
void GradExecutor::EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell) {
MS_EXCEPTION_IF_NULL(top_cell);
auto iter = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
[&](const TopCellInfoPtr &elem) { return elem.get() == top_cell.get(); });
if (iter == top_cell_list_.end()) {
MS_LOG(WARNING) << "Can not find top cell " << top_cell.get() << " cell id " << top_cell->cell_id()
<< " from top cell list";
} else {
(void)top_cell_list_.erase(iter);
}
}
void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights,
const py::object &grad_position, const py::args &args) {
{
@ -558,7 +635,19 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob
(void)top_input_args_info_->input_arg_value_vec.emplace_back(ShallowCopyTensorValue(sens_v));
top_input_args_info_->has_sens = true;
}
if (pre_top_cell_ != nullptr) {
set_top_cell(pre_top_cell_);
}
if (!top_cell()->need_compile_graph()) {
MS_LOG(DEBUG) << "No need compile graph";
top_cell_list_.pop_back();
UpdateTopCellInfo(false, false);
return;
}
MS_LOG(DEBUG) << "Need compile graph";
top_cell()->set_grad_operation(grad_operation_);
SetBpropGraphJitLevel(obj);
bool weight_param_is_tuple = true;
auto w_args = GetWeightsArgs(weights, &weight_param_is_tuple);
@ -568,12 +657,22 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob
GetGradGraph(grad_attr, w_args, p_args);
}
std::string GradExecutor::GetAlreadyRunCellId(const std::string &cell_id) const {
std::string already_run_cell_id(cell_id);
already_run_cell_id += std::to_string(grad_order_ == 0 ? 1 : grad_order_);
already_run_cell_id += "_" + grad_operation_;
MS_LOG(DEBUG) << "Get already run top cell id " << already_run_cell_id;
return already_run_cell_id;
}
void GradExecutor::GetGradGraph(const ad::GradAttr &grad_attr, const std::vector<AnfNodePtr> &w_args,
const std::vector<size_t> &p_args) {
// Get bprop graph of top cell
auto bprop_graph = GetBpropGraph(grad_attr, w_args, p_args);
MS_EXCEPTION_IF_NULL(bprop_graph);
bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true);
bool use_dynamic_shape_process = (forward()->device_target() == kAscendDevice ? false : use_dynamic_shape_process_);
bprop_graph->set_flag(kFlagUseDynamicShapeProcess, use_dynamic_shape_process);
MS_EXCEPTION_IF_NULL(top_input_args_info_);
bprop_graph->set_attr(kAttrFuncGraphCellId, MakeValue(top_input_args_info_->obj_id));
auto resource = top_cell()->resource();
@ -583,14 +682,13 @@ void GradExecutor::GetGradGraph(const ad::GradAttr &grad_attr, const std::vector
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(bprop_graph, true);
PyNativeAlgo::Common::DumpGraphIR("launch_bprop_graph.ir", bprop_graph);
if (backends_.find(top_input_args_info_->obj_id) == backends_.end()) {
backends_[top_input_args_info_->obj_id] = compile::CreateBackend();
}
resource->SetBackendAsync([&]() { return backends_[top_input_args_info_->obj_id]; });
SaveForwardTensorInfoInBpropGraph(resource);
resource->SetBackendAsync([]() { return compile::CreateBackend(); });
MS_LOG(DEBUG) << "Start task emit action";
(void)TaskEmitAction(resource);
MS_LOG(DEBUG) << "Start execute action";
(void)ExecuteAction(resource);
UpdateTopCellInfo(false, false);
resource->Clean();
}
@ -761,10 +859,18 @@ void GradExecutor::SetGradOrder(const std::string &cell_id) {
}
py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj,
const py::args &args) {
const py::object &grad_hash_id, const py::args &args) {
auto cell_id = GetCellId(obj, args, nullptr);
// Check current cell grad order and erase it if in current top cell list
SetGradOrder(cell_id);
// Include weight param size and required grad flag
std::string grad_hash_id_str;
if (!py::isinstance<py::none>(grad_hash_id)) {
grad_hash_id_str = std::string(py::str(grad_hash_id));
}
grad_operation_ = std::to_string(static_cast<int>(grad->get_all_)) +
std::to_string(static_cast<int>(grad->get_by_list_)) + grad_hash_id_str;
std::string input_args_id;
for (size_t i = 0; i < args.size(); ++i) {
@ -774,8 +880,9 @@ py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, con
// check whether need to run forward process
bool forward_run = false;
if (input_args_info_stack_.empty() && top_cell_ != nullptr) {
cell_id += std::to_string(grad_order_ == 0 ? 1 : grad_order_);
if (CanGetTopCell(cell_id)) {
const auto &check_already_run_cell_id = GetAlreadyRunCellId(cell_id);
auto find_top_cell = GetTopCell(check_already_run_cell_id);
if (find_top_cell != nullptr) {
MS_LOG(DEBUG) << "Find already run top cell";
forward_run = top_cell()->forward_already_run();
bool input_args_changed = !top_cell()->input_args_id().empty() && top_cell()->input_args_id() != input_args_id;
@ -948,7 +1055,12 @@ void GradExecutor::ClearGradRes() {
if (top_cell_ != nullptr) {
top_cell_->ClearDeviceMemory();
}
if (use_dynamic_shape_process_ ||
already_run_top_cell_.find(top_cell_->already_run_cell_id()) != already_run_top_cell_.end()) {
top_cell_ = nullptr;
}
DecreaseGradOrder();
ClearGlobalRes();
}
@ -959,13 +1071,21 @@ void GradExecutor::ClearRes() {
grad_is_running_ = false;
need_renormalize_ = false;
eliminate_forward_ = true;
use_dynamic_shape_process_ = false;
is_cell_id_in_dynamic_detect_nodes_map_ = false;
custom_bprop_cell_count_ = 0;
grad_order_ = 0;
top_cell_ = nullptr;
top_input_args_info_ = nullptr;
bprop_cell_list_.clear();
backends_.clear();
async_executor_->Reset();
for (const auto &cell_ptr : top_cell_list_) {
MS_EXCEPTION_IF_NULL(cell_ptr);
cell_ptr->Clear();
}
top_cell_list_.clear();
already_run_top_cell_.clear();
cell_id_with_dynamic_detect_nodes_.clear();
std::stack<InputArgsInfoPtr>().swap(input_args_info_stack_);
std::stack<std::pair<std::string, bool>>().swap(bprop_grad_stack_);
std::stack<TopCellInfoPtr>().swap(high_order_stack_);
@ -1072,6 +1192,8 @@ AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v, const std::str
auto cnode = curr_g()->NewCNode(inputs);
MS_LOG(DEBUG) << "Create make tuple node: " << cnode->DebugString();
top_cell()->SetNodeMapInGraphInfoMap(obj_id, cnode, -1, false);
CheckGraphDynamic(cnode, top_cell()->op_index());
top_cell()->IncreaseOpIndex();
return cnode;
}
@ -1099,10 +1221,42 @@ AnfNodePtr GradExecutor::CreateTupleGetItemNode(const std::string &obj_id,
c_node->set_abstract(prim_abs);
}
}
CheckGraphDynamic(c_node, top_cell()->op_index());
top_cell()->IncreaseOpIndex();
MS_LOG(DEBUG) << "Get input node " << c_node->ToString() << ", id " << obj_id;
return c_node;
}
TopCellInfoPtr GradExecutor::GetTopCell(const std::string &already_run_cell_id) {
TopCellInfoPtr find_top_cell = nullptr;
for (const auto &top_cell : top_cell_list_) {
MS_EXCEPTION_IF_NULL(top_cell);
// Complete match, means run grad operation first
if (top_cell->already_run_cell_id() == already_run_cell_id) {
return top_cell;
}
// Partial match, means run forward first
if (already_run_cell_id.find(top_cell->already_run_cell_id()) != std::string::npos &&
top_cell->already_run_cell_id().back() == '_') {
find_top_cell = top_cell;
break;
}
}
// Same topcell info, but grad operation is not the same, construct backward graph again
if (find_top_cell != nullptr) {
if (!find_top_cell->grad_operation().empty() && find_top_cell->grad_operation() != grad_operation_) {
MS_LOG(DEBUG) << "Already exist grad operation " << find_top_cell->grad_operation() << " is different with new "
<< grad_operation_;
EraseTopCellFromTopCellList(find_top_cell);
(void)already_run_top_cell_.erase(find_top_cell->already_run_cell_id());
return nullptr;
} else {
return find_top_cell;
}
}
return nullptr;
}
void GradExecutor::SetHookChanged(const py::object &cell) const {
if (top_cell_ == nullptr) {
return;
@ -1118,24 +1272,19 @@ void GradExecutor::SetHookChanged(const py::object &cell) const {
void GradExecutor::ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const {
MS_EXCEPTION_IF_NULL(op_run_info);
if (!op_run_info->grad_flag) {
MS_LOG(DEBUG) << "Grad flag is false";
return;
}
// Set forward output flag for release memory
PyNativeAlgo::Common::SetForwardOutputFlag(op_run_info->out_value);
// Const value no need do op grad
if (op_run_info->output_get_by_infer_value) {
return;
}
// Do op grad and save node info. If cell have custom bprop, no need do op grad. Otherwise, need do.
if (op_run_info->custom_bprop_cell_count <= 0) {
const auto &cnode = ConstructForwardGraph(op_run_info);
MS_EXCEPTION_IF_NULL(cnode);
cnode->set_abstract(op_run_info->base_op_run_info.abstract);
SaveOutputNodeMap(op_run_info->out_value_id, op_run_info, cnode);
if (grad_is_running_ && !bprop_grad_stack_.top().second) {
MS_LOG(DEBUG) << "Custom bprop, no need do op grad";
return;
}
DoOpGrad(op_run_info, cnode, op_run_info->out_value);
CheckGraphDynamic(cnode, top_cell()->op_index());
UpdateForwardTensorInfoInBpropGraph(op_run_info);
}
}
@ -1163,10 +1312,6 @@ void GradExecutor::SaveOutputNodeMap(const std::string &obj_id, const FrontendOp
// Run ad grad for curr op and connect grad graph with previous op
void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode,
const ValuePtr &op_out) const {
if (grad_is_running_ && !bprop_grad_stack_.top().second) {
MS_LOG(DEBUG) << "Custom bprop, no need do op grad";
return;
}
MS_EXCEPTION_IF_NULL(op_run_info);
// to avoid out exist in tape bprop, avoid out be modified.
@ -1175,14 +1320,196 @@ void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNode
std::back_inserter(cloned_op_args),
[](const ValuePtr &value) { return ShallowCopyTensorValue(value); });
ValuePtr cloned_out = ShallowCopyTensorValue(op_out);
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(cloned_out, &tensors);
for (auto tensor : tensors) {
tensor->set_is_forward_output(true);
}
if (!ad::GradPynativeOp(top_cell()->auto_grad_cell_ptr(), cnode, cloned_op_args, cloned_out)) {
MS_LOG(EXCEPTION) << "Failed to run ad grad for op " << op_run_info->base_op_run_info.op_name;
}
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) {
AsyncGradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out);
} else {
GradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out);
}
}
void GradExecutor::GradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode,
const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const {
if (!ad::GradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out)) {
MS_LOG(EXCEPTION) << "Failed to run ad grad for op ";
}
}
void GradExecutor::AsyncGradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode,
const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const {
const auto fn = [this, auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out]() {
this->GradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out);
};
auto task = std::make_shared<BpropTask>(fn);
async_executor_->Push(task);
}
void GradExecutor::AsyncUpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &cloned_value) const {
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
const auto fn = [auto_grad_cell_ptr, output_node, cloned_value]() {
auto_grad_cell_ptr->UpdateOutputNodeOfTopCell(output_node, cloned_value);
};
auto task = std::make_shared<BpropTask>(fn);
async_executor_->Push(task);
}
void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPtr &op_run_info) const {
MS_EXCEPTION_IF_NULL(op_run_info);
if (op_run_info->base_op_run_info.use_dynamic_shape_process) {
MS_LOG(DEBUG) << "Get dynamic shape process";
return;
}
top_cell()->GetOpInfo(op_run_info);
MS_LOG(DEBUG) << "Current op info: " << op_run_info->op_info;
std::vector<tensor::TensorPtr> op_output_tensors;
// Get output tensors
TensorValueToTensor(op_run_info->out_value, &op_output_tensors);
// Save all tensors info of current op
top_cell()->set_opinfo_with_tensor_id(op_run_info->op_info, op_output_tensors);
// First run top cell
if (already_run_top_cell_.find(top_cell_->already_run_cell_id()) == already_run_top_cell_.end()) {
MS_LOG(DEBUG) << "Top cell " << top_cell_->cell_id() << " run firstly";
return;
}
// Non-first run
const auto &pre_top_cell = already_run_top_cell_.at(top_cell_->already_run_cell_id());
MS_EXCEPTION_IF_NULL(pre_top_cell);
if (pre_top_cell->op_info_with_tensor_id().find(op_run_info->op_info) ==
pre_top_cell->op_info_with_tensor_id().end()) {
MS_LOG(DEBUG) << "Can not find op info " << op_run_info->op_info << " in op info with tensor id map. Top cell "
<< top_cell_->cell_id();
return;
}
// Update new output tensor info in bprop graph
const auto &pre_op_tensor_id = pre_top_cell->op_info_with_tensor_id().at(op_run_info->op_info);
if (pre_op_tensor_id.size() != op_output_tensors.size()) {
MS_LOG(EXCEPTION) << "The size of op pre output tensor size: " << pre_op_tensor_id.size()
<< " is not equal to current " << op_output_tensors.size();
}
// For value node tensor in the bprop graph, take its id for tensor, and save in tensor_id_with_tensor_object;
// And then take the output of the op and find out if the output used by tensor_id_with_tensor_object,
// if there is a tensor need to replace it.
const auto &pre_tensor_id_with_tensor_object = pre_top_cell->tensor_id_with_tensor_object();
for (size_t i = 0; i < pre_op_tensor_id.size(); ++i) {
auto pre_id = pre_op_tensor_id[i];
if (pre_tensor_id_with_tensor_object.find(pre_id) == pre_tensor_id_with_tensor_object.end()) {
continue;
}
// Based on the output size of the op is fixed, so can use index.
const auto &new_tensor = op_output_tensors[i];
const auto &pre_tensor_object = pre_tensor_id_with_tensor_object.at(pre_id);
UpdatePreTensorInfo(new_tensor, pre_tensor_object);
}
}
void GradExecutor::UpdatePreTensorInfo(const tensor::TensorPtr &new_tensor,
const std::vector<tensor::TensorPtr> &pre_tensors) const {
MS_EXCEPTION_IF_NULL(new_tensor);
if (pre_tensors.empty() || new_tensor->device_address() == nullptr) {
MS_LOG(DEBUG) << "The number of pre tensors is zero or the device address of new tensor is nullptr.";
return;
}
const auto &device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
for (auto &pre_tensor : pre_tensors) {
MS_EXCEPTION_IF_NULL(pre_tensor);
MS_LOG(DEBUG) << "Replace Old tensor id " << pre_tensor->id() << " device_address: " << pre_tensor->device_address()
<< " shape and type " << pre_tensor->GetShapeAndDataTypeInfo() << " with New tensor id "
<< new_tensor->id() << " device_address " << new_tensor->device_address() << " shape and dtype "
<< new_tensor->GetShapeAndDataTypeInfo();
(void)pre_tensor->set_shape(new_tensor->shape());
(void)pre_tensor->set_data_type(new_tensor->data_type());
auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
MS_EXCEPTION_IF_NULL(device_address);
if (device_target != kCPUDevice && device_address->GetDeviceType() != device::DeviceType::kCPU) {
pre_tensor->set_device_address(new_tensor->device_address());
continue;
}
for (const auto &item : PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->mindrt_backend()) {
MS_EXCEPTION_IF_NULL(item.second);
item.second->WaitTaskFinish();
}
// Replace data in device address when run in CPU device.
if (pre_tensor->device_address() != nullptr) {
// If tensor is dynamic shape, Just replace device address.
if (PyNativeAlgo::Common::ValueHasDynamicShape(pre_tensor)) {
pre_tensor->set_device_address(new_tensor->device_address());
continue;
}
auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(pre_tensor->device_address());
MS_EXCEPTION_IF_NULL(old_device_address);
auto new_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
MS_EXCEPTION_IF_NULL(new_device_address);
// CPU host tensor data_c is different from device address if the address is from mem_pool.
if (new_device_address->from_mem_pool()) {
pre_tensor->set_device_address(new_device_address);
continue;
}
auto old_ptr = old_device_address->GetMutablePtr();
MS_EXCEPTION_IF_NULL(old_ptr);
auto new_ptr = new_device_address->GetPtr();
MS_EXCEPTION_IF_NULL(new_ptr);
MS_EXCEPTION_IF_CHECK_FAIL(old_device_address->GetSize() == new_device_address->GetSize(), "Size not equal");
if (old_device_address->GetSize() < SECUREC_MEM_MAX_LEN) {
auto ret_code = memcpy_s(old_ptr, old_device_address->GetSize(), new_ptr, new_device_address->GetSize());
MS_EXCEPTION_IF_CHECK_FAIL(ret_code == EOK, "Memory copy failed, ret code: " + std::to_string(ret_code));
} else {
auto ret_code = std::memcpy(old_ptr, new_ptr, old_device_address->GetSize());
MS_EXCEPTION_IF_CHECK_FAIL(ret_code == old_ptr, "Memory copy failed");
}
} else {
pre_tensor->set_device_address(device_address);
pre_tensor->data_sync();
pre_tensor->set_device_address(nullptr);
pre_tensor->set_sync_status(kNeedSyncHostToDevice);
}
}
}
void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const {
if (use_dynamic_shape_process_) {
return;
}
// Get all tensors id of forward op
mindspore::HashSet<std::string> forward_op_tensor_id;
const auto &op_info_with_tensor_id = top_cell()->op_info_with_tensor_id();
for (const auto &record : op_info_with_tensor_id) {
(void)std::for_each(
record.second.begin(), record.second.end(),
[&forward_op_tensor_id](const std::string &tensor_id) { (void)forward_op_tensor_id.emplace(tensor_id); });
}
// Get all tensors obj in value node of bprop graph
MS_EXCEPTION_IF_NULL(resource);
const auto &bprop_graph = resource->func_graph();
MS_EXCEPTION_IF_NULL(bprop_graph);
const auto &value_node_list = bprop_graph->value_nodes();
std::vector<tensor::TensorPtr> tensors_in_bprop_graph;
for (const auto &elem : value_node_list) {
auto value_node = elem.first->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
TensorValueToTensor(value_node->value(), &tensors_in_bprop_graph);
}
// Save tensor in value node of bprop graph
for (const auto &tensor : tensors_in_bprop_graph) {
MS_EXCEPTION_IF_NULL(tensor);
if (forward_op_tensor_id.find(tensor->id()) == forward_op_tensor_id.end() || tensor->device_address() == nullptr) {
continue;
}
tensor->set_is_forward_output(true);
top_cell()->set_tensor_id_with_tensor_object(tensor->id(), tensor);
MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
<< " device address: " << tensor->device_address() << " shape and dtype "
<< tensor->GetShapeAndDataTypeInfo();
}
}
AnfNodePtr GradExecutor::GetRealInputNodeBySkipHook(const AnfNodePtr &input_node) const {
@ -1241,6 +1568,7 @@ CNodePtr GradExecutor::ConstructForwardGraph(const FrontendOpRunInfoPtr &op_run_
if (IsPrimitiveCNode(cnode, prim::kPrimCellBackwardHook)) {
top_cell()->RecordCellBackwardHookOp(GetCurCellOrder(), cnode);
}
MS_LOG(DEBUG) << "Make CNode for " << op_run_info->base_op_run_info.op_name << ", new cnode is "
<< cnode->DebugString();
return cnode;
@ -1260,5 +1588,235 @@ void GradExecutor::SetBpropGraphJitLevel(const py::object &obj) const {
MS_EXCEPTION_IF_NULL(graph_executor);
graph_executor->SetJitConfig(jit_config_dict);
}
void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, const size_t &node_idx,
bool is_ms_function_node,
const std::string &graph_phase) const {
MS_EXCEPTION_IF_NULL(cnode);
auto node_info = std::make_shared<DynamicDetectNodeInfo>();
if (!is_ms_function_node) {
node_info->prim = GetCNodePrimitive(cnode);
for (size_t i = 1; i < cnode->inputs().size(); i++) {
const auto &input_node = cnode->input(i);
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<ValueNode>()) {
node_info->input_values[i] = GetValueNode(input_node);
} else if (input_node->isa<CNode>()) {
const auto &node_abs = input_node->abstract();
auto op_index = top_cell()->get_op_index_by_cnode_hash(input_node->hash());
node_info->input_cnode_info[i] = std::make_pair(op_index, node_abs);
} else {
if (!input_node->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "input_node:" << input_node->fullname_with_scope()
<< " is none of value node, cnode and parameter.";
}
const auto &param = input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
node_info->input_param_infos[i] = param->param_info();
}
}
node_info->output_abs = cnode->abstract();
} else {
node_info->is_graph_node = true;
node_info->graph_phase = graph_phase;
}
top_cell()->set_cnode_hash_with_op_index(cnode->hash(), node_idx);
const auto &cell_id = top_cell()->c_cell_id() + "_" + std::to_string(top_cell()->grad_order());
(void)cell_id_with_dynamic_detect_nodes_[cell_id].emplace_back(node_info);
}
bool IsAbsDifferent(const AbstractBasePtr &old_abs, const AbstractBasePtr &new_abs) {
if (old_abs == new_abs) {
return false;
}
if (old_abs == nullptr || new_abs == nullptr) {
MS_LOG(DEBUG) << "graph is dynamic, old_abs is different with new_abs";
return true;
}
if (!common::IsEqual(old_abs->BuildType(), new_abs->BuildType()) ||
!common::IsEqual(old_abs->BuildShape(), new_abs->BuildShape())) {
MS_LOG(DEBUG) << "graph is dynamic, old_abs is different with new_abs, old abs:" << old_abs->ToString()
<< " new abs:" << new_abs->ToString();
return true;
}
return false;
}
bool IsValuePtrEqual(const ValuePtr &v1, const ValuePtr &v2) {
if (v1 == v2) {
return true;
}
if (v1 == nullptr || v2 == nullptr) {
return false;
}
if (v1->isa<tensor::Tensor>() && v2->isa<tensor::Tensor>()) {
return v1->cast<tensor::TensorPtr>()->ValueEqual(*(v2->cast<tensor::TensorPtr>()));
}
return *v1 == *v2;
}
bool IsParamInfoEqual(const ParamInfoPtr &p1, const ParamInfoPtr &p2) {
if (p1 == p2) {
return true;
}
if (p1 == nullptr || p2 == nullptr) {
return false;
}
return p1->key() == p2->key();
}
bool GradExecutor::IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info,
const std::vector<AnfNodePtr> &new_anf_inputs) const {
MS_EXCEPTION_IF_NULL(old_node_info);
auto old_input_size = old_node_info->input_cnode_info.size() + old_node_info->input_values.size() +
old_node_info->input_param_infos.size();
if (old_input_size != new_anf_inputs.size() - 1) {
MS_LOG(DEBUG) << "graph is dynamic, old input size:" << old_input_size
<< " new input_infos:" << (new_anf_inputs.size() - 1);
return true;
}
for (size_t i = 1; i < new_anf_inputs.size(); i++) {
const auto &new_anf_input = new_anf_inputs[i];
MS_EXCEPTION_IF_NULL(new_anf_input);
if (new_anf_input->isa<ValueNode>()) {
const auto &value_iter = old_node_info->input_values.find(i);
if (value_iter == old_node_info->input_values.end()) {
MS_LOG(DEBUG) << "The " << i << "th input is different, cur input is a value, old input is not a value.";
return true;
}
if (!IsValuePtrEqual(value_iter->second, GetValueNode(new_anf_input))) {
MS_LOG(DEBUG) << "The " << i << "th input, value is different.";
return true;
}
} else if (new_anf_input->isa<CNode>()) {
// Compare cnode abstract.
const auto &node_iter = old_node_info->input_cnode_info.find(i);
if (node_iter == old_node_info->input_cnode_info.end()) {
MS_LOG(DEBUG) << "The " << i << "th input is different, cur input is a cnode, old input is not a cnode.";
return true;
}
size_t old_op_index = 0;
AbstractBasePtr old_abs = nullptr;
std::tie(old_op_index, old_abs) = node_iter->second;
if (IsAbsDifferent(old_abs, new_anf_input->abstract())) {
MS_LOG(DEBUG) << "The " << i << "th input, abs is different.";
return true;
}
// Compare cnode edge.
if (old_op_index != top_cell()->get_op_index_by_cnode_hash(new_anf_input->hash())) {
MS_LOG(DEBUG) << "The " << i << "th input, op_index is different, old op_index:" << old_op_index
<< " new op_index:" << top_cell()->get_op_index_by_cnode_hash(new_anf_input->hash());
return true;
}
} else {
// Compare parameter.
if (!new_anf_input->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "new_anf_input:" << new_anf_input->fullname_with_scope()
<< " is none of value node, cnode and parameter.";
}
const auto &node_iter = old_node_info->input_param_infos.find(i);
if (node_iter == old_node_info->input_param_infos.end()) {
MS_LOG(DEBUG) << "The " << i
<< "th input is different, cur input is a parameter, old input is not a parameter.";
return true;
}
const auto &param = new_anf_input->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
if (!IsParamInfoEqual(node_iter->second, param->param_info())) {
MS_LOG(DEBUG) << "The " << i << "th input, param info is different.";
return true;
}
}
}
return false;
}
bool GradExecutor::IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info,
const CNodePtr &new_cnode, bool is_ms_function_node,
const std::string &graph_phase) const {
MS_EXCEPTION_IF_NULL(new_cnode);
MS_EXCEPTION_IF_NULL(old_node_info);
// 1.Detect ms_function phase
if (is_ms_function_node != old_node_info->is_graph_node ||
(is_ms_function_node && graph_phase != old_node_info->graph_phase)) {
MS_LOG(DEBUG) << "graph is dynamic, old is_graph_node:" << old_node_info->is_graph_node
<< " new is_graph_node:" << is_ms_function_node << " old graph_phase" << old_node_info->graph_phase
<< " new graph_phase:" << graph_phase;
return true;
}
// 2.Detect cnode prim
auto new_prim = GetCNodePrimitive(new_cnode);
if (!common::IsEqual(new_prim, old_node_info->prim)) {
MS_LOG(DEBUG) << "graph is dynamic, old prim:" << (old_node_info->prim == nullptr ? 0 : old_node_info->prim->name())
<< " new prim:" << (new_prim == nullptr ? 0 : new_prim->name());
return true;
}
// 3.Detect output abs
if (IsAbsDifferent(old_node_info->output_abs, new_cnode->abstract())) {
MS_LOG(DEBUG) << "graph is dynamic, output_abs is different";
return true;
}
// 4.Detect inputs
return IsCnodeInputsDynamic(old_node_info, new_cnode->inputs());
}
bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
const std::string &graph_phase) const {
MS_EXCEPTION_IF_NULL(cnode);
if (!is_cell_id_in_dynamic_detect_nodes_map_) {
SaveDynamicDetectNodeInfoInFirstTime(cnode, node_idx, is_ms_function_node, graph_phase);
// The net is regarded as a static net by default in the first time.
return false;
}
const auto &cell_id = top_cell()->c_cell_id() + "_" + std::to_string(top_cell()->grad_order());
const auto &dynamic_nodes = cell_id_with_dynamic_detect_nodes_[cell_id];
if (node_idx >= dynamic_nodes.size()) {
MS_LOG(DEBUG) << "old dynamic_nodes size:" << dynamic_nodes.size() << " cur node_idx is:" << node_idx
<< ", graph is dynamic.";
return true;
}
if (IsDynamicDetectNodeInfoChange(dynamic_nodes[node_idx], cnode, is_ms_function_node, graph_phase)) {
MS_LOG(DEBUG) << "graph is dynamic, node_idx:" << node_idx
<< " is different, cnode:" << cnode->fullname_with_scope();
return true;
}
top_cell()->set_cnode_hash_with_op_index(cnode->hash(), node_idx);
return false;
}
void GradExecutor::CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
const std::string &graph_phase) const {
if (!top_cell()->is_run_cell() || use_dynamic_shape_process_) {
return;
}
use_dynamic_shape_process_ = IsGraphDynamic(cnode, node_idx, is_ms_function_node, graph_phase);
if (use_dynamic_shape_process_) {
MS_LOG(DEBUG) << "cnode:" << cnode->fullname_with_scope() << ",node_idx:" << node_idx
<< ",is_ms_function_node:" << is_ms_function_node << ",graph_phase:" << graph_phase
<< ",use_dynamic_shape_process_:" << use_dynamic_shape_process_;
cell_id_with_dynamic_detect_nodes_.clear();
}
}
} // namespace pynative
} // namespace mindspore

View File

@ -36,6 +36,17 @@ class ForwardExecutor;
using ForwardExecutorPtr = std::shared_ptr<ForwardExecutor>;
using ForwardExecutorWeakPtr = std::weak_ptr<ForwardExecutor>;
struct DynamicDetectNodeInfo {
PrimitivePtr prim{nullptr};
AbstractBasePtr output_abs{nullptr};
bool is_graph_node{false};
std::string graph_phase;
mindspore::HashMap<size_t, std::pair<size_t, AbstractBasePtr>> input_cnode_info;
mindspore::HashMap<size_t, ValuePtr> input_values;
mindspore::HashMap<size_t, ParamInfoPtr> input_param_infos;
};
using DynamicDetectNodeInfoPtr = std::shared_ptr<DynamicDetectNodeInfo>;
class GradExecutor {
public:
GradExecutor() = default;
@ -43,8 +54,7 @@ class GradExecutor {
explicit GradExecutor(const ForwardExecutorPtr &forward_executor = nullptr)
: forward_executor_(ForwardExecutorWeakPtr(forward_executor)),
ms_function_(std::make_shared<MsFunction>()),
async_executor_(std::make_unique<AsyncQueue>()),
enable_async_(std::getenv("ENABLE_ASYNC")) {}
async_executor_(std::make_shared<AsyncQueue>()) {}
std::function<void(const py::object &, const py::args &)> InitGraph = [this](auto &&PH1, auto &&PH2) {
NewGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2));
@ -69,6 +79,10 @@ class GradExecutor {
MS_EXCEPTION_IF_NULL(ms_function_);
return ms_function_;
}
inline void set_use_dynamic_shape_process(bool use_dynamic_shape_process) {
use_dynamic_shape_process_ = use_dynamic_shape_process;
}
inline bool need_renormalize() const { return need_renormalize_; }
inline void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); }
inline bool grad_flag() const { return grad_flag_; }
@ -77,12 +91,16 @@ class GradExecutor {
inline bool eliminate_forward() const { return eliminate_forward_; }
inline void set_eliminate_forward(bool eliminate_forward) { eliminate_forward_ = eliminate_forward; }
inline size_t custom_bprop_cell_count() const { return custom_bprop_cell_count_; }
inline bool use_dynamic_shape_process() const { return use_dynamic_shape_process_; }
inline std::shared_ptr<AsyncQueue> async_executor() const { return async_executor_; }
void SetHookChanged(const py::object &cell) const;
void GradNetInner(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights,
const py::object &grad_position, const py::args &args);
py::object RunGradGraph();
CNodePtr ConstructForwardGraph(const FrontendOpRunInfoPtr &op_run_info) const;
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::args &args);
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &grad_hash_id,
const py::args &args);
TopCellInfoPtr GetTopCell(const std::string &already_run_cell_id);
void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
void AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args);
@ -90,25 +108,34 @@ class GradExecutor {
AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const;
void AsyncEndGraphImpl(const InputArgsInfoPtr input_args_info);
AnfNodePtr GetParamInput(const ValuePtr &v, const std::string &id) const;
void UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPtr &op_run_info) const;
void UpdatePreTensorInfo(const tensor::TensorPtr &new_tensor,
const std::vector<tensor::TensorPtr> &pre_tensors) const;
void ClearRes();
void WorkerJoin() { async_executor_->WorkerJoin(); }
void CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node = false,
const std::string &graph_phase = "") const;
private:
ForwardExecutorPtr forward() const;
inline FuncGraphPtr curr_g() const { return top_cell()->fg(); }
inline void PushHighOrderGraphStack(const TopCellInfoPtr &top_cell) { high_order_stack_.push(top_cell); }
inline bool CanGetTopCell(const string &already_run_cell_id) {
return already_run_cell_id.find(top_cell()->already_run_cell_id()) != std::string::npos;
}
std::string GetCurCellOrder() const;
void SetGradOrder(const std::string &cell_id);
void SaveOutputNodeMap(const std::string &obj_id, const FrontendOpRunInfoPtr &op_run_info,
const CNodePtr &cnode) const;
void DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode, const ValuePtr &op_out) const;
void GradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode,
const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const;
void AsyncGradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode,
const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const;
void AsyncUpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &cloned_value) const;
AnfNodePtr GetRealInputNodeBySkipHook(const AnfNodePtr &input_node) const;
void SetBpropGraphJitLevel(const py::object &obj) const;
void ClearGlobalRes();
void ClearGradRes();
std::string GetAlreadyRunCellId(const std::string &cell_id) const;
// Higher derivative
inline bool IsNestedGrad() const { return grad_order_ > 1; }
@ -121,6 +148,7 @@ class GradExecutor {
inline bool is_high_order_top_cell() const {
return !input_args_info_stack_.empty() && IsNestedGrad() && top_cell()->grad_order() != grad_order_;
}
void SwitchTopCell();
void DoParameterReplace(const FuncGraphPtr &first_grad_fg, const std::vector<ValuePtr> &forward_args,
std::vector<AnfNodePtr> *inputs, ValuePtrList *weights_args);
@ -132,15 +160,20 @@ class GradExecutor {
void HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_info, bool is_bprop_top) const;
void InitResourceAndDfBuilder(const InputArgsInfoPtr &cell_info);
void MakeNewTopGraph(const InputArgsInfoPtr &input_args_info);
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph) const;
// Manage resource when run grad process.
bool IsBpropGraph(const std::string &cell_id) const;
void NewGraphInner(const py::object &obj, const py::args &args);
InputArgsInfoPtr GetInputArgsInfo(const py::object &obj, const py::args &args);
void NewGraphImpl(const InputArgsInfoPtr &input_args_info);
void AsyncNewGraphImpl(const InputArgsInfoPtr &input_args_info);
void SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const;
void GetCustomBpropPrim(const py::object &obj, const py::args &args, const py::object &out,
const InputArgsInfoPtr &input_args_info);
void DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info, const std::string &out_id);
void CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info);
void EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell);
void GetGradGraph(const ad::GradAttr &grad_attr, const std::vector<AnfNodePtr> &w_args,
const std::vector<size_t> &p_args);
FuncGraphPtr GetBpropGraph(const ad::GradAttr &grad_attr, const vector<AnfNodePtr> &w_args,
@ -151,22 +184,38 @@ class GradExecutor {
const abstract::AbstractBasePtr &param_tensor_abs, const std::string &input_shape);
void UpdateParamAbsByArgs(const std::vector<ValuePtr> &input_args, const FuncGraphPtr &bprop_graph, bool has_sens);
std::vector<size_t> GetGradPositionArgs(const py::object &grad_position, bool get_by_position) const;
void SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const;
// Manage resource for construct forward graph.
AnfNodePtr GetOutputNodeAsInput(const std::string &obj_id) const;
AnfNodePtr GetValueSequenceInput(const ValuePtr &v, const std::string &obj_id) const;
AnfNodePtr CreateTupleGetItemNode(const std::string &obj_id,
const std::pair<AnfNodePtr, std::vector<int64_t>> &out) const;
void SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
const std::string &graph_phase) const;
bool IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
const std::string &graph_phase) const;
bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info,
const std::vector<AnfNodePtr> &new_anf_inputs) const;
bool IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info, const CNodePtr &new_cnode,
bool is_ms_function_node, const std::string &graph_phase) const;
bool grad_flag_{false};
bool grad_is_running_{false};
bool need_renormalize_{false};
bool eliminate_forward_{true};
mutable bool use_dynamic_shape_process_{false};
mutable bool is_cell_id_in_dynamic_detect_nodes_map_{false};
int custom_bprop_cell_count_{0};
// Used in sub thread
size_t cell_order_{0};
std::string cur_cell_id_{""};
// If grad_order=1, indicate first derivative; grad_order=2, indicate second derivative; ...
size_t grad_order_{0};
std::string grad_operation_;
TopCellInfoPtr top_cell_{nullptr};
TopCellInfoPtr pre_top_cell_{nullptr};
InputArgsInfoPtr top_input_args_info_{nullptr};
// Records every cell info for share, regardless of whether need construct grad graph
std::stack<InputArgsInfoPtr> input_args_info_stack_;
@ -175,11 +224,13 @@ class GradExecutor {
std::vector<std::string> bprop_cell_list_;
// For high grad order
std::stack<TopCellInfoPtr> high_order_stack_;
std::vector<TopCellInfoPtr> top_cell_list_;
// Record all top cell which has been ran
mindspore::HashMap<std::string, TopCellInfoPtr> already_run_top_cell_;
ForwardExecutorWeakPtr forward_executor_;
MsFunctionPtr ms_function_;
std::unique_ptr<AsyncQueue> async_executor_;
std::map<std::string, compile::BackendPtr> backends_;
bool enable_async_ = false;
std::shared_ptr<AsyncQueue> async_executor_;
mutable mindspore::HashMap<std::string, std::vector<DynamicDetectNodeInfoPtr>> cell_id_with_dynamic_detect_nodes_;
};
} // namespace pynative
} // namespace mindspore

View File

@ -19,6 +19,8 @@
#include "include/common/utils/anfalgo.h"
#include "include/common/utils/parallel_context.h"
#include "ir/func_graph_cloner.h"
#include "runtime/pynative/async/async_queue.h"
#include "pipeline/pynative/grad/bprop_task.h"
namespace mindspore {
namespace pynative {
@ -151,6 +153,35 @@ void MsFunction::ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, co
RunReplace(added_make_tuple, total_output_tensors, grad_graph);
}
void MsFunction::UpdateMsFunctionForwardTensors(const GradExecutor *grad_executor, const string &op_info,
const ValuePtr &new_forward_value) const {
if (grad_executor->use_dynamic_shape_process()) {
MS_LOG(DEBUG) << "Get dynamic shape process";
return;
}
MS_EXCEPTION_IF_NULL(new_forward_value);
MS_LOG(DEBUG) << "Ms func graph has already ran before. The graph phase is: " << graph_phase_;
MS_LOG(DEBUG) << "The output values of added forward nodes are: " << new_forward_value->ToString();
std::vector<tensor::TensorPtr> new_tensors;
TensorValueToTensor(new_forward_value, &new_tensors);
if (new_tensors.empty()) {
MS_LOG(DEBUG) << "The size of added forward tensors is zero, no need to update.";
return;
}
MS_EXCEPTION_IF_NULL(grad_executor);
const auto &top_cell = grad_executor->top_cell();
const auto &old_tensors = top_cell->op_info_with_ms_func_forward_tensors().at(op_info);
if (old_tensors.size() != new_tensors.size()) {
MS_LOG(EXCEPTION) << "The size of old tensors is: " << old_tensors.size()
<< ", but the size of new tensors is: " << new_tensors.size()
<< ", the current op info is: " << op_info;
}
for (size_t i = 0; i < new_tensors.size(); ++i) {
grad_executor->UpdatePreTensorInfo(new_tensors[i], {old_tensors[i]});
old_tensors[i]->set_sync_status(kNeedSyncDeviceToHost);
}
}
void MsFunction::GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, AnfNodePtrList *input_nodes,
const GradExecutor *grad_executor) const {
MS_EXCEPTION_IF_NULL(op_run_info);
@ -213,6 +244,7 @@ void MsFunction::GetWeightsNode(const FrontendOpRunInfoPtr &op_run_info, const G
void MsFunction::MakeCNodeForMsFunction(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
const FuncGraphPtr &ms_func_graph, CNodePtr *ms_function_cnode) const {
MS_EXCEPTION_IF_NULL(op_run_info);
// Get input node info of ms_function
std::vector<AnfNodePtr> input_nodes{NewValueNode(ms_func_graph)};
MS_EXCEPTION_IF_NULL(grad_executor);
@ -222,6 +254,7 @@ void MsFunction::MakeCNodeForMsFunction(const FrontendOpRunInfoPtr &op_run_info,
// Make a CNode which includes ms_function fprop graph and inputs node
MS_EXCEPTION_IF_NULL(ms_function_cnode);
*ms_function_cnode = grad_executor->top_cell()->fg()->NewCNode(input_nodes);
MS_LOG(DEBUG) << "Make ms function forward CNode: " << (*ms_function_cnode)->DebugString();
}
@ -242,6 +275,10 @@ CNodePtr MsFunction::MakeAdjointForMsFunction(const FrontendOpRunInfoPtr &op_run
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
auto grad_param =
std::make_shared<ad::GradParam>(ms_function_cnode, op_run_info->input_value, op_run_info->out_value, grad_graph);
{
py::gil_scoped_release gil_release;
grad_executor->async_executor()->Wait();
}
if (!auto_grad_cell_ptr->KPynativeWithFProp(grad_param)) {
MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode, ms_function cnode info: "
<< ms_function_cnode->DebugString();
@ -250,21 +287,55 @@ CNodePtr MsFunction::MakeAdjointForMsFunction(const FrontendOpRunInfoPtr &op_run
return ms_function_cnode;
}
void MsFunction::AsyncKPynativeWithFProp(const GradExecutor *grad_executor,
const ad::AutoGradCellImplPtr &auto_grad_cell_ptr,
const ad::GradParamPtr &grad_param) const {
MS_EXCEPTION_IF_NULL(grad_executor);
const auto fn = [this, grad_param, auto_grad_cell_ptr]() {
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
if (!auto_grad_cell_ptr->KPynativeWithFProp(grad_param)) {
MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode";
}
};
auto task = std::make_shared<BpropTask>(fn);
grad_executor->async_executor()->Push(task);
}
void MsFunction::AsyncGradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
const ValuePtr &added_out_v, const FuncGraphPtr &ms_func_graph,
const FuncGraphPtr &grad_graph) const {
const auto fn = [this, op_run_info, grad_executor, added_out_v, ms_func_graph, grad_graph]() {
this->GradMsFunctionInner(op_run_info, grad_executor, added_out_v, ms_func_graph, grad_graph);
};
auto task = std::make_shared<BpropTask>(fn);
grad_executor->async_executor()->Push(task);
}
void MsFunction::GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
const ValuePtr &added_out_v, const FuncGraphPtr &ms_func_graph,
const FuncGraphPtr &grad_graph) const {
MS_EXCEPTION_IF_NULL(op_run_info);
MS_EXCEPTION_IF_NULL(grad_executor);
MS_LOG(DEBUG) << "ms_function actual output value: " << op_run_info->out_value->ToString();
if (!grad_executor->grad_flag()) {
MS_LOG(EXCEPTION) << "The flag of need construct graph is False.";
// Step 1: Update actual output tensors used in grad graph.
MS_EXCEPTION_IF_NULL(op_run_info->out_value);
MS_LOG(DEBUG) << "ms_function actual output value: " << op_run_info->out_value->ToString();
// The output of ms_function may be used in subsequent PyNative process
grad_executor->UpdateForwardTensorInfoInBpropGraph(op_run_info);
// Step 2: Update output tensors of added forward nodes, which are added to return node of ms_function func graph.
if (grad_executor->top_cell()->op_info_with_ms_func_forward_tensors().find(op_run_info->op_info) !=
grad_executor->top_cell()->op_info_with_ms_func_forward_tensors().end()) {
UpdateMsFunctionForwardTensors(grad_executor, op_run_info->op_info, added_out_v);
}
// Update actual output tensors used in grad graph.
ReplaceNewTensorsInGradGraph(grad_executor->top_cell(), added_out_v, ms_func_graph, grad_graph);
// Clone new ms_function func graph and grad graph.
auto new_ms_func_graph = BasicClone(ms_func_graph);
auto new_grad_graph = BasicClone(grad_graph, true);
auto new_make_tuple = new_ms_func_graph->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(new_make_tuple);
new_ms_func_graph->set_output(new_make_tuple->input(1));
@ -273,6 +344,11 @@ void MsFunction::GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, co
const auto &ms_function_cnode =
MakeAdjointForMsFunction(op_run_info, grad_executor, new_ms_func_graph, new_grad_graph);
ms_function_cnode->set_abstract(new_ms_func_graph->output()->abstract()->Broaden());
auto grad_exec_ptr = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
MS_EXCEPTION_IF_NULL(grad_exec_ptr);
grad_exec_ptr->CheckGraphDynamic(ms_function_cnode, op_run_info->op_index, true,
op_run_info->base_op_run_info.op_name);
}
void MsFunction::SetMsFuncGraphParameters(const FuncGraphPtr &ms_func_graph) {
@ -316,6 +392,9 @@ py::object MsFunction::GradMsFunction(const py::object &out, const py::args &arg
const auto &op_run_info = GetOpRunInfo(out, args, graph_phase_, &added_out_v);
FuncGraphPtr grad_graph = executor->GetGradGraph(graph_phase_);
PyNativeAlgo::Common::DumpGraphIR("ms_func_forward_graph.ir", ms_func_graph);
if (!grad_executor->grad_flag()) {
MS_LOG(EXCEPTION) << "The flag of need construct graph is False.";
}
GradMsFunctionInner(op_run_info, grad_executor.get(), added_out_v, ms_func_graph, grad_graph);
SetMsFuncGraphParameters(ms_func_graph);
graph_phase_.clear();

View File

@ -42,11 +42,18 @@ class MsFunction {
void GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
const ValuePtr &added_out_v, const FuncGraphPtr &ms_func_graph,
const FuncGraphPtr &grad_graph) const;
void AsyncGradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, const GradExecutor *grad_executor,
const ValuePtr &added_out_v, const FuncGraphPtr &ms_func_graph,
const FuncGraphPtr &grad_graph) const;
void AsyncKPynativeWithFProp(const GradExecutor *grad_executor, const ad::AutoGradCellImplPtr &auto_grad_cell_ptr,
const ad::GradParamPtr &grad_param) const;
// Update device address of value node in grad graph by forward tensors.
void RunReplace(const CNodePtr &added_make_tuple, const std::vector<tensor::TensorPtr> &total_output_tensors,
const FuncGraphPtr &grad_graph) const;
void ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const ValuePtr &added_out,
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) const;
void UpdateMsFunctionForwardTensors(const GradExecutor *grad_executor, const string &op_info,
const ValuePtr &new_forward_value) const;
// Make CNode for ms_function forward graph.
void GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, AnfNodePtrList *input_nodes,
const GradExecutor *grad_executor) const;

View File

@ -57,6 +57,40 @@ void TopCellInfo::RecordCellBackwardHookOp(const std::string &cell_order, const
}
}
void TopCellInfo::GetOpInfo(const FrontendOpRunInfoPtr &op_run_info) {
MS_EXCEPTION_IF_NULL(op_run_info);
std::string input_args_info;
// Record input args info (weight or data)
// self.p = Parameter();
// def construct(x, y)
// if y:
// x = x + x
// else:
// x = x + self.p
// return x
for (size_t i = 0; i < op_run_info->base_op_run_info.input_tensor.size(); i++) {
const auto &t = op_run_info->base_op_run_info.input_tensor[i];
MS_EXCEPTION_IF_NULL(t);
if (t->is_parameter() && t->param_info() != nullptr && t->param_info()->requires_grad()) {
input_args_info += "w";
} else {
input_args_info += "d";
}
}
// Record op name and index
op_run_info->op_info.clear();
op_run_info->op_info +=
op_run_info->base_op_run_info.op_name + "-" + std::to_string(op_index_) + "-" + input_args_info;
const auto &out_abs = op_run_info->base_op_run_info.abstract;
auto shape = out_abs->BuildShape();
MS_EXCEPTION_IF_NULL(shape);
if (!shape->isa<abstract::NoShape>() && !shape->IsDimZero()) {
op_run_info->op_info += "-" + shape->ToString();
}
op_run_info->op_index = op_index_;
++op_index_;
}
void TopCellInfo::ClearDeviceMemory() const {
MS_LOG(DEBUG) << "Clear device memory in value nodes of bprop graph, top cell: " << cell_id_;
auto ms_context = MsContext::GetInstance();
@ -154,11 +188,40 @@ void TopCellInfo::SetNestedMultipleOutputToGraphInfoMap(const string &id, const
}
}
void TopCellInfo::Clear() {
MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_;
hook_changed_ = false;
ms_function_flag_ = false;
is_init_kpynative_ = false;
need_compile_graph_ = false;
forward_already_run_ = false;
op_index_ = 0;
resource_ = nullptr;
fg_ = nullptr;
graph_info_map_.clear();
op_info_with_tensor_id_.clear();
tensor_id_with_tensor_object_.clear();
op_info_with_ms_func_forward_tensors_.clear();
cnode_hash_with_op_index_.clear();
}
void TopCellInfo::SetUnpackOutputToGraphInfoMap(const std::string &id, const AnfNodePtr &node,
const std::vector<int64_t> &index) const {
auto &graph_info = graph_info_map().at(fg());
MS_EXCEPTION_IF_NULL(graph_info);
graph_info->node_map[id] = std::make_pair(node, index);
}
void TopCellInfo::set_opinfo_with_tensor_id(const std::string &op_info,
const std::vector<tensor::TensorPtr> &op_out_tensors) {
if (op_info_with_tensor_id_.find(op_info) != op_info_with_tensor_id_.end()) {
MS_LOG(EXCEPTION) << "Top cell: " << cell_id_ << " records op info with tensor id, but get op info " << op_info
<< " in op_info_with_tensor_id map";
}
// Record the relationship between the forward op and its output tensor id
(void)std::for_each(op_out_tensors.begin(), op_out_tensors.end(), [this, &op_info](const tensor::TensorPtr &tensor) {
(void)op_info_with_tensor_id_[op_info].emplace_back(tensor->id());
});
}
} // namespace pynative
} // namespace mindspore

View File

@ -42,6 +42,9 @@ namespace mindspore {
namespace pynative {
namespace py = pybind11;
class GradExecutor;
using OpInfoWithTensorId = mindspore::HashMap<std::string, std::vector<std::string>>;
using TensorIdWithTensorObject = mindspore::HashMap<std::string, std::vector<tensor::TensorPtr>>;
using OpInfoWithMsFuncForwardTensors = mindspore::HashMap<std::string, std::vector<tensor::TensorPtr>>;
using CellIdWithBackwardHookOp = mindspore::HashMap<std::string, std::vector<AnfNodePtr>>;
struct GraphInfo {
@ -55,9 +58,10 @@ using GraphInfoPtr = std::shared_ptr<GraphInfo>;
class TopCellInfo {
public:
~TopCellInfo() = default;
TopCellInfo(size_t grad_order, std::string cellid, std::string already_run_cell_id, pipeline::ResourcePtr r,
FuncGraphPtr fg)
TopCellInfo(size_t grad_order, std::string c_cell_id, std::string cellid, std::string already_run_cell_id,
pipeline::ResourcePtr r, FuncGraphPtr fg)
: grad_order_(grad_order),
c_cell_id_(std::move(c_cell_id)),
cell_id_(std::move(cellid)),
already_run_cell_id_(std::move(already_run_cell_id)),
resource_(std::move(r)),
@ -70,11 +74,14 @@ class TopCellInfo {
inline void set_sub_cell_hook_changed(const std::string &sub_cell) { (void)sub_cell_hook_changed_.emplace(sub_cell); }
inline const CellIdWithBackwardHookOp &cell_backward_hook_op() const { return cell_backward_hook_op_; }
void RecordCellBackwardHookOp(const std::string &cell_order, const AnfNodePtr &hook_op);
void GetOpInfo(const FrontendOpRunInfoPtr &op_run_info);
inline void ClearCellHookOp() { cell_backward_hook_op_.clear(); }
inline bool ms_function_flag() const { return ms_function_flag_; }
inline void set_ms_function_flag(bool ms_function_flag) { ms_function_flag_ = ms_function_flag; }
inline bool forward_already_run() const { return forward_already_run_; }
inline void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
inline bool need_compile_graph() const { return need_compile_graph_; }
inline void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
inline pipeline::ResourcePtr resource() const { return resource_; }
inline FuncGraphPtr fg() const {
MS_EXCEPTION_IF_NULL(fg_);
@ -82,18 +89,51 @@ class TopCellInfo {
}
inline void set_fg(const FuncGraphPtr &fg) { fg_ = fg; }
inline const std::string &cell_id() const { return cell_id_; }
inline const std::string &c_cell_id() const { return c_cell_id_; }
inline const std::string &already_run_cell_id() const { return already_run_cell_id_; }
inline void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; }
inline const std::string &input_args_id() const { return input_args_id_; }
const std::string &grad_operation() const { return grad_operation_; }
void set_grad_operation(const std::string &grad_operation) { grad_operation_ = grad_operation; }
inline void CheckSubCellHookChanged() { sub_cell_hook_changed_.clear(); }
inline void SetGraphInfoMap(const FuncGraphPtr &fg, const GraphInfoPtr &graph_info) {
graph_info_map_[fg] = graph_info;
}
inline void set_is_run_cell(bool is_run_cell) { is_run_cell_ = is_run_cell; }
inline bool is_run_cell() { return is_run_cell_; }
inline const OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() const { return graph_info_map_; }
inline ad::AutoGradCellImplPtr auto_grad_cell_ptr() const { return auto_grad_cell_ptr_; }
inline ad::AutoGradCellImplPtr auto_grad_cell_ptr() const {
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr_);
return auto_grad_cell_ptr_;
}
void set_auto_grad_cell_ptr(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr) {
auto_grad_cell_ptr_ = auto_grad_cell_ptr;
}
inline const OpInfoWithTensorId &op_info_with_tensor_id() const { return op_info_with_tensor_id_; }
void set_opinfo_with_tensor_id(const std::string &op_info, const std::vector<tensor::TensorPtr> &op_out_tensors);
inline const TensorIdWithTensorObject &tensor_id_with_tensor_object() const { return tensor_id_with_tensor_object_; }
inline void set_tensor_id_with_tensor_object(const std::string &id, const tensor::TensorPtr &tensor) {
(void)tensor_id_with_tensor_object_[id].emplace_back(tensor);
}
inline const OpInfoWithMsFuncForwardTensors &op_info_with_ms_func_forward_tensors() const {
return op_info_with_ms_func_forward_tensors_;
}
inline size_t op_index() const { return op_index_; }
inline void IncreaseOpIndex() { op_index_++; }
inline void set_cnode_hash_with_op_index(const size_t &node_hash, const size_t &op_index) {
cnode_hash_with_op_index_[node_hash] = op_index;
}
inline size_t get_op_index_by_cnode_hash(const size_t &node_hash) {
auto iter = cnode_hash_with_op_index_.find(node_hash);
if (iter == cnode_hash_with_op_index_.end()) {
MS_LOG(EXCEPTION) << "hash:" << node_hash << " is not found in cnode_hash_with_op_index_";
}
return iter->second;
}
void Clear();
void DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id);
void SetParamNodeMapInGraphInfoMap(const std::string &id, const ParameterPtr &param, bool is_weight = false) const;
void SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index = -1,
@ -111,7 +151,11 @@ class TopCellInfo {
bool ms_function_flag_{false};
bool is_init_kpynative_{false};
bool forward_already_run_{false};
bool need_compile_graph_{false};
bool is_run_cell_{false};
size_t op_index_{0};
size_t grad_order_{0};
std::string c_cell_id_;
std::string cell_id_;
std::string already_run_cell_id_;
std::string input_args_id_;
@ -126,6 +170,10 @@ class TopCellInfo {
// Record backward hook ops for each cell object.
// Each cell object has two backward hook ops.
CellIdWithBackwardHookOp cell_backward_hook_op_;
OpInfoWithTensorId op_info_with_tensor_id_;
TensorIdWithTensorObject tensor_id_with_tensor_object_;
OpInfoWithMsFuncForwardTensors op_info_with_ms_func_forward_tensors_;
mindspore::HashMap<size_t, size_t> cnode_hash_with_op_index_;
};
using TopCellInfoPtr = std::shared_ptr<TopCellInfo>;
} // namespace pynative

View File

@ -139,6 +139,7 @@ void PyNativeExecutor::ClearRes() const {
void PyNativeExecutor::Init() {
MS_LOG(DEBUG) << "Init PyNativeExecutor";
forward_executor_ = std::make_shared<ForwardExecutor>();
forward_executor_->Init();
grad_executor_ = std::make_shared<GradExecutor>(forward_executor_);
forward_executor_->set_grad_executor(grad_executor_);
}
@ -161,8 +162,8 @@ bool PyNativeExecutor::grad_flag() const { return grad_executor()->grad_flag();
void PyNativeExecutor::set_grad_flag(bool flag) const { grad_executor()->set_grad_flag(flag); }
py::object PyNativeExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj,
const py::args &args) const {
return grad_executor()->CheckAlreadyRun(grad, obj, args);
const py::object &grad_hash_id, const py::args &args) const {
return grad_executor()->CheckAlreadyRun(grad, obj, grad_hash_id, args);
}
void PyNativeExecutor::NewGraph(const py::object &obj, const py::args &args) const {
@ -187,7 +188,10 @@ void PyNativeExecutor::EndGraph(const py::object &obj, const py::object &out, co
forward_executor()->ProcessAfterEndGraph(obj, is_cell);
}
py::object PyNativeExecutor::Run() const { return PyNativeExecutorTry(grad_executor()->RunGraph); }
py::object PyNativeExecutor::Run() const {
const auto &ret = PyNativeExecutorTry(grad_executor()->RunGraph);
return ret;
}
void PyNativeExecutor::GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::object &grad_position, const py::args &args) const {
@ -195,13 +199,22 @@ void PyNativeExecutor::GradNet(const prim::GradOperationPtr &grad, const py::obj
}
py::object PyNativeExecutor::GradMsFunction(const py::object &out, const py::args &args) const {
return grad_executor()->ms_function()->GradMsFunction(out, args);
const auto &ret = grad_executor()->ms_function()->GradMsFunction(out, args);
return ret;
}
void PyNativeExecutor::SetLazyBuild(bool enable) const { forward_executor()->set_lazy_build(enable); }
bool PyNativeExecutor::IsFirstCell() const { return forward_executor()->IsFirstCell(); }
void PyNativeExecutor::SetMsFunctionCompileStatus(bool is_compiling) const {
forward_executor()->set_is_ms_function_compiling(is_compiling);
}
void PyNativeExecutor::SetDynamicInput(const py::object &cell, const py::args &args) const {
grad_executor()->set_use_dynamic_shape_process(true);
}
void RegPyNativeExecutor(const py::module *m) {
(void)py::class_<PyNativeExecutor, std::shared_ptr<PyNativeExecutor>>(*m, "PyNativeExecutor_")
.def_static("get_instance", &PyNativeExecutor::GetInstance, "PyNativeExecutor get_instance.")
@ -220,10 +233,13 @@ void RegPyNativeExecutor(const py::module *m) {
.def("set_hook_changed", &PyNativeExecutor::SetHookChanged, "set pynative hook changed")
.def("set_grad_flag", &PyNativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
"Executor set grad flag.")
.def("set_dynamic_input", &PyNativeExecutor::SetDynamicInput, "set dynamic input")
.def("set_py_exe_path", &PyNativeExecutor::set_py_exe_path, py::arg("py_exe_path") = py::str(""),
"set python executable path.")
.def("set_kernel_build_server_dir", &PyNativeExecutor::set_kernel_build_server_dir,
py::arg("kernel_build_server_dir") = py::str(""), "set kernel build server directory path.")
.def("set_ms_function_compile_status", &PyNativeExecutor::SetMsFunctionCompileStatus,
"set ms_funciton compile status.")
.def("real_run_op", &PyNativeExecutor::RealRunOp, "Run op pynatively.")
.def("constant_folding", &PyNativeExecutor::CallConstantFolding, "Call Constant Folding Primitive");
}

View File

@ -65,13 +65,17 @@ class PyNativeExecutor : public std::enable_shared_from_this<PyNativeExecutor> {
void GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::object &grad_position, const py::args &args) const;
py::object GradMsFunction(const py::object &out, const py::args &args) const;
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::args &args) const;
void SetDynamicInput(const py::object &cell, const py::args &args) const;
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &grad_hash_id,
const py::args &args) const;
void ClearRes() const;
// Sync stream
void Sync() const;
void SetLazyBuild(bool enable) const;
bool IsFirstCell() const;
void WorkerJoin() { grad_executor_->WorkerJoin(); }
void SetMsFunctionCompileStatus(bool is_compiling) const;
private:
PyNativeExecutor() = default;

View File

@ -602,13 +602,16 @@ TensorPtr GraphCompiler::GetSingleOpInputTensorByIndex(const CNodePtr &kernel,
}
void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info,
bool use_dynamic_shape_process,
session::BackendOpRunInfoPtr *op_run_info, GraphInfo *graph_info,
const GraphOutputInfo *const graph_output_info) {
MS_EXCEPTION_IF_NULL(session_);
MS_EXCEPTION_IF_NULL(graph_info);
*op_run_info = session_->GetSingleOpRunInfo(kernel, *graph_info, tensor_info, graph_output_info);
session_->GetSingleOpGraphInfo(kernel, tensor_info, graph_info, *op_run_info);
MS_EXCEPTION_IF_NULL(*op_run_info);
(*op_run_info)->base_op_run_info.graph_info = *graph_info;
(*op_run_info)->base_op_run_info.use_dynamic_shape_process = use_dynamic_shape_process;
}
void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const {

View File

@ -130,8 +130,8 @@ class GraphCompiler {
// Get OpRunInfo and GraphInfo for single op compile and run.
void GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info,
session::BackendOpRunInfoPtr *op_run_info, GraphInfo *graph_info,
const GraphOutputInfo *const graph_output_info);
bool use_dynamic_shape_process, session::BackendOpRunInfoPtr *op_run_info,
GraphInfo *graph_info, const GraphOutputInfo *const graph_output_info);
// Calculate ref count of PyNative back propagation operators.
void CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const;

View File

@ -296,7 +296,9 @@ class _MindsporeFunctionExecutor:
args_list = args
if self.obj is not None:
args_list = args_list[1:]
_pynative_executor.set_ms_function_compile_status(True)
phase = self.compile(args_list, self.fn.__name__)
_pynative_executor.set_ms_function_compile_status(False)
if context.get_context("precompile_only"):
return None
new_inputs = self._generate_run_args(args_list)
@ -428,6 +430,7 @@ class _MindsporeFunctionExecutor:
self.input_signature.append(args_list[-1])
Validator.check_dynamic_shape(self.input_signature, args_list)
compile_args = tuple(self.input_signature)
_pynative_executor.set_dynamic_input(self.obj, *compile_args)
return compile_args
def _generate_run_args(self, args_list):
@ -1012,7 +1015,7 @@ class _PyNativeExecutor:
"""
self._executor.end_graph(obj, output, *args, *(kwargs.values()))
def check_run(self, grad, obj, *args, **kwargs):
def check_run(self, grad, obj, grad_hash_id, *args, **kwargs):
"""
Whether the forward graph need to construct.
@ -1026,7 +1029,7 @@ class _PyNativeExecutor:
Return:
bool, specifies whether the forward graph need to construct.
"""
return self._executor.check_run(grad, obj, *args, *(kwargs.values()))
return self._executor.check_run(grad, obj, grad_hash_id, *args, *(kwargs.values()))
def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
"""
@ -1122,6 +1125,30 @@ class _PyNativeExecutor:
"""
self._executor.set_grad_flag(flag)
def set_ms_function_compile_status(self, status):
"""
Set ms_function is compiling
Args:
status(bool): ms_function compile status
Return:
None.
"""
self._executor.set_ms_function_compile_status(status)
def set_dynamic_input(self, obj, *args):
"""
Set dynamic shape tensor of input arguments.
Args:
obj (Function/Cell): The function or cell instance.
args (tuple): Function or cell dynamic input arguments.
Return:
None.
"""
self._executor.set_dynamic_input(obj, *args)
def is_first_cell(self):
"""
The flag of first cell instance.

View File

@ -891,6 +891,8 @@ class Cell(Cell_):
self._check_construct_args(*inputs)
if self._dynamic_shape_inputs:
ds.config.set_dynamic_shape(True)
if context._get_mode() == context.PYNATIVE_MODE:
_pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
def get_inputs(self):
"""

View File

@ -392,14 +392,14 @@ class GradOperation(GradOperation_):
new_kwargs = kwargs.copy()
new_kwargs.pop('sens')
if isinstance(fn, (FunctionType, MethodType)):
if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
if not _pynative_executor.check_run(grad, fn, self.weights_id, *args, **new_kwargs):
_pynative_executor.set_grad_flag(True)
_pynative_executor.new_graph(fn, *args, **new_kwargs)
output = fn(*args, **new_kwargs)
_pynative_executor.end_graph(fn, output, *args, **new_kwargs)
else:
# Check if fn have run already
if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
if not _pynative_executor.check_run(grad, fn, self.weights_id, *args, **new_kwargs):
fn.set_grad()
fn(*args, **new_kwargs)
fn.set_grad(False)
@ -465,6 +465,7 @@ class _Grad(GradOperation_):
self.pynative_ = False
self.grad_position = None
self.weights_id = None
self.grad_hash_id = None
def __call__(self, fn, weights=None, grad_position=0):
weights_id = _get_grad_weights_id(weights)
@ -537,6 +538,7 @@ class _Grad(GradOperation_):
self.fn = fn
self.grad_position = grad_position
self.weights_id = weights_id
self.grad_hash_id = (grad_position, weights_id)
return self.grad_fn
def _pynative_forward_run(self, fn, grad, args, kwargs):
@ -550,7 +552,7 @@ class _Grad(GradOperation_):
else:
args = args[:-1]
if isinstance(fn, (FunctionType, MethodType)):
if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
if not _pynative_executor.check_run(grad, fn, self.grad_hash_id, *args, **new_kwargs):
_pynative_executor.set_grad_flag(True)
_pynative_executor.new_graph(fn, *args, **new_kwargs)
outputs = fn(*args, **new_kwargs)
@ -558,7 +560,7 @@ class _Grad(GradOperation_):
return outputs
else:
# Check if fn has run already.
if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
if not _pynative_executor.check_run(grad, fn, self.grad_hash_id, *args, **new_kwargs):
fn.set_grad()
outputs = fn(*args, **new_kwargs)
fn.set_grad(False)