forked from mindspore-Ecosystem/mindspore
PyNative Bprop run in Graph.
1. Dynamic structure bp graph need to be split and run op by op. 2. Static bp graph can run graph integated. (Only support CPU/GPU/Non-DynamaiShape)
This commit is contained in:
parent
d8b5ce1e49
commit
d490450270
|
@ -115,6 +115,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
first_step_ = graph.first_step_;
|
||||
has_optimizer_ = graph.has_optimizer_;
|
||||
is_dynamic_shape_ = graph.is_dynamic_shape_;
|
||||
front_outputs_ = graph.front_outputs_;
|
||||
}
|
||||
|
||||
~KernelGraph() override;
|
||||
|
@ -445,6 +446,9 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
return iter->second;
|
||||
}
|
||||
|
||||
AnfNodePtrList front_outputs() const { return front_outputs_; }
|
||||
void set_front_outputs(const AnfNodePtrList &outputs) { front_outputs_ = outputs; }
|
||||
|
||||
private:
|
||||
// remove value node form graph
|
||||
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
|
||||
|
@ -494,6 +498,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes_;
|
||||
// parameters that will be updated when graph is executed
|
||||
mindspore::HashSet<ParameterPtr> updated_parameters_;
|
||||
|
||||
// graph needn't execute
|
||||
bool executable_{false};
|
||||
// exist summary node in graph
|
||||
|
@ -515,6 +520,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
CNodePtr start_label_;
|
||||
CNodePtr end_goto_;
|
||||
|
||||
AnfNodePtrList front_outputs_;
|
||||
// Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is
|
||||
// related to the input of this kernel graph. The first of unordered map is the input of this kernel graph, the second
|
||||
// of unordered map is front node corresponding to the output of previous kernel graph.
|
||||
|
|
|
@ -1419,7 +1419,8 @@ void SessionBasic::GetParameterIndex(const KernelGraph *graph, const std::vector
|
|||
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||
auto parallel_mode = parallel_context->parallel_mode();
|
||||
bool is_parallel_forward_ms_function =
|
||||
!graph->is_bprop() && (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel);
|
||||
!graph->has_flag(kFlagIsPynativeBpropGraph) &&
|
||||
(parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel);
|
||||
for (const auto &input_node : graph->input_nodes()) {
|
||||
auto params = common::AnfAlgo::GetAllOutput(input_node);
|
||||
for (const auto ¶m : params) {
|
||||
|
@ -2617,12 +2618,12 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
|
|||
HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &op_output_map);
|
||||
HandleOpOutputs(kernel, op_outputs, cnode_refcount, &op_output_map, &graph_output_info);
|
||||
// Save grad node to Bucket
|
||||
if (kernel_graph->is_bprop()) {
|
||||
if (kernel_graph->has_flag(kFlagIsPynativeBpropGraph)) {
|
||||
AddGradAddrToBucket(graph_id, graph_output_info.graph_output_tensors);
|
||||
}
|
||||
}
|
||||
// Clear bucket resources every step
|
||||
if (kernel_graph->is_bprop()) {
|
||||
if (kernel_graph->has_flag(kFlagIsPynativeBpropGraph)) {
|
||||
ClearAllBucket(graph_id);
|
||||
}
|
||||
|
||||
|
@ -2699,10 +2700,8 @@ void SetGraphBpropAttr(const KernelGraphPtr &graph) {
|
|||
auto &execution_orders = graph->execution_order();
|
||||
if (std::any_of(execution_orders.begin(), execution_orders.end(),
|
||||
[](const AnfNodePtr &node) { return node->scope()->name().rfind("Gradient", 0) == 0; })) {
|
||||
graph->set_is_bprop(true);
|
||||
graph->set_flag(kFlagIsPynativeBpropGraph, true);
|
||||
MS_LOG(INFO) << "Match bprop graph";
|
||||
} else {
|
||||
graph->set_is_bprop(false);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2798,7 +2797,7 @@ void SessionBasic::InitAllBucket(const KernelGraphPtr &graph, const device::Devi
|
|||
}
|
||||
SetGraphBpropAttr(graph);
|
||||
|
||||
if (!graph->is_bprop()) {
|
||||
if (!graph->has_flag(kFlagIsPynativeBpropGraph)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@
|
|||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "runtime/graph_scheduler/graph_compiler.h"
|
||||
#include "runtime/pynative/run_op_helper.h"
|
||||
#include "runtime/pynative/graph_adapter.h"
|
||||
#include "distributed/recovery/recovery_context.h"
|
||||
#include "include/common/utils/scoped_long_running.h"
|
||||
#ifdef ENABLE_D
|
||||
|
@ -440,6 +441,7 @@ MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string
|
|||
SetDebuggerInit();
|
||||
#endif
|
||||
runtime::GraphScheduler::GetInstance().Initialize();
|
||||
pynative_run_in_graph_ = device_context->GetDeviceType() != device::DeviceType::kAscend;
|
||||
}
|
||||
|
||||
void MindRTBackend::ProcessNotSupportCnode(const FuncGraphPtr &func_graph,
|
||||
|
@ -889,6 +891,74 @@ void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> ¶m
|
|||
input_tensor->push_back(tensor_input);
|
||||
}
|
||||
|
||||
void MindRTBackend::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, FuncGraph *root_graph) {
|
||||
bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
|
||||
distributed::recovery::RecoveryContext::GetInstance()->need_reset());
|
||||
if (need_contruct_output) {
|
||||
// Update device address for output node of graph.
|
||||
// Summary processing will use the output device address, so must be after the summary processing.
|
||||
actor_set->output_actor_->UpdateOutputDeviceAddress();
|
||||
|
||||
// Fetch outputs.
|
||||
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
|
||||
auto &output_tensors = actor_set->output_actor_->outputs();
|
||||
if (!output_tensors.empty()) {
|
||||
size_t output_position = 0;
|
||||
ConstructOutputs(root_graph->output(), output_tensors, &output_position, outputs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MindRTBackend::RunGraphIntergated(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs) {
|
||||
WaitTaskFinish();
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
||||
auto graphs = graph_compiler_info.graphs_;
|
||||
auto &op_executor = runtime::OpExecutor::GetInstance();
|
||||
op_executor.Register([this]() { BatchBuildCallback(); });
|
||||
for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
|
||||
const auto &graph = graphs[graph_index];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// TODO(caifubi): Update parameter format for Ascend
|
||||
if (!graph->has_flag(kFlagGraphCompiled)) {
|
||||
graph_compiler_->CompileGraphImpl(graph, graph_compiler_info.device_contexts_.front());
|
||||
graph->CacheGraphOutputToFrontNodeWithIndex({graph->output()}, graph->front_outputs());
|
||||
// Clear front outputs
|
||||
graph->set_front_outputs({});
|
||||
// Transform graph to actor DAG, and schedule the actor DAG.
|
||||
const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(graph_compiler_info);
|
||||
runtime::GraphScheduler::GetInstance().Schedule(actor_set);
|
||||
pynative::GraphAdapter::ClearForwardOutputValueNodeDeviceAddress(graph);
|
||||
pynative::GraphAdapter::GenerateRefCountForBpropValueNode(graph);
|
||||
graph->set_flag(kFlagGraphCompiled, true);
|
||||
}
|
||||
pynative::GraphAdapter::UpdateForwardOutputInBpropGraph(graph);
|
||||
}
|
||||
|
||||
// Run actor DAG.
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
runtime::GraphScheduler::GetInstance().Run(actor_set, graph_compiler_info.device_contexts_, inputs);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
||||
graph_compiler_->Summary(graph_compiler_info.graphs_);
|
||||
|
||||
ConstructOutputs(actor_set, outputs, root_graph_);
|
||||
// Clear bucket resources every step
|
||||
for (auto &graph : graphs) {
|
||||
if (graph->has_flag(kFlagIsPynativeBpropGraph)) {
|
||||
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), actor_set->output_actor_->outputs());
|
||||
graph_compiler_->ClearAllBucket(graph->graph_id());
|
||||
}
|
||||
}
|
||||
|
||||
runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
|
||||
// Close abstract_lock for dynamic_shape
|
||||
AnfUtils::CloseAbstractLock();
|
||||
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
|
||||
}
|
||||
|
||||
void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs) {
|
||||
WaitTaskFinish();
|
||||
|
@ -941,18 +1011,42 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
|
|||
graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
|
||||
|
||||
// Save grad node to Bucket
|
||||
if (graph->is_bprop() && (!common::AnfAlgo::IsControlOpExecInBackend(kernel)) && !kernel->is_parallel()) {
|
||||
if (graph->has_flag(kFlagIsPynativeBpropGraph) && (!common::AnfAlgo::IsControlOpExecInBackend(kernel)) &&
|
||||
!kernel->is_parallel()) {
|
||||
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
|
||||
}
|
||||
}
|
||||
WaitTaskFinish();
|
||||
// Clear bucket resources every step
|
||||
if (graph->is_bprop()) {
|
||||
if (graph->has_flag(kFlagIsPynativeBpropGraph)) {
|
||||
graph_compiler_->ClearAllBucket(graph->graph_id());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MindRTBackend::RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &input_tensors,
|
||||
VectorRef *outputs) {
|
||||
bool contain_cut_graph = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(),
|
||||
[](const KernelGraphPtr &graph) { return graph->has_flag(kFlagsIsCutGraph); });
|
||||
if (contain_cut_graph) {
|
||||
// Python API will be called in cut_graph, so we cannot release gil here.
|
||||
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
|
||||
} else {
|
||||
// Release python gil.
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
bool is_dynamic_shape = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(),
|
||||
[](const KernelGraphPtr &graph) { return graph->is_dynamic_shape(); });
|
||||
// TODO(caifubi): PyNative dynamic shape not support run in graph now.
|
||||
if (pynative_run_in_graph_ && !is_dynamic_shape && !root_graph_->has_flag(kFlagIsDynamicStructure)) {
|
||||
RunGraphIntergated(actor_info, graph_compiler_info, input_tensors, outputs);
|
||||
} else {
|
||||
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
|
||||
}
|
||||
|
||||
void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(root_graph_);
|
||||
if (IsGraphOutputValueNodeOrParameter(root_graph_->output(), args, outputs)) {
|
||||
|
@ -1013,16 +1107,7 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
|
|||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
// There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode.
|
||||
if (real_execution_mode_ == kPynativeMode) {
|
||||
bool is_cut_graph = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(),
|
||||
[](const KernelGraphPtr &graph) { return graph->has_flag(kFlagsIsCutGraph); });
|
||||
if (is_cut_graph) {
|
||||
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
|
||||
} else {
|
||||
// Release python gil.
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
|
||||
}
|
||||
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
|
||||
RunGraphByCondition(actor_info, graph_compiler_info, input_tensors, outputs);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1036,21 +1121,7 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
|
|||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
||||
graph_compiler_->Summary(graph_compiler_info.graphs_);
|
||||
|
||||
bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
|
||||
distributed::recovery::RecoveryContext::GetInstance()->need_reset());
|
||||
if (need_contruct_output) {
|
||||
// Update device address for output node of graph.
|
||||
// Summary processing will use the output device address, so must be after the summary processing.
|
||||
actor_set->output_actor_->UpdateOutputDeviceAddress();
|
||||
|
||||
// Fetch outputs.
|
||||
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
|
||||
auto &output_tensors = actor_set->output_actor_->outputs();
|
||||
if (output_tensors.size() > 0) {
|
||||
size_t output_position = 0;
|
||||
ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs);
|
||||
}
|
||||
}
|
||||
ConstructOutputs(actor_set, outputs, root_graph_);
|
||||
|
||||
runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
|
||||
// Close abstract_lock for dynamic_shape
|
||||
|
|
|
@ -143,6 +143,8 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
|
|||
// Get saved OpBuildTask in OpExecutor and build all the kernels together in PyNative mode.
|
||||
void CompileSingleOpGraphs(const std::vector<std::shared_ptr<runtime::OpBuildTask>> &build_tasks);
|
||||
|
||||
void ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, FuncGraph *root_graph);
|
||||
|
||||
// Restore the outputs tuple by the origin funcGraph output node and output tensors.
|
||||
void ConstructOutputs(const AnfNodePtr &output_node, const std::vector<tensor::TensorPtr> &output_tensors,
|
||||
size_t *output_position, VectorRef *outputs);
|
||||
|
@ -173,11 +175,16 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
|
|||
void DispatchOpTask(bool single_op_cache_hit, VectorRef *outputs, GraphCompilerInfo *graph_compiler_info,
|
||||
OpRunInfo *op_run_info);
|
||||
|
||||
void RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &input_tensors, VectorRef *outputs);
|
||||
// Split complete kernel graph to single op graph in PyNative back
|
||||
// propagation, then compile and run single op graph.
|
||||
void RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs);
|
||||
|
||||
void RunGraphIntergated(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs);
|
||||
|
||||
void UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, VectorRef *const outputs);
|
||||
|
||||
void ReleaseForwardOutput(const std::vector<TensorPtr> &input_tensors);
|
||||
|
@ -212,6 +219,8 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
|
|||
void CompileSubGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode = device::RunMode::kUnknown);
|
||||
void ProcessNotSupportCnode(const FuncGraphPtr &func_graph, const device::DeviceType &old_target,
|
||||
const device::DeviceType &new_target);
|
||||
// TODO(caifubi): Remove this flag when Ascend backend is ok.
|
||||
bool pynative_run_in_graph_{false};
|
||||
};
|
||||
using MindRTBackendPtr = std::shared_ptr<compile::MindRTBackend>;
|
||||
} // namespace compile
|
||||
|
|
|
@ -583,12 +583,15 @@ constexpr auto kActualAbstract = "actual_abstract";
|
|||
constexpr auto kAttrZeroInfinity = "zero_infinity";
|
||||
constexpr auto kAttrBlank = "blank";
|
||||
|
||||
// FuncGraph Flags
|
||||
constexpr auto kFlagsIsCutGraph = "is_cut_graph";
|
||||
constexpr auto kFlagGraphCompiled = "graph_compiled";
|
||||
constexpr auto kFlagIsDynamicStructure = "is_dynamic_structure";
|
||||
constexpr auto kFlagIsPynativeBpropGraph = "is_pynative_bprop_graph";
|
||||
|
||||
// TODO(dsj): for ms_function running in graph_mode. should be delete later
|
||||
constexpr auto kAttrMSFunction = "ms_function_graph";
|
||||
|
||||
// KernelGraph Flags
|
||||
constexpr auto kFlagsIsCutGraph = "is_cut_graph";
|
||||
|
||||
// custom operator func type
|
||||
constexpr auto kCustomTypeAOT = "aot";
|
||||
constexpr auto kCustomTypeJULIA = "julia";
|
||||
|
|
|
@ -1033,9 +1033,10 @@ bool TaskEmitAction(const ResourcePtr &resource) {
|
|||
DisableMindRT(resource);
|
||||
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
|
||||
auto is_parallel = (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel);
|
||||
bool pynative_switch_to_graph_mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
|
||||
(!func_graph->is_bprop() || func_graph->manager()->func_graphs().size() > 1) &&
|
||||
!is_parallel;
|
||||
bool pynative_switch_to_graph_mode =
|
||||
context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
|
||||
(!func_graph->has_flag(kFlagIsPynativeBpropGraph) || func_graph->manager()->func_graphs().size() > 1) &&
|
||||
!is_parallel;
|
||||
SetRunMode(resource, pynative_switch_to_graph_mode);
|
||||
auto bc_ptr = resource->GetResult(kBackend).cast<compile::BackendPtr>();
|
||||
MS_EXCEPTION_IF_NULL(bc_ptr);
|
||||
|
|
|
@ -872,6 +872,8 @@ void ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const OpExecIn
|
|||
std::vector<tensor::TensorPtr> total_output_tensors;
|
||||
TensorValueToTensor(added_out, &total_output_tensors);
|
||||
RunReplace(added_make_tuple, total_output_tensors, grad_graph);
|
||||
std::for_each(total_output_tensors.begin(), total_output_tensors.end(),
|
||||
[](tensor::TensorPtr &tensor) { tensor->set_is_forward_output(true); });
|
||||
top_cell->set_op_info_with_ms_func_forward_tensors(op_exec_info->op_info, total_output_tensors);
|
||||
}
|
||||
|
||||
|
@ -926,6 +928,13 @@ void UpdateTensorInfo(const tensor::TensorPtr &new_tensor, const std::vector<ten
|
|||
MS_EXCEPTION_IF_NULL(old_device_address);
|
||||
auto new_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
|
||||
MS_EXCEPTION_IF_NULL(new_device_address);
|
||||
|
||||
// CPU host tensor data_c is different from device address if the address is from mem_pool.
|
||||
if (new_device_address->from_mem_pool()) {
|
||||
pre_tensor->set_device_address(new_device_address);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto old_ptr = old_device_address->GetMutablePtr();
|
||||
MS_EXCEPTION_IF_NULL(old_ptr);
|
||||
auto new_ptr = new_device_address->GetPtr();
|
||||
|
@ -3338,7 +3347,10 @@ void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &g
|
|||
// Get bprop graph of top cell
|
||||
auto bprop_graph = GetBpropGraph(grad, cell, w_args, p_args, size, args);
|
||||
MS_EXCEPTION_IF_NULL(bprop_graph);
|
||||
bprop_graph->set_is_bprop(true);
|
||||
if (top_cell()->is_dynamic_structure()) {
|
||||
bprop_graph->set_flag(kFlagIsDynamicStructure, true);
|
||||
}
|
||||
bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true);
|
||||
resource->set_func_graph(bprop_graph);
|
||||
auto manager = resource->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -3666,6 +3678,7 @@ void GradExecutor::CheckNeedCompileGraph() {
|
|||
pre_top_cell->Clear();
|
||||
already_run_top_cell_[already_top_cell_id] = new_top_cell;
|
||||
g_pyobj_id_cache.clear();
|
||||
top_cell()->set_is_dynamic_structure(true);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "The op info has not been changed, no need to compile graph again";
|
||||
pre_top_cell->set_input_args_id(new_top_cell->input_args_id());
|
||||
|
|
|
@ -104,6 +104,8 @@ class TopCellInfo {
|
|||
void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
|
||||
bool forward_already_run() const { return forward_already_run_; }
|
||||
void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
|
||||
void set_is_dynamic_structure(bool is_dynamic_structure) { is_dynamic_structure_ = is_dynamic_structure; }
|
||||
bool is_dynamic_structure() const { return is_dynamic_structure_; }
|
||||
pipeline::ResourcePtr resource() const { return resource_; }
|
||||
FuncGraphPtr df_builder() const { return df_builder_; }
|
||||
FuncGraphPtr fg() const { return fg_; }
|
||||
|
@ -150,6 +152,7 @@ class TopCellInfo {
|
|||
bool is_init_kpynative_{false};
|
||||
bool forward_already_run_{false};
|
||||
bool need_compile_graph_{false};
|
||||
bool is_dynamic_structure_{false};
|
||||
size_t op_num_{0};
|
||||
size_t grad_order_{0};
|
||||
pipeline::ResourcePtr resource_{nullptr};
|
||||
|
|
|
@ -37,6 +37,22 @@ namespace device {
|
|||
namespace ascend {
|
||||
using AscendAutoMonad = mindspore::session::AscendAutoMonad;
|
||||
|
||||
namespace {
|
||||
void RemoveUnusedValueNode(const KernelGraphPtr &graph) {
|
||||
auto m = graph->manager();
|
||||
auto node_users = m->node_users();
|
||||
mindspore::HashSet<ValueNodePtr> unused_value_nodes;
|
||||
for (auto &value_node : graph->graph_value_nodes()) {
|
||||
if (node_users.find(value_node) == node_users.end()) {
|
||||
unused_value_nodes.insert(value_node);
|
||||
}
|
||||
}
|
||||
for (auto &value_node : unused_value_nodes) {
|
||||
graph->RemoveNodeFromGraph(value_node);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void AscendGraphOptimization::Reset() {
|
||||
MS_LOG(INFO) << "Clear Ascend Graph Optimization Resource.";
|
||||
memo_.clear();
|
||||
|
@ -60,7 +76,9 @@ void AscendGraphOptimization::OptimizeGraph(const KernelGraphPtr &graph) {
|
|||
OptimizeGraphWithDeviceInfo(graph);
|
||||
OptimizeExecutionOrder(graph);
|
||||
PostOptimization(graph);
|
||||
// must clear memo_ which holds kernel graph after using AscendGraphOptimization class.
|
||||
|
||||
RemoveUnusedValueNode(graph);
|
||||
|
||||
memo_.clear();
|
||||
// clear and reset graph_manager_ after optimization
|
||||
graph_manager_ = MakeManager();
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <memory>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/device_sync.h"
|
||||
#include "utils/shape_utils.h"
|
||||
|
@ -107,6 +108,10 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
std::string device_name() const { return device_name_; }
|
||||
uint32_t device_id() const { return device_id_; }
|
||||
|
||||
void set_from_tensor(const std::weak_ptr<tensor::Tensor> &from_tensor) { from_tensors_.emplace_back(from_tensor); }
|
||||
std::vector<std::weak_ptr<tensor::Tensor>> from_tensors() const { return from_tensors_; }
|
||||
void clear_from_tensors() { from_tensors_.clear(); }
|
||||
|
||||
virtual void SetNodeIndex(const AnfNodePtr &node, size_t out_index) { node_index_ = {node, out_index}; }
|
||||
KernelWithIndex GetNodeIndex() const {
|
||||
return node_index_.first.expired() ? KernelWithIndex{nullptr, node_index_.second}
|
||||
|
@ -167,6 +172,7 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
ShapeVector host_shape_{};
|
||||
// {node, out_index}
|
||||
std::pair<AnfNodeWeakPtr, size_t> node_index_{AnfNodePtr(nullptr), 0};
|
||||
std::vector<std::weak_ptr<tensor::Tensor>> from_tensors_;
|
||||
// The device address of the node that owns the device address cannot be updated and replaced.
|
||||
// Application scenario: set to true when the hardware execution mode requires that ptr cannot be changed during
|
||||
// execution.
|
||||
|
|
|
@ -226,7 +226,21 @@ void FreeMemoryByRefCount(DeviceTensor *const device_tensor, const DeviceContext
|
|||
device_tensor->DecreaseRefCount();
|
||||
if (device_tensor->ref_count() == 0) {
|
||||
if (device_tensor->GetPtr() != nullptr) {
|
||||
FreeMemory(device_tensor, device_context);
|
||||
auto from_tensors = device_tensor->from_tensors();
|
||||
if (from_tensors.empty()) {
|
||||
FreeMemory(device_tensor, device_context);
|
||||
} else {
|
||||
std::for_each(from_tensors.begin(), from_tensors.end(), [](const std::weak_ptr<tensor::Tensor> &t) {
|
||||
auto tensor = t.lock();
|
||||
if (tensor != nullptr) {
|
||||
tensor->set_device_address(nullptr);
|
||||
}
|
||||
});
|
||||
device_tensor->clear_from_tensors();
|
||||
// Reset device address status
|
||||
device_tensor->set_original_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCount();
|
||||
}
|
||||
}
|
||||
device_tensor->ResetRefCount();
|
||||
}
|
||||
|
|
|
@ -304,7 +304,8 @@ void DataPrepareActor::UpdateDeviceAddressForDataNode(const AnfNodePtr &input_no
|
|||
// Assign tensor address to input data node and set `ref_count` to `SIZE_MAX` for avoiding clean
|
||||
AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
|
||||
tensor_address->SetNodeIndex(input_node, 0);
|
||||
tensor_address->set_ref_count(SIZE_MAX);
|
||||
tensor_address->set_original_ref_count(SIZE_MAX);
|
||||
tensor_address->ResetRefCount();
|
||||
} else if (device_address->GetPtr() != nullptr) {
|
||||
// The `device_address` may come from another previous tensor. In order to prevent pollute the device data of
|
||||
// previous tensor, creating a new device address for holding current input tensor data.
|
||||
|
@ -313,7 +314,8 @@ void DataPrepareActor::UpdateDeviceAddressForDataNode(const AnfNodePtr &input_no
|
|||
MS_EXCEPTION_IF_NULL(new_device_address);
|
||||
AnfAlgo::SetOutputAddr(new_device_address, 0, input_node.get());
|
||||
new_device_address->SetNodeIndex(input_node, 0);
|
||||
new_device_address->set_ref_count(SIZE_MAX);
|
||||
new_device_address->set_original_ref_count(SIZE_MAX);
|
||||
new_device_address->ResetRefCount();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -161,12 +161,10 @@ void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, cons
|
|||
}
|
||||
auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
|
||||
if (output_address != nullptr && output_address->GetDeviceType() == device_context->GetDeviceType()) {
|
||||
// The input of PyNative bprop graph is ValueNode.
|
||||
// Setting the address to the ValueNode will lead to memory leak.
|
||||
if (!graph->is_bprop()) {
|
||||
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
|
||||
value_node.get());
|
||||
}
|
||||
// We need to set tensor->device_address to ValueNode even if the tensor is a forward_output tensor
|
||||
// in PyNative Bprop graph. ValueNode device_address is necessary for GraphSchedule::Transform.
|
||||
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
|
||||
value_node.get());
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -369,18 +367,6 @@ void SetSummaryNodesRefCount(const KernelGraph *graph) {
|
|||
}
|
||||
}
|
||||
|
||||
void UpdateRefCountForGraphOutput(const std::vector<KernelWithIndex> &output_with_index) {
|
||||
for (const auto &item_with_index : output_with_index) {
|
||||
if (!AnfAlgo::OutputAddrExist(item_with_index.first, item_with_index.second, false)) {
|
||||
continue;
|
||||
}
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(item_with_index.first, item_with_index.second, false);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
device_address->set_original_ref_count(SIZE_MAX);
|
||||
device_address->ResetRefCount();
|
||||
}
|
||||
}
|
||||
|
||||
void SetGraphInputNodeActualAbstract(const session::OpRunInfo &op_run_info, const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!op_run_info.output_is_dynamic_shape && !op_run_info.input_is_dynamic_shape) {
|
||||
|
@ -446,11 +432,13 @@ GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNod
|
|||
MS_EXCEPTION_IF_NULL(session_);
|
||||
// Graph kernel does not support pynative mode now, print a warning here.
|
||||
graphkernel::GraphKernelFlags::GetInstance().CheckSupport();
|
||||
session_->InitAllBucket(graph, device_context);
|
||||
graph_id = graph->graph_id();
|
||||
} else {
|
||||
graph_id = CompileGraphImpl(graph, device_context);
|
||||
}
|
||||
session_->InitAllBucket(graph, device_context);
|
||||
|
||||
graph->set_front_outputs(outputs);
|
||||
|
||||
session_->DumpGraphs({graph});
|
||||
|
||||
|
@ -559,7 +547,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
|||
graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
|
||||
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
session_->InitAllBucket(graph, device_context);
|
||||
SetSummaryNodesRefCount(graph.get());
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
// Dump .pb graph after graph optimization.
|
||||
|
@ -629,7 +616,6 @@ GraphId GraphCompiler::CompileGraph(const session::OpRunInfo &op_run_info, bool
|
|||
(void)outputs_with_index.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(node, 0, false));
|
||||
}
|
||||
|
||||
UpdateRefCountForGraphOutput(outputs_with_index);
|
||||
AnfAlgo::UpdateGraphValidRefPair(graph);
|
||||
return graph->graph_id();
|
||||
}
|
||||
|
|
|
@ -183,13 +183,13 @@ class GraphCompiler {
|
|||
// Remove single op kernel graph cache and output nodes cache.
|
||||
void EraseSingleOpCache(const GraphInfo &graph_info, const GraphId &graph_id);
|
||||
|
||||
private:
|
||||
DISABLE_COPY_AND_ASSIGN(GraphCompiler);
|
||||
|
||||
// The implementation of compiling graph in Graph Mode, including optimizing graph,
|
||||
// setting operator info, creating kernel and transforming kernel graph to ActorSet.
|
||||
GraphId CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const;
|
||||
|
||||
private:
|
||||
DISABLE_COPY_AND_ASSIGN(GraphCompiler);
|
||||
|
||||
// Add operators' output and input reference map to the graph.
|
||||
void AddOutInRefToGraph(const KernelGraphPtr &graph) const;
|
||||
|
||||
|
|
|
@ -0,0 +1,200 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/pynative/graph_adapter.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ir/tensor.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "runtime/graph_scheduler/device_tensor_store.h"
|
||||
|
||||
namespace mindspore::pynative {
|
||||
namespace {
|
||||
constexpr auto kAttrBpropValueNodeRefCount = "bprop_value_node_ref_count";
|
||||
constexpr auto kAttrValueNodeForwardOuputFlags = "value_node_forward_output_flags";
|
||||
|
||||
tensor::TensorPtr GetTensorFromValueNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<ValueNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
// ValueTuple is already expanded into tensors in backend.
|
||||
if (!value->isa<tensor::Tensor>()) {
|
||||
MS_LOG(DEBUG) << "Only need to process forward output tensor. value:" << value->ToString();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
return tensor;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void GraphAdapter::ClearForwardOutputValueNodeDeviceAddress(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
for (auto &value_node : graph->graph_value_nodes()) {
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
if (tensor->is_forward_output()) {
|
||||
AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The device address of graph value node need to release
|
||||
// if the value node is output of forward_graph in PyNative mode.
|
||||
void GraphAdapter::GenerateRefCountForBpropValueNode(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
HashMap<std::string, size_t> tensor_counts;
|
||||
auto execution_nodes = graph->execution_order();
|
||||
for (auto &node : execution_nodes) {
|
||||
std::vector<session::KernelWithIndex> real_inputs;
|
||||
common::AnfAlgo::GetRealInputs(node, &real_inputs);
|
||||
for (auto &real_input : real_inputs) {
|
||||
auto forward_output_tensor = GetTensorFromValueNode(real_input.first);
|
||||
if (forward_output_tensor == nullptr || !forward_output_tensor->is_forward_output()) {
|
||||
continue;
|
||||
}
|
||||
tensor_counts[forward_output_tensor->id()] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> value_node_ref_count;
|
||||
std::vector<bool> value_node_forward_output_flags;
|
||||
for (auto &value_node : graph->graph_value_nodes()) {
|
||||
auto tensor = GetTensorFromValueNode(value_node);
|
||||
if (tensor == nullptr || !tensor->is_forward_output()) {
|
||||
value_node_ref_count.emplace_back(SIZE_MAX);
|
||||
value_node_forward_output_flags.emplace_back(false);
|
||||
continue;
|
||||
}
|
||||
auto iter = tensor_counts.find(tensor->id());
|
||||
if (iter == tensor_counts.end()) {
|
||||
// The tensor is in bp graph but not used.
|
||||
// e.g. %1-MakeTuple(T1, T2) -> TupleGetItem(%1, 0). T2 is not used.
|
||||
MS_LOG(DEBUG) << "Tensor " << tensor->ToString() << " is not found in value node";
|
||||
value_node_ref_count.emplace_back(SIZE_MAX);
|
||||
value_node_forward_output_flags.emplace_back(false);
|
||||
continue;
|
||||
}
|
||||
|
||||
value_node_ref_count.emplace_back(iter->second);
|
||||
value_node_forward_output_flags.emplace_back(true);
|
||||
}
|
||||
graph->set_attr(kAttrBpropValueNodeRefCount, MakeValue(value_node_ref_count));
|
||||
graph->set_attr(kAttrValueNodeForwardOuputFlags, MakeValue(value_node_forward_output_flags));
|
||||
}
|
||||
|
||||
void GraphAdapter::UpdateForwardOutputInBpropGraph(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(DEBUG) << "Update start";
|
||||
auto value_node_ref_counts = GetValue<std::vector<size_t>>(graph->get_attr(kAttrBpropValueNodeRefCount));
|
||||
auto value_node_forward_output_flags = GetValue<std::vector<bool>>(graph->get_attr(kAttrValueNodeForwardOuputFlags));
|
||||
size_t value_node_size = graph->graph_value_nodes().size();
|
||||
if (value_node_ref_counts.size() != value_node_size || value_node_forward_output_flags.size() != value_node_size) {
|
||||
MS_LOG(EXCEPTION) << "value_node_ref_count.size " << value_node_ref_counts.size()
|
||||
<< " value_node_forward_output_flags.size " << value_node_forward_output_flags.size()
|
||||
<< " not equal to " << value_node_size;
|
||||
}
|
||||
|
||||
size_t value_node_index = 0;
|
||||
HashMap<device::DeviceAddressPtr, size_t> address_ref_count;
|
||||
// Update ValueNode device address
|
||||
for (auto &value_node : graph->graph_value_nodes()) {
|
||||
auto is_forward_output = value_node_forward_output_flags[value_node_index];
|
||||
if (!is_forward_output) {
|
||||
value_node_index++;
|
||||
continue;
|
||||
}
|
||||
size_t value_node_ref_count = value_node_ref_counts[value_node_index++];
|
||||
auto tensor = GetTensorFromValueNode(value_node);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
|
||||
if (device_address == nullptr) {
|
||||
MS_LOG(WARNING) << "Forward output " << tensor->ToString() << " device address is null";
|
||||
continue;
|
||||
}
|
||||
|
||||
if (device_address->GetDeviceType() != device::DeviceType::kCPU) {
|
||||
address_ref_count[device_address] += value_node_ref_count;
|
||||
device_address->set_from_tensor(tensor);
|
||||
}
|
||||
|
||||
auto front_node = AnfAlgo::FetchFrontNodeByBackendNode(value_node, *graph);
|
||||
runtime::DeviceTensorStore::GetInstance().Insert(front_node.get(), device_address);
|
||||
}
|
||||
|
||||
for (auto &[address, ref_count] : address_ref_count) {
|
||||
address->set_original_ref_count(ref_count);
|
||||
address->ResetRefCount();
|
||||
MS_LOG(DEBUG) << "device_address " << address.get() << " ref_count " << address->ref_count();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Update end";
|
||||
}
|
||||
|
||||
bool GraphAdapter::ReplaceBpropGraphParameter(const KernelGraphPtr &graph,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors) {
|
||||
size_t index = 0;
|
||||
bool changed = false;
|
||||
for (const auto &input_node : graph->input_nodes()) {
|
||||
auto params = common::AnfAlgo::GetAllOutput(input_node);
|
||||
for (const auto ¶m : params) {
|
||||
if (index >= input_tensors.size()) {
|
||||
MS_LOG(EXCEPTION) << "Parameter size out of range. Parameter index: " << index
|
||||
<< ", input size: " << input_tensors.size();
|
||||
}
|
||||
const auto &input_tensor = input_tensors[index++];
|
||||
MS_EXCEPTION_IF_NULL(input_tensor);
|
||||
const auto &tensor_address = input_tensor->device_address();
|
||||
auto address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address);
|
||||
if (address != nullptr) {
|
||||
auto tensor_format = address->format();
|
||||
auto param_format = AnfAlgo::GetOutputFormat(param, 0);
|
||||
if (tensor_format != param_format) {
|
||||
// Update parameter format
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{address->format()});
|
||||
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{address->type_id()});
|
||||
kernel_build_info_builder->SetOutputsReshapeType({input_tensor->padding_type()});
|
||||
AnfAlgo::SetOutputAddr(address, 0, param.get());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
|
||||
|
||||
// Update abstract
|
||||
auto type_of_tensor = input_tensor->Dtype();
|
||||
auto shape_of_tensor = input_tensor->shape();
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
|
||||
param->set_abstract(abstract);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace mindspore::pynative
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_GRAPH_ADAPTER_H_
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_GRAPH_ADAPTER_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore::pynative {
|
||||
class GraphAdapter {
|
||||
public:
|
||||
static void UpdateForwardOutputInBpropGraph(const KernelGraphPtr &graph);
|
||||
static bool ReplaceBpropGraphParameter(const KernelGraphPtr &graph,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors);
|
||||
static void GenerateRefCountForBpropValueNode(const KernelGraphPtr &graph);
|
||||
static void ClearForwardOutputValueNodeDeviceAddress(const KernelGraphPtr &graph);
|
||||
};
|
||||
} // namespace mindspore::pynative
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_GRAPH_ADAPTER_H_
|
|
@ -105,7 +105,6 @@ void UpdateInputNodeDeviceAddress(const std::vector<AnfNodePtr> &input_nodes,
|
|||
input_tensor->set_lazy_callback([]() { runtime::OpExecutor::GetInstance().Wait(); });
|
||||
node_address->set_from_persistent_mem(input_tensor->is_parameter());
|
||||
node_address->SetNodeIndex(input_node, 0);
|
||||
UpdateRefCount(node_address.get(), true);
|
||||
}
|
||||
|
||||
// The DeviceType and format of DeviceAddress is always the same after UpdateInputTensor
|
||||
|
@ -174,7 +173,6 @@ void CopyValueNodeTensorToDevice(const ValueNodePtr &node, const device::DeviceC
|
|||
return;
|
||||
}
|
||||
tensor->set_device_address(node_address);
|
||||
UpdateRefCount(node_address.get(), true);
|
||||
CopyTensorDataToDevice(tensor, node, device_context);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,7 +43,6 @@ FuncGraph::FuncGraph(GraphDebugInfoPtr &&debug_info)
|
|||
kw_only_args_count_(0),
|
||||
hyper_param_count_(0),
|
||||
is_generated_(false),
|
||||
is_bprop_(false),
|
||||
return_(nullptr),
|
||||
manager_(),
|
||||
debug_info_(std::move(debug_info)),
|
||||
|
@ -142,7 +141,7 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
|
|||
return p;
|
||||
}
|
||||
|
||||
bool FuncGraph::has_flag(const std::string &key) {
|
||||
bool FuncGraph::has_flag(const std::string &key) const {
|
||||
auto iter = attrs_.find(key);
|
||||
if (iter != attrs_.cend()) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
|
|
|
@ -184,8 +184,6 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
FuncGraphPtr GenerateGraph(const AbstractBasePtrList &args_spec_list);
|
||||
void set_is_generate(bool generated) { is_generated_ = generated; }
|
||||
bool is_generated() const { return is_generated_; }
|
||||
void set_is_bprop(bool is_brop) { is_bprop_ = is_brop; }
|
||||
bool is_bprop() const { return is_bprop_; }
|
||||
|
||||
mindspore::HashMap<std::string, ValuePtr> &attrs() { return attrs_; }
|
||||
void set_attrs(const mindspore::HashMap<std::string, ValuePtr> &attrs) {
|
||||
|
@ -193,7 +191,7 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
attrs_[attr.first] = attr.second;
|
||||
}
|
||||
}
|
||||
bool has_flag(const std::string &key);
|
||||
bool has_flag(const std::string &key) const;
|
||||
void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); }
|
||||
void erase_flag(const std::string &key) { (void)attrs_.erase(key); }
|
||||
|
||||
|
@ -425,9 +423,6 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
size_t hyper_param_count_;
|
||||
// Argument input list for the graph used to generate this graph.
|
||||
bool is_generated_;
|
||||
|
||||
bool is_bprop_;
|
||||
|
||||
// CNode that calls 'return' primitive.
|
||||
// We use shared pointer to manage it.
|
||||
CNodePtr return_;
|
||||
|
|
|
@ -381,6 +381,6 @@ def test_train_feed(num_classes=65536):
|
|||
model = Model(net, loss_fn=loss, optimizer=opt)
|
||||
model.train(3, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
|
||||
loss_value = np.array(parallel_callback.loss_list)
|
||||
expect_out = [11.087254, 10.876551, 10.045684]
|
||||
expect_out = [11.087254, 10.876551, 10.142526]
|
||||
print(loss_value)
|
||||
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)
|
||||
|
|
Loading…
Reference in New Issue