!34784 PyNative Bprop run in Graph.

Merge pull request !34784 from caifubi/master-pynative-run-in-graph-dev-ci
This commit is contained in:
i-robot 2022-05-30 01:13:52 +00:00 committed by Gitee
commit 3d0e975c90
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
20 changed files with 436 additions and 80 deletions

View File

@ -115,6 +115,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
first_step_ = graph.first_step_;
has_optimizer_ = graph.has_optimizer_;
is_dynamic_shape_ = graph.is_dynamic_shape_;
front_outputs_ = graph.front_outputs_;
}
~KernelGraph() override;
@ -445,6 +446,9 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
return iter->second;
}
AnfNodePtrList front_outputs() const { return front_outputs_; }
void set_front_outputs(const AnfNodePtrList &outputs) { front_outputs_ = outputs; }
private:
// remove value node form graph
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
@ -494,6 +498,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes_;
// parameters that will be updated when graph is executed
mindspore::HashSet<ParameterPtr> updated_parameters_;
// graph needn't execute
bool executable_{false};
// exist summary node in graph
@ -515,6 +520,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
CNodePtr start_label_;
CNodePtr end_goto_;
AnfNodePtrList front_outputs_;
// Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is
// related to the input of this kernel graph. The first of unordered map is the input of this kernel graph, the second
// of unordered map is front node corresponding to the output of previous kernel graph.

View File

@ -1419,7 +1419,8 @@ void SessionBasic::GetParameterIndex(const KernelGraph *graph, const std::vector
MS_EXCEPTION_IF_NULL(parallel_context);
auto parallel_mode = parallel_context->parallel_mode();
bool is_parallel_forward_ms_function =
!graph->is_bprop() && (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel);
!graph->has_flag(kFlagIsPynativeBpropGraph) &&
(parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel);
for (const auto &input_node : graph->input_nodes()) {
auto params = common::AnfAlgo::GetAllOutput(input_node);
for (const auto &param : params) {
@ -2617,12 +2618,12 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &op_output_map);
HandleOpOutputs(kernel, op_outputs, cnode_refcount, &op_output_map, &graph_output_info);
// Save grad node to Bucket
if (kernel_graph->is_bprop()) {
if (kernel_graph->has_flag(kFlagIsPynativeBpropGraph)) {
AddGradAddrToBucket(graph_id, graph_output_info.graph_output_tensors);
}
}
// Clear bucket resources every step
if (kernel_graph->is_bprop()) {
if (kernel_graph->has_flag(kFlagIsPynativeBpropGraph)) {
ClearAllBucket(graph_id);
}
@ -2699,10 +2700,8 @@ void SetGraphBpropAttr(const KernelGraphPtr &graph) {
auto &execution_orders = graph->execution_order();
if (std::any_of(execution_orders.begin(), execution_orders.end(),
[](const AnfNodePtr &node) { return node->scope()->name().rfind("Gradient", 0) == 0; })) {
graph->set_is_bprop(true);
graph->set_flag(kFlagIsPynativeBpropGraph, true);
MS_LOG(INFO) << "Match bprop graph";
} else {
graph->set_is_bprop(false);
}
}
@ -2798,7 +2797,7 @@ void SessionBasic::InitAllBucket(const KernelGraphPtr &graph, const device::Devi
}
SetGraphBpropAttr(graph);
if (!graph->is_bprop()) {
if (!graph->has_flag(kFlagIsPynativeBpropGraph)) {
return;
}

View File

@ -37,6 +37,7 @@
#include "runtime/hardware/device_context_manager.h"
#include "runtime/graph_scheduler/graph_compiler.h"
#include "runtime/pynative/run_op_helper.h"
#include "runtime/pynative/graph_adapter.h"
#include "distributed/recovery/recovery_context.h"
#include "include/common/utils/scoped_long_running.h"
#ifdef ENABLE_D
@ -440,6 +441,7 @@ MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string
SetDebuggerInit();
#endif
runtime::GraphScheduler::GetInstance().Initialize();
pynative_run_in_graph_ = device_context->GetDeviceType() != device::DeviceType::kAscend;
}
void MindRTBackend::ProcessNotSupportCnode(const FuncGraphPtr &func_graph,
@ -889,6 +891,74 @@ void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> &param
input_tensor->push_back(tensor_input);
}
void MindRTBackend::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, FuncGraph *root_graph) {
bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
distributed::recovery::RecoveryContext::GetInstance()->need_reset());
if (need_contruct_output) {
// Update device address for output node of graph.
// Summary processing will use the output device address, so must be after the summary processing.
actor_set->output_actor_->UpdateOutputDeviceAddress();
// Fetch outputs.
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
auto &output_tensors = actor_set->output_actor_->outputs();
if (!output_tensors.empty()) {
size_t output_position = 0;
ConstructOutputs(root_graph->output(), output_tensors, &output_position, outputs);
}
}
}
void MindRTBackend::RunGraphIntergated(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs) {
WaitTaskFinish();
MS_EXCEPTION_IF_NULL(graph_compiler_);
auto graphs = graph_compiler_info.graphs_;
auto &op_executor = runtime::OpExecutor::GetInstance();
op_executor.Register([this]() { BatchBuildCallback(); });
for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
const auto &graph = graphs[graph_index];
MS_EXCEPTION_IF_NULL(graph);
// TODO(caifubi): Update parameter format for Ascend
if (!graph->has_flag(kFlagGraphCompiled)) {
graph_compiler_->CompileGraphImpl(graph, graph_compiler_info.device_contexts_.front());
graph->CacheGraphOutputToFrontNodeWithIndex({graph->output()}, graph->front_outputs());
// Clear front outputs
graph->set_front_outputs({});
// Transform graph to actor DAG, and schedule the actor DAG.
const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(graph_compiler_info);
runtime::GraphScheduler::GetInstance().Schedule(actor_set);
pynative::GraphAdapter::ClearForwardOutputValueNodeDeviceAddress(graph);
pynative::GraphAdapter::GenerateRefCountForBpropValueNode(graph);
graph->set_flag(kFlagGraphCompiled, true);
}
pynative::GraphAdapter::UpdateForwardOutputInBpropGraph(graph);
}
// Run actor DAG.
mindspore::ScopedLongRunning long_running;
const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
MS_EXCEPTION_IF_NULL(actor_set);
runtime::GraphScheduler::GetInstance().Run(actor_set, graph_compiler_info.device_contexts_, inputs);
MS_EXCEPTION_IF_NULL(graph_compiler_);
graph_compiler_->Summary(graph_compiler_info.graphs_);
ConstructOutputs(actor_set, outputs, root_graph_);
// Clear bucket resources every step
for (auto &graph : graphs) {
if (graph->has_flag(kFlagIsPynativeBpropGraph)) {
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), actor_set->output_actor_->outputs());
graph_compiler_->ClearAllBucket(graph->graph_id());
}
}
runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
// Close abstract_lock for dynamic_shape
AnfUtils::CloseAbstractLock();
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
}
void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs) {
WaitTaskFinish();
@ -941,18 +1011,42 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
// Save grad node to Bucket
if (graph->is_bprop() && (!common::AnfAlgo::IsControlOpExecInBackend(kernel)) && !kernel->is_parallel()) {
if (graph->has_flag(kFlagIsPynativeBpropGraph) && (!common::AnfAlgo::IsControlOpExecInBackend(kernel)) &&
!kernel->is_parallel()) {
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
}
}
WaitTaskFinish();
// Clear bucket resources every step
if (graph->is_bprop()) {
if (graph->has_flag(kFlagIsPynativeBpropGraph)) {
graph_compiler_->ClearAllBucket(graph->graph_id());
}
}
}
void MindRTBackend::RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
const std::vector<std::vector<tensor::TensorPtr>> &input_tensors,
VectorRef *outputs) {
bool contain_cut_graph = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(),
[](const KernelGraphPtr &graph) { return graph->has_flag(kFlagsIsCutGraph); });
if (contain_cut_graph) {
// Python API will be called in cut_graph, so we cannot release gil here.
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
} else {
// Release python gil.
mindspore::ScopedLongRunning long_running;
bool is_dynamic_shape = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(),
[](const KernelGraphPtr &graph) { return graph->is_dynamic_shape(); });
// TODO(caifubi): PyNative dynamic shape not support run in graph now.
if (pynative_run_in_graph_ && !is_dynamic_shape && !root_graph_->has_flag(kFlagIsDynamicStructure)) {
RunGraphIntergated(actor_info, graph_compiler_info, input_tensors, outputs);
} else {
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
}
}
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
}
void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(root_graph_);
if (IsGraphOutputValueNodeOrParameter(root_graph_->output(), args, outputs)) {
@ -1013,16 +1107,7 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
MS_EXCEPTION_IF_NULL(outputs);
// There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode.
if (real_execution_mode_ == kPynativeMode) {
bool is_cut_graph = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(),
[](const KernelGraphPtr &graph) { return graph->has_flag(kFlagsIsCutGraph); });
if (is_cut_graph) {
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
} else {
// Release python gil.
mindspore::ScopedLongRunning long_running;
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
}
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
RunGraphByCondition(actor_info, graph_compiler_info, input_tensors, outputs);
return;
}
@ -1036,21 +1121,7 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
MS_EXCEPTION_IF_NULL(graph_compiler_);
graph_compiler_->Summary(graph_compiler_info.graphs_);
bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
distributed::recovery::RecoveryContext::GetInstance()->need_reset());
if (need_contruct_output) {
// Update device address for output node of graph.
// Summary processing will use the output device address, so must be after the summary processing.
actor_set->output_actor_->UpdateOutputDeviceAddress();
// Fetch outputs.
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
auto &output_tensors = actor_set->output_actor_->outputs();
if (output_tensors.size() > 0) {
size_t output_position = 0;
ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs);
}
}
ConstructOutputs(actor_set, outputs, root_graph_);
runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
// Close abstract_lock for dynamic_shape

View File

@ -143,6 +143,8 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
// Get saved OpBuildTask in OpExecutor and build all the kernels together in PyNative mode.
void CompileSingleOpGraphs(const std::vector<std::shared_ptr<runtime::OpBuildTask>> &build_tasks);
void ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, FuncGraph *root_graph);
// Restore the outputs tuple by the origin funcGraph output node and output tensors.
void ConstructOutputs(const AnfNodePtr &output_node, const std::vector<tensor::TensorPtr> &output_tensors,
size_t *output_position, VectorRef *outputs);
@ -173,11 +175,16 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
void DispatchOpTask(bool single_op_cache_hit, VectorRef *outputs, GraphCompilerInfo *graph_compiler_info,
OpRunInfo *op_run_info);
void RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
const std::vector<std::vector<tensor::TensorPtr>> &input_tensors, VectorRef *outputs);
// Split complete kernel graph to single op graph in PyNative back
// propagation, then compile and run single op graph.
void RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs);
void RunGraphIntergated(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs);
void UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, VectorRef *const outputs);
void ReleaseForwardOutput(const std::vector<TensorPtr> &input_tensors);
@ -212,6 +219,8 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
void CompileSubGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode = device::RunMode::kUnknown);
void ProcessNotSupportCnode(const FuncGraphPtr &func_graph, const device::DeviceType &old_target,
const device::DeviceType &new_target);
// TODO(caifubi): Remove this flag when Ascend backend is ok.
bool pynative_run_in_graph_{false};
};
using MindRTBackendPtr = std::shared_ptr<compile::MindRTBackend>;
} // namespace compile

View File

@ -584,12 +584,15 @@ constexpr auto kActualAbstract = "actual_abstract";
constexpr auto kAttrZeroInfinity = "zero_infinity";
constexpr auto kAttrBlank = "blank";
// FuncGraph Flags
constexpr auto kFlagsIsCutGraph = "is_cut_graph";
constexpr auto kFlagGraphCompiled = "graph_compiled";
constexpr auto kFlagIsDynamicStructure = "is_dynamic_structure";
constexpr auto kFlagIsPynativeBpropGraph = "is_pynative_bprop_graph";
// TODO(dsj): for ms_function running in graph_mode. should be delete later
constexpr auto kAttrMSFunction = "ms_function_graph";
// KernelGraph Flags
constexpr auto kFlagsIsCutGraph = "is_cut_graph";
// custom operator func type
constexpr auto kCustomTypeAOT = "aot";
constexpr auto kCustomTypeJULIA = "julia";

View File

@ -1016,9 +1016,10 @@ bool TaskEmitAction(const ResourcePtr &resource) {
DisableMindRT(resource);
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
auto is_parallel = (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel);
bool pynative_switch_to_graph_mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
(!func_graph->is_bprop() || func_graph->manager()->func_graphs().size() > 1) &&
!is_parallel;
bool pynative_switch_to_graph_mode =
context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
(!func_graph->has_flag(kFlagIsPynativeBpropGraph) || func_graph->manager()->func_graphs().size() > 1) &&
!is_parallel;
SetRunMode(resource, pynative_switch_to_graph_mode);
auto bc_ptr = resource->GetResult(kBackend).cast<compile::BackendPtr>();
MS_EXCEPTION_IF_NULL(bc_ptr);

View File

@ -872,6 +872,8 @@ void ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const OpExecIn
std::vector<tensor::TensorPtr> total_output_tensors;
TensorValueToTensor(added_out, &total_output_tensors);
RunReplace(added_make_tuple, total_output_tensors, grad_graph);
std::for_each(total_output_tensors.begin(), total_output_tensors.end(),
[](tensor::TensorPtr &tensor) { tensor->set_is_forward_output(true); });
top_cell->set_op_info_with_ms_func_forward_tensors(op_exec_info->op_info, total_output_tensors);
}
@ -926,6 +928,13 @@ void UpdateTensorInfo(const tensor::TensorPtr &new_tensor, const std::vector<ten
MS_EXCEPTION_IF_NULL(old_device_address);
auto new_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
MS_EXCEPTION_IF_NULL(new_device_address);
// CPU host tensor data_c is different from device address if the address is from mem_pool.
if (new_device_address->from_mem_pool()) {
pre_tensor->set_device_address(new_device_address);
continue;
}
auto old_ptr = old_device_address->GetMutablePtr();
MS_EXCEPTION_IF_NULL(old_ptr);
auto new_ptr = new_device_address->GetPtr();
@ -3338,7 +3347,10 @@ void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &g
// Get bprop graph of top cell
auto bprop_graph = GetBpropGraph(grad, cell, w_args, p_args, size, args);
MS_EXCEPTION_IF_NULL(bprop_graph);
bprop_graph->set_is_bprop(true);
if (top_cell()->is_dynamic_structure()) {
bprop_graph->set_flag(kFlagIsDynamicStructure, true);
}
bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true);
resource->set_func_graph(bprop_graph);
auto manager = resource->manager();
MS_EXCEPTION_IF_NULL(manager);
@ -3666,6 +3678,7 @@ void GradExecutor::CheckNeedCompileGraph() {
pre_top_cell->Clear();
already_run_top_cell_[already_top_cell_id] = new_top_cell;
g_pyobj_id_cache.clear();
top_cell()->set_is_dynamic_structure(true);
} else {
MS_LOG(DEBUG) << "The op info has not been changed, no need to compile graph again";
pre_top_cell->set_input_args_id(new_top_cell->input_args_id());

View File

@ -104,6 +104,8 @@ class TopCellInfo {
void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
bool forward_already_run() const { return forward_already_run_; }
void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
void set_is_dynamic_structure(bool is_dynamic_structure) { is_dynamic_structure_ = is_dynamic_structure; }
bool is_dynamic_structure() const { return is_dynamic_structure_; }
pipeline::ResourcePtr resource() const { return resource_; }
FuncGraphPtr df_builder() const { return df_builder_; }
FuncGraphPtr fg() const { return fg_; }
@ -150,6 +152,7 @@ class TopCellInfo {
bool is_init_kpynative_{false};
bool forward_already_run_{false};
bool need_compile_graph_{false};
bool is_dynamic_structure_{false};
size_t op_num_{0};
size_t grad_order_{0};
pipeline::ResourcePtr resource_{nullptr};

View File

@ -37,6 +37,22 @@ namespace device {
namespace ascend {
using AscendAutoMonad = mindspore::session::AscendAutoMonad;
namespace {
void RemoveUnusedValueNode(const KernelGraphPtr &graph) {
auto m = graph->manager();
auto node_users = m->node_users();
mindspore::HashSet<ValueNodePtr> unused_value_nodes;
for (auto &value_node : graph->graph_value_nodes()) {
if (node_users.find(value_node) == node_users.end()) {
unused_value_nodes.insert(value_node);
}
}
for (auto &value_node : unused_value_nodes) {
graph->RemoveNodeFromGraph(value_node);
}
}
} // namespace
void AscendGraphOptimization::Reset() {
MS_LOG(INFO) << "Clear Ascend Graph Optimization Resource.";
memo_.clear();
@ -60,7 +76,9 @@ void AscendGraphOptimization::OptimizeGraph(const KernelGraphPtr &graph) {
OptimizeGraphWithDeviceInfo(graph);
OptimizeExecutionOrder(graph);
PostOptimization(graph);
// must clear memo_ which holds kernel graph after using AscendGraphOptimization class.
RemoveUnusedValueNode(graph);
memo_.clear();
// clear and reset graph_manager_ after optimization
graph_manager_ = MakeManager();

View File

@ -22,6 +22,7 @@
#include <memory>
#include <map>
#include <utility>
#include "ir/tensor.h"
#include "ir/dtype.h"
#include "ir/device_sync.h"
#include "utils/shape_utils.h"
@ -107,6 +108,10 @@ class DeviceAddress : public mindspore::DeviceSync {
std::string device_name() const { return device_name_; }
uint32_t device_id() const { return device_id_; }
void set_from_tensor(const std::weak_ptr<tensor::Tensor> &from_tensor) { from_tensors_.emplace_back(from_tensor); }
std::vector<std::weak_ptr<tensor::Tensor>> from_tensors() const { return from_tensors_; }
void clear_from_tensors() { from_tensors_.clear(); }
virtual void SetNodeIndex(const AnfNodePtr &node, size_t out_index) { node_index_ = {node, out_index}; }
KernelWithIndex GetNodeIndex() const {
return node_index_.first.expired() ? KernelWithIndex{nullptr, node_index_.second}
@ -167,6 +172,7 @@ class DeviceAddress : public mindspore::DeviceSync {
ShapeVector host_shape_{};
// {node, out_index}
std::pair<AnfNodeWeakPtr, size_t> node_index_{AnfNodePtr(nullptr), 0};
std::vector<std::weak_ptr<tensor::Tensor>> from_tensors_;
// The device address of the node that owns the device address cannot be updated and replaced.
// Application scenario: set to true when the hardware execution mode requires that ptr cannot be changed during
// execution.

View File

@ -226,7 +226,21 @@ void FreeMemoryByRefCount(DeviceTensor *const device_tensor, const DeviceContext
device_tensor->DecreaseRefCount();
if (device_tensor->ref_count() == 0) {
if (device_tensor->GetPtr() != nullptr) {
FreeMemory(device_tensor, device_context);
auto from_tensors = device_tensor->from_tensors();
if (from_tensors.empty()) {
FreeMemory(device_tensor, device_context);
} else {
std::for_each(from_tensors.begin(), from_tensors.end(), [](const std::weak_ptr<tensor::Tensor> &t) {
auto tensor = t.lock();
if (tensor != nullptr) {
tensor->set_device_address(nullptr);
}
});
device_tensor->clear_from_tensors();
// Reset device address status
device_tensor->set_original_ref_count(SIZE_MAX);
device_tensor->ResetRefCount();
}
}
device_tensor->ResetRefCount();
}

View File

@ -304,7 +304,8 @@ void DataPrepareActor::UpdateDeviceAddressForDataNode(const AnfNodePtr &input_no
// Assign tensor address to input data node and set `ref_count` to `SIZE_MAX` for avoiding clean
AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
tensor_address->SetNodeIndex(input_node, 0);
tensor_address->set_ref_count(SIZE_MAX);
tensor_address->set_original_ref_count(SIZE_MAX);
tensor_address->ResetRefCount();
} else if (device_address->GetPtr() != nullptr) {
// The `device_address` may come from another previous tensor. In order to prevent pollute the device data of
// previous tensor, creating a new device address for holding current input tensor data.
@ -313,7 +314,8 @@ void DataPrepareActor::UpdateDeviceAddressForDataNode(const AnfNodePtr &input_no
MS_EXCEPTION_IF_NULL(new_device_address);
AnfAlgo::SetOutputAddr(new_device_address, 0, input_node.get());
new_device_address->SetNodeIndex(input_node, 0);
new_device_address->set_ref_count(SIZE_MAX);
new_device_address->set_original_ref_count(SIZE_MAX);
new_device_address->ResetRefCount();
}
}

View File

@ -161,12 +161,10 @@ void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, cons
}
auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
if (output_address != nullptr && output_address->GetDeviceType() == device_context->GetDeviceType()) {
// The input of PyNative bprop graph is ValueNode.
// Setting the address to the ValueNode will lead to memory leak.
if (!graph->is_bprop()) {
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
value_node.get());
}
// We need to set tensor->device_address to ValueNode even if the tensor is a forward_output tensor
// in PyNative Bprop graph. ValueNode device_address is necessary for GraphSchedule::Transform.
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
value_node.get());
continue;
}
@ -369,18 +367,6 @@ void SetSummaryNodesRefCount(const KernelGraph *graph) {
}
}
void UpdateRefCountForGraphOutput(const std::vector<KernelWithIndex> &output_with_index) {
for (const auto &item_with_index : output_with_index) {
if (!AnfAlgo::OutputAddrExist(item_with_index.first, item_with_index.second, false)) {
continue;
}
auto device_address = AnfAlgo::GetMutableOutputAddr(item_with_index.first, item_with_index.second, false);
MS_EXCEPTION_IF_NULL(device_address);
device_address->set_original_ref_count(SIZE_MAX);
device_address->ResetRefCount();
}
}
void SetGraphInputNodeActualAbstract(const session::OpRunInfo &op_run_info, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
if (!op_run_info.output_is_dynamic_shape && !op_run_info.input_is_dynamic_shape) {
@ -446,11 +432,13 @@ GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNod
MS_EXCEPTION_IF_NULL(session_);
// Graph kernel does not support pynative mode now, print a warning here.
graphkernel::GraphKernelFlags::GetInstance().CheckSupport();
session_->InitAllBucket(graph, device_context);
graph_id = graph->graph_id();
} else {
graph_id = CompileGraphImpl(graph, device_context);
}
session_->InitAllBucket(graph, device_context);
graph->set_front_outputs(outputs);
session_->DumpGraphs({graph});
@ -559,7 +547,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
MS_EXCEPTION_IF_NULL(session_);
session_->InitAllBucket(graph, device_context);
SetSummaryNodesRefCount(graph.get());
#ifdef ENABLE_DUMP_IR
// Dump .pb graph after graph optimization.
@ -629,7 +616,6 @@ GraphId GraphCompiler::CompileGraph(const session::OpRunInfo &op_run_info, bool
(void)outputs_with_index.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(node, 0, false));
}
UpdateRefCountForGraphOutput(outputs_with_index);
AnfAlgo::UpdateGraphValidRefPair(graph);
return graph->graph_id();
}

View File

@ -183,13 +183,13 @@ class GraphCompiler {
// Remove single op kernel graph cache and output nodes cache.
void EraseSingleOpCache(const GraphInfo &graph_info, const GraphId &graph_id);
private:
DISABLE_COPY_AND_ASSIGN(GraphCompiler);
// The implementation of compiling graph in Graph Mode, including optimizing graph,
// setting operator info, creating kernel and transforming kernel graph to ActorSet.
GraphId CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const;
private:
DISABLE_COPY_AND_ASSIGN(GraphCompiler);
// Add operators' output and input reference map to the graph.
void AddOutInRefToGraph(const KernelGraphPtr &graph) const;

View File

@ -0,0 +1,200 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/pynative/graph_adapter.h"
#include <string>
#include <memory>
#include <vector>
#include "ir/tensor.h"
#include "include/common/utils/convert_utils.h"
#include "include/common/utils/anfalgo.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "runtime/graph_scheduler/device_tensor_store.h"
namespace mindspore::pynative {
namespace {
constexpr auto kAttrBpropValueNodeRefCount = "bprop_value_node_ref_count";
constexpr auto kAttrValueNodeForwardOuputFlags = "value_node_forward_output_flags";
tensor::TensorPtr GetTensorFromValueNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<ValueNode>()) {
return nullptr;
}
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
// ValueTuple is already expanded into tensors in backend.
if (!value->isa<tensor::Tensor>()) {
MS_LOG(DEBUG) << "Only need to process forward output tensor. value:" << value->ToString();
return nullptr;
}
auto tensor = value->cast<tensor::TensorPtr>();
return tensor;
}
} // namespace
void GraphAdapter::ClearForwardOutputValueNodeDeviceAddress(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
for (auto &value_node : graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
auto tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
if (tensor->is_forward_output()) {
AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
}
}
}
}
// The device address of graph value node need to release
// if the value node is output of forward_graph in PyNative mode.
void GraphAdapter::GenerateRefCountForBpropValueNode(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
HashMap<std::string, size_t> tensor_counts;
auto execution_nodes = graph->execution_order();
for (auto &node : execution_nodes) {
std::vector<session::KernelWithIndex> real_inputs;
common::AnfAlgo::GetRealInputs(node, &real_inputs);
for (auto &real_input : real_inputs) {
auto forward_output_tensor = GetTensorFromValueNode(real_input.first);
if (forward_output_tensor == nullptr || !forward_output_tensor->is_forward_output()) {
continue;
}
tensor_counts[forward_output_tensor->id()] += 1;
}
}
std::vector<size_t> value_node_ref_count;
std::vector<bool> value_node_forward_output_flags;
for (auto &value_node : graph->graph_value_nodes()) {
auto tensor = GetTensorFromValueNode(value_node);
if (tensor == nullptr || !tensor->is_forward_output()) {
value_node_ref_count.emplace_back(SIZE_MAX);
value_node_forward_output_flags.emplace_back(false);
continue;
}
auto iter = tensor_counts.find(tensor->id());
if (iter == tensor_counts.end()) {
// The tensor is in bp graph but not used.
// e.g. %1-MakeTuple(T1, T2) -> TupleGetItem(%1, 0). T2 is not used.
MS_LOG(DEBUG) << "Tensor " << tensor->ToString() << " is not found in value node";
value_node_ref_count.emplace_back(SIZE_MAX);
value_node_forward_output_flags.emplace_back(false);
continue;
}
value_node_ref_count.emplace_back(iter->second);
value_node_forward_output_flags.emplace_back(true);
}
graph->set_attr(kAttrBpropValueNodeRefCount, MakeValue(value_node_ref_count));
graph->set_attr(kAttrValueNodeForwardOuputFlags, MakeValue(value_node_forward_output_flags));
}
void GraphAdapter::UpdateForwardOutputInBpropGraph(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(DEBUG) << "Update start";
auto value_node_ref_counts = GetValue<std::vector<size_t>>(graph->get_attr(kAttrBpropValueNodeRefCount));
auto value_node_forward_output_flags = GetValue<std::vector<bool>>(graph->get_attr(kAttrValueNodeForwardOuputFlags));
size_t value_node_size = graph->graph_value_nodes().size();
if (value_node_ref_counts.size() != value_node_size || value_node_forward_output_flags.size() != value_node_size) {
MS_LOG(EXCEPTION) << "value_node_ref_count.size " << value_node_ref_counts.size()
<< " value_node_forward_output_flags.size " << value_node_forward_output_flags.size()
<< " not equal to " << value_node_size;
}
size_t value_node_index = 0;
HashMap<device::DeviceAddressPtr, size_t> address_ref_count;
// Update ValueNode device address
for (auto &value_node : graph->graph_value_nodes()) {
auto is_forward_output = value_node_forward_output_flags[value_node_index];
if (!is_forward_output) {
value_node_index++;
continue;
}
size_t value_node_ref_count = value_node_ref_counts[value_node_index++];
auto tensor = GetTensorFromValueNode(value_node);
MS_EXCEPTION_IF_NULL(tensor);
auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
if (device_address == nullptr) {
MS_LOG(WARNING) << "Forward output " << tensor->ToString() << " device address is null";
continue;
}
if (device_address->GetDeviceType() != device::DeviceType::kCPU) {
address_ref_count[device_address] += value_node_ref_count;
device_address->set_from_tensor(tensor);
}
auto front_node = AnfAlgo::FetchFrontNodeByBackendNode(value_node, *graph);
runtime::DeviceTensorStore::GetInstance().Insert(front_node.get(), device_address);
}
for (auto &[address, ref_count] : address_ref_count) {
address->set_original_ref_count(ref_count);
address->ResetRefCount();
MS_LOG(DEBUG) << "device_address " << address.get() << " ref_count " << address->ref_count();
}
MS_LOG(DEBUG) << "Update end";
}
bool GraphAdapter::ReplaceBpropGraphParameter(const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
size_t index = 0;
bool changed = false;
for (const auto &input_node : graph->input_nodes()) {
auto params = common::AnfAlgo::GetAllOutput(input_node);
for (const auto &param : params) {
if (index >= input_tensors.size()) {
MS_LOG(EXCEPTION) << "Parameter size out of range. Parameter index: " << index
<< ", input size: " << input_tensors.size();
}
const auto &input_tensor = input_tensors[index++];
MS_EXCEPTION_IF_NULL(input_tensor);
const auto &tensor_address = input_tensor->device_address();
auto address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address);
if (address != nullptr) {
auto tensor_format = address->format();
auto param_format = AnfAlgo::GetOutputFormat(param, 0);
if (tensor_format != param_format) {
// Update parameter format
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{address->format()});
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{address->type_id()});
kernel_build_info_builder->SetOutputsReshapeType({input_tensor->padding_type()});
AnfAlgo::SetOutputAddr(address, 0, param.get());
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
// Update abstract
auto type_of_tensor = input_tensor->Dtype();
auto shape_of_tensor = input_tensor->shape();
auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
param->set_abstract(abstract);
changed = true;
}
}
}
}
return changed;
}
} // namespace mindspore::pynative

View File

@ -0,0 +1,33 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_GRAPH_ADAPTER_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_GRAPH_ADAPTER_H_
#include <vector>
#include "backend/common/session/kernel_graph.h"
namespace mindspore::pynative {
class GraphAdapter {
public:
static void UpdateForwardOutputInBpropGraph(const KernelGraphPtr &graph);
static bool ReplaceBpropGraphParameter(const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors);
static void GenerateRefCountForBpropValueNode(const KernelGraphPtr &graph);
static void ClearForwardOutputValueNodeDeviceAddress(const KernelGraphPtr &graph);
};
} // namespace mindspore::pynative
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_GRAPH_ADAPTER_H_

View File

@ -105,7 +105,6 @@ void UpdateInputNodeDeviceAddress(const std::vector<AnfNodePtr> &input_nodes,
input_tensor->set_lazy_callback([]() { runtime::OpExecutor::GetInstance().Wait(); });
node_address->set_from_persistent_mem(input_tensor->is_parameter());
node_address->SetNodeIndex(input_node, 0);
UpdateRefCount(node_address.get(), true);
}
// The DeviceType and format of DeviceAddress is always the same after UpdateInputTensor
@ -174,7 +173,6 @@ void CopyValueNodeTensorToDevice(const ValueNodePtr &node, const device::DeviceC
return;
}
tensor->set_device_address(node_address);
UpdateRefCount(node_address.get(), true);
CopyTensorDataToDevice(tensor, node, device_context);
}
}

View File

@ -43,7 +43,6 @@ FuncGraph::FuncGraph(GraphDebugInfoPtr &&debug_info)
kw_only_args_count_(0),
hyper_param_count_(0),
is_generated_(false),
is_bprop_(false),
return_(nullptr),
manager_(),
debug_info_(std::move(debug_info)),
@ -142,7 +141,7 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
return p;
}
bool FuncGraph::has_flag(const std::string &key) {
bool FuncGraph::has_flag(const std::string &key) const {
auto iter = attrs_.find(key);
if (iter != attrs_.cend()) {
MS_EXCEPTION_IF_NULL(iter->second);

View File

@ -184,8 +184,6 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
FuncGraphPtr GenerateGraph(const AbstractBasePtrList &args_spec_list);
void set_is_generate(bool generated) { is_generated_ = generated; }
bool is_generated() const { return is_generated_; }
void set_is_bprop(bool is_brop) { is_bprop_ = is_brop; }
bool is_bprop() const { return is_bprop_; }
mindspore::HashMap<std::string, ValuePtr> &attrs() { return attrs_; }
void set_attrs(const mindspore::HashMap<std::string, ValuePtr> &attrs) {
@ -193,7 +191,7 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
attrs_[attr.first] = attr.second;
}
}
bool has_flag(const std::string &key);
bool has_flag(const std::string &key) const;
void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); }
void erase_flag(const std::string &key) { (void)attrs_.erase(key); }
@ -425,9 +423,6 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
size_t hyper_param_count_;
// Argument input list for the graph used to generate this graph.
bool is_generated_;
bool is_bprop_;
// CNode that calls 'return' primitive.
// We use shared pointer to manage it.
CNodePtr return_;

View File

@ -381,6 +381,6 @@ def test_train_feed(num_classes=65536):
model = Model(net, loss_fn=loss, optimizer=opt)
model.train(3, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
loss_value = np.array(parallel_callback.loss_list)
expect_out = [11.087254, 10.876551, 10.045684]
expect_out = [11.087254, 10.876551, 10.142526]
print(loss_value)
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)