!34784 PyNative Bprop run in Graph.
Merge pull request !34784 from caifubi/master-pynative-run-in-graph-dev-ci
This commit is contained in:
commit
3d0e975c90
|
@ -115,6 +115,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
first_step_ = graph.first_step_;
|
||||
has_optimizer_ = graph.has_optimizer_;
|
||||
is_dynamic_shape_ = graph.is_dynamic_shape_;
|
||||
front_outputs_ = graph.front_outputs_;
|
||||
}
|
||||
|
||||
~KernelGraph() override;
|
||||
|
@ -445,6 +446,9 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
return iter->second;
|
||||
}
|
||||
|
||||
AnfNodePtrList front_outputs() const { return front_outputs_; }
|
||||
void set_front_outputs(const AnfNodePtrList &outputs) { front_outputs_ = outputs; }
|
||||
|
||||
private:
|
||||
// remove value node form graph
|
||||
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
|
||||
|
@ -494,6 +498,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes_;
|
||||
// parameters that will be updated when graph is executed
|
||||
mindspore::HashSet<ParameterPtr> updated_parameters_;
|
||||
|
||||
// graph needn't execute
|
||||
bool executable_{false};
|
||||
// exist summary node in graph
|
||||
|
@ -515,6 +520,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
CNodePtr start_label_;
|
||||
CNodePtr end_goto_;
|
||||
|
||||
AnfNodePtrList front_outputs_;
|
||||
// Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is
|
||||
// related to the input of this kernel graph. The first of unordered map is the input of this kernel graph, the second
|
||||
// of unordered map is front node corresponding to the output of previous kernel graph.
|
||||
|
|
|
@ -1419,7 +1419,8 @@ void SessionBasic::GetParameterIndex(const KernelGraph *graph, const std::vector
|
|||
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||
auto parallel_mode = parallel_context->parallel_mode();
|
||||
bool is_parallel_forward_ms_function =
|
||||
!graph->is_bprop() && (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel);
|
||||
!graph->has_flag(kFlagIsPynativeBpropGraph) &&
|
||||
(parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel);
|
||||
for (const auto &input_node : graph->input_nodes()) {
|
||||
auto params = common::AnfAlgo::GetAllOutput(input_node);
|
||||
for (const auto ¶m : params) {
|
||||
|
@ -2617,12 +2618,12 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
|
|||
HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &op_output_map);
|
||||
HandleOpOutputs(kernel, op_outputs, cnode_refcount, &op_output_map, &graph_output_info);
|
||||
// Save grad node to Bucket
|
||||
if (kernel_graph->is_bprop()) {
|
||||
if (kernel_graph->has_flag(kFlagIsPynativeBpropGraph)) {
|
||||
AddGradAddrToBucket(graph_id, graph_output_info.graph_output_tensors);
|
||||
}
|
||||
}
|
||||
// Clear bucket resources every step
|
||||
if (kernel_graph->is_bprop()) {
|
||||
if (kernel_graph->has_flag(kFlagIsPynativeBpropGraph)) {
|
||||
ClearAllBucket(graph_id);
|
||||
}
|
||||
|
||||
|
@ -2699,10 +2700,8 @@ void SetGraphBpropAttr(const KernelGraphPtr &graph) {
|
|||
auto &execution_orders = graph->execution_order();
|
||||
if (std::any_of(execution_orders.begin(), execution_orders.end(),
|
||||
[](const AnfNodePtr &node) { return node->scope()->name().rfind("Gradient", 0) == 0; })) {
|
||||
graph->set_is_bprop(true);
|
||||
graph->set_flag(kFlagIsPynativeBpropGraph, true);
|
||||
MS_LOG(INFO) << "Match bprop graph";
|
||||
} else {
|
||||
graph->set_is_bprop(false);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2798,7 +2797,7 @@ void SessionBasic::InitAllBucket(const KernelGraphPtr &graph, const device::Devi
|
|||
}
|
||||
SetGraphBpropAttr(graph);
|
||||
|
||||
if (!graph->is_bprop()) {
|
||||
if (!graph->has_flag(kFlagIsPynativeBpropGraph)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@
|
|||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "runtime/graph_scheduler/graph_compiler.h"
|
||||
#include "runtime/pynative/run_op_helper.h"
|
||||
#include "runtime/pynative/graph_adapter.h"
|
||||
#include "distributed/recovery/recovery_context.h"
|
||||
#include "include/common/utils/scoped_long_running.h"
|
||||
#ifdef ENABLE_D
|
||||
|
@ -440,6 +441,7 @@ MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string
|
|||
SetDebuggerInit();
|
||||
#endif
|
||||
runtime::GraphScheduler::GetInstance().Initialize();
|
||||
pynative_run_in_graph_ = device_context->GetDeviceType() != device::DeviceType::kAscend;
|
||||
}
|
||||
|
||||
void MindRTBackend::ProcessNotSupportCnode(const FuncGraphPtr &func_graph,
|
||||
|
@ -889,6 +891,74 @@ void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> ¶m
|
|||
input_tensor->push_back(tensor_input);
|
||||
}
|
||||
|
||||
void MindRTBackend::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, FuncGraph *root_graph) {
|
||||
bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
|
||||
distributed::recovery::RecoveryContext::GetInstance()->need_reset());
|
||||
if (need_contruct_output) {
|
||||
// Update device address for output node of graph.
|
||||
// Summary processing will use the output device address, so must be after the summary processing.
|
||||
actor_set->output_actor_->UpdateOutputDeviceAddress();
|
||||
|
||||
// Fetch outputs.
|
||||
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
|
||||
auto &output_tensors = actor_set->output_actor_->outputs();
|
||||
if (!output_tensors.empty()) {
|
||||
size_t output_position = 0;
|
||||
ConstructOutputs(root_graph->output(), output_tensors, &output_position, outputs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MindRTBackend::RunGraphIntergated(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs) {
|
||||
WaitTaskFinish();
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
||||
auto graphs = graph_compiler_info.graphs_;
|
||||
auto &op_executor = runtime::OpExecutor::GetInstance();
|
||||
op_executor.Register([this]() { BatchBuildCallback(); });
|
||||
for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
|
||||
const auto &graph = graphs[graph_index];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// TODO(caifubi): Update parameter format for Ascend
|
||||
if (!graph->has_flag(kFlagGraphCompiled)) {
|
||||
graph_compiler_->CompileGraphImpl(graph, graph_compiler_info.device_contexts_.front());
|
||||
graph->CacheGraphOutputToFrontNodeWithIndex({graph->output()}, graph->front_outputs());
|
||||
// Clear front outputs
|
||||
graph->set_front_outputs({});
|
||||
// Transform graph to actor DAG, and schedule the actor DAG.
|
||||
const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(graph_compiler_info);
|
||||
runtime::GraphScheduler::GetInstance().Schedule(actor_set);
|
||||
pynative::GraphAdapter::ClearForwardOutputValueNodeDeviceAddress(graph);
|
||||
pynative::GraphAdapter::GenerateRefCountForBpropValueNode(graph);
|
||||
graph->set_flag(kFlagGraphCompiled, true);
|
||||
}
|
||||
pynative::GraphAdapter::UpdateForwardOutputInBpropGraph(graph);
|
||||
}
|
||||
|
||||
// Run actor DAG.
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
runtime::GraphScheduler::GetInstance().Run(actor_set, graph_compiler_info.device_contexts_, inputs);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
||||
graph_compiler_->Summary(graph_compiler_info.graphs_);
|
||||
|
||||
ConstructOutputs(actor_set, outputs, root_graph_);
|
||||
// Clear bucket resources every step
|
||||
for (auto &graph : graphs) {
|
||||
if (graph->has_flag(kFlagIsPynativeBpropGraph)) {
|
||||
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), actor_set->output_actor_->outputs());
|
||||
graph_compiler_->ClearAllBucket(graph->graph_id());
|
||||
}
|
||||
}
|
||||
|
||||
runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
|
||||
// Close abstract_lock for dynamic_shape
|
||||
AnfUtils::CloseAbstractLock();
|
||||
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
|
||||
}
|
||||
|
||||
void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs) {
|
||||
WaitTaskFinish();
|
||||
|
@ -941,18 +1011,42 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
|
|||
graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
|
||||
|
||||
// Save grad node to Bucket
|
||||
if (graph->is_bprop() && (!common::AnfAlgo::IsControlOpExecInBackend(kernel)) && !kernel->is_parallel()) {
|
||||
if (graph->has_flag(kFlagIsPynativeBpropGraph) && (!common::AnfAlgo::IsControlOpExecInBackend(kernel)) &&
|
||||
!kernel->is_parallel()) {
|
||||
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
|
||||
}
|
||||
}
|
||||
WaitTaskFinish();
|
||||
// Clear bucket resources every step
|
||||
if (graph->is_bprop()) {
|
||||
if (graph->has_flag(kFlagIsPynativeBpropGraph)) {
|
||||
graph_compiler_->ClearAllBucket(graph->graph_id());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MindRTBackend::RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &input_tensors,
|
||||
VectorRef *outputs) {
|
||||
bool contain_cut_graph = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(),
|
||||
[](const KernelGraphPtr &graph) { return graph->has_flag(kFlagsIsCutGraph); });
|
||||
if (contain_cut_graph) {
|
||||
// Python API will be called in cut_graph, so we cannot release gil here.
|
||||
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
|
||||
} else {
|
||||
// Release python gil.
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
bool is_dynamic_shape = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(),
|
||||
[](const KernelGraphPtr &graph) { return graph->is_dynamic_shape(); });
|
||||
// TODO(caifubi): PyNative dynamic shape not support run in graph now.
|
||||
if (pynative_run_in_graph_ && !is_dynamic_shape && !root_graph_->has_flag(kFlagIsDynamicStructure)) {
|
||||
RunGraphIntergated(actor_info, graph_compiler_info, input_tensors, outputs);
|
||||
} else {
|
||||
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
|
||||
}
|
||||
|
||||
void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(root_graph_);
|
||||
if (IsGraphOutputValueNodeOrParameter(root_graph_->output(), args, outputs)) {
|
||||
|
@ -1013,16 +1107,7 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
|
|||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
// There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode.
|
||||
if (real_execution_mode_ == kPynativeMode) {
|
||||
bool is_cut_graph = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(),
|
||||
[](const KernelGraphPtr &graph) { return graph->has_flag(kFlagsIsCutGraph); });
|
||||
if (is_cut_graph) {
|
||||
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
|
||||
} else {
|
||||
// Release python gil.
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
|
||||
}
|
||||
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
|
||||
RunGraphByCondition(actor_info, graph_compiler_info, input_tensors, outputs);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1036,21 +1121,7 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
|
|||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
||||
graph_compiler_->Summary(graph_compiler_info.graphs_);
|
||||
|
||||
bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
|
||||
distributed::recovery::RecoveryContext::GetInstance()->need_reset());
|
||||
if (need_contruct_output) {
|
||||
// Update device address for output node of graph.
|
||||
// Summary processing will use the output device address, so must be after the summary processing.
|
||||
actor_set->output_actor_->UpdateOutputDeviceAddress();
|
||||
|
||||
// Fetch outputs.
|
||||
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
|
||||
auto &output_tensors = actor_set->output_actor_->outputs();
|
||||
if (output_tensors.size() > 0) {
|
||||
size_t output_position = 0;
|
||||
ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs);
|
||||
}
|
||||
}
|
||||
ConstructOutputs(actor_set, outputs, root_graph_);
|
||||
|
||||
runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
|
||||
// Close abstract_lock for dynamic_shape
|
||||
|
|
|
@ -143,6 +143,8 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
|
|||
// Get saved OpBuildTask in OpExecutor and build all the kernels together in PyNative mode.
|
||||
void CompileSingleOpGraphs(const std::vector<std::shared_ptr<runtime::OpBuildTask>> &build_tasks);
|
||||
|
||||
void ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, FuncGraph *root_graph);
|
||||
|
||||
// Restore the outputs tuple by the origin funcGraph output node and output tensors.
|
||||
void ConstructOutputs(const AnfNodePtr &output_node, const std::vector<tensor::TensorPtr> &output_tensors,
|
||||
size_t *output_position, VectorRef *outputs);
|
||||
|
@ -173,11 +175,16 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
|
|||
void DispatchOpTask(bool single_op_cache_hit, VectorRef *outputs, GraphCompilerInfo *graph_compiler_info,
|
||||
OpRunInfo *op_run_info);
|
||||
|
||||
void RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &input_tensors, VectorRef *outputs);
|
||||
// Split complete kernel graph to single op graph in PyNative back
|
||||
// propagation, then compile and run single op graph.
|
||||
void RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs);
|
||||
|
||||
void RunGraphIntergated(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs);
|
||||
|
||||
void UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, VectorRef *const outputs);
|
||||
|
||||
void ReleaseForwardOutput(const std::vector<TensorPtr> &input_tensors);
|
||||
|
@ -212,6 +219,8 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
|
|||
void CompileSubGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode = device::RunMode::kUnknown);
|
||||
void ProcessNotSupportCnode(const FuncGraphPtr &func_graph, const device::DeviceType &old_target,
|
||||
const device::DeviceType &new_target);
|
||||
// TODO(caifubi): Remove this flag when Ascend backend is ok.
|
||||
bool pynative_run_in_graph_{false};
|
||||
};
|
||||
using MindRTBackendPtr = std::shared_ptr<compile::MindRTBackend>;
|
||||
} // namespace compile
|
||||
|
|
|
@ -584,12 +584,15 @@ constexpr auto kActualAbstract = "actual_abstract";
|
|||
constexpr auto kAttrZeroInfinity = "zero_infinity";
|
||||
constexpr auto kAttrBlank = "blank";
|
||||
|
||||
// FuncGraph Flags
|
||||
constexpr auto kFlagsIsCutGraph = "is_cut_graph";
|
||||
constexpr auto kFlagGraphCompiled = "graph_compiled";
|
||||
constexpr auto kFlagIsDynamicStructure = "is_dynamic_structure";
|
||||
constexpr auto kFlagIsPynativeBpropGraph = "is_pynative_bprop_graph";
|
||||
|
||||
// TODO(dsj): for ms_function running in graph_mode. should be delete later
|
||||
constexpr auto kAttrMSFunction = "ms_function_graph";
|
||||
|
||||
// KernelGraph Flags
|
||||
constexpr auto kFlagsIsCutGraph = "is_cut_graph";
|
||||
|
||||
// custom operator func type
|
||||
constexpr auto kCustomTypeAOT = "aot";
|
||||
constexpr auto kCustomTypeJULIA = "julia";
|
||||
|
|
|
@ -1016,9 +1016,10 @@ bool TaskEmitAction(const ResourcePtr &resource) {
|
|||
DisableMindRT(resource);
|
||||
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
|
||||
auto is_parallel = (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel);
|
||||
bool pynative_switch_to_graph_mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
|
||||
(!func_graph->is_bprop() || func_graph->manager()->func_graphs().size() > 1) &&
|
||||
!is_parallel;
|
||||
bool pynative_switch_to_graph_mode =
|
||||
context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
|
||||
(!func_graph->has_flag(kFlagIsPynativeBpropGraph) || func_graph->manager()->func_graphs().size() > 1) &&
|
||||
!is_parallel;
|
||||
SetRunMode(resource, pynative_switch_to_graph_mode);
|
||||
auto bc_ptr = resource->GetResult(kBackend).cast<compile::BackendPtr>();
|
||||
MS_EXCEPTION_IF_NULL(bc_ptr);
|
||||
|
|
|
@ -872,6 +872,8 @@ void ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const OpExecIn
|
|||
std::vector<tensor::TensorPtr> total_output_tensors;
|
||||
TensorValueToTensor(added_out, &total_output_tensors);
|
||||
RunReplace(added_make_tuple, total_output_tensors, grad_graph);
|
||||
std::for_each(total_output_tensors.begin(), total_output_tensors.end(),
|
||||
[](tensor::TensorPtr &tensor) { tensor->set_is_forward_output(true); });
|
||||
top_cell->set_op_info_with_ms_func_forward_tensors(op_exec_info->op_info, total_output_tensors);
|
||||
}
|
||||
|
||||
|
@ -926,6 +928,13 @@ void UpdateTensorInfo(const tensor::TensorPtr &new_tensor, const std::vector<ten
|
|||
MS_EXCEPTION_IF_NULL(old_device_address);
|
||||
auto new_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
|
||||
MS_EXCEPTION_IF_NULL(new_device_address);
|
||||
|
||||
// CPU host tensor data_c is different from device address if the address is from mem_pool.
|
||||
if (new_device_address->from_mem_pool()) {
|
||||
pre_tensor->set_device_address(new_device_address);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto old_ptr = old_device_address->GetMutablePtr();
|
||||
MS_EXCEPTION_IF_NULL(old_ptr);
|
||||
auto new_ptr = new_device_address->GetPtr();
|
||||
|
@ -3338,7 +3347,10 @@ void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &g
|
|||
// Get bprop graph of top cell
|
||||
auto bprop_graph = GetBpropGraph(grad, cell, w_args, p_args, size, args);
|
||||
MS_EXCEPTION_IF_NULL(bprop_graph);
|
||||
bprop_graph->set_is_bprop(true);
|
||||
if (top_cell()->is_dynamic_structure()) {
|
||||
bprop_graph->set_flag(kFlagIsDynamicStructure, true);
|
||||
}
|
||||
bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true);
|
||||
resource->set_func_graph(bprop_graph);
|
||||
auto manager = resource->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -3666,6 +3678,7 @@ void GradExecutor::CheckNeedCompileGraph() {
|
|||
pre_top_cell->Clear();
|
||||
already_run_top_cell_[already_top_cell_id] = new_top_cell;
|
||||
g_pyobj_id_cache.clear();
|
||||
top_cell()->set_is_dynamic_structure(true);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "The op info has not been changed, no need to compile graph again";
|
||||
pre_top_cell->set_input_args_id(new_top_cell->input_args_id());
|
||||
|
|
|
@ -104,6 +104,8 @@ class TopCellInfo {
|
|||
void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
|
||||
bool forward_already_run() const { return forward_already_run_; }
|
||||
void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
|
||||
void set_is_dynamic_structure(bool is_dynamic_structure) { is_dynamic_structure_ = is_dynamic_structure; }
|
||||
bool is_dynamic_structure() const { return is_dynamic_structure_; }
|
||||
pipeline::ResourcePtr resource() const { return resource_; }
|
||||
FuncGraphPtr df_builder() const { return df_builder_; }
|
||||
FuncGraphPtr fg() const { return fg_; }
|
||||
|
@ -150,6 +152,7 @@ class TopCellInfo {
|
|||
bool is_init_kpynative_{false};
|
||||
bool forward_already_run_{false};
|
||||
bool need_compile_graph_{false};
|
||||
bool is_dynamic_structure_{false};
|
||||
size_t op_num_{0};
|
||||
size_t grad_order_{0};
|
||||
pipeline::ResourcePtr resource_{nullptr};
|
||||
|
|
|
@ -37,6 +37,22 @@ namespace device {
|
|||
namespace ascend {
|
||||
using AscendAutoMonad = mindspore::session::AscendAutoMonad;
|
||||
|
||||
namespace {
|
||||
void RemoveUnusedValueNode(const KernelGraphPtr &graph) {
|
||||
auto m = graph->manager();
|
||||
auto node_users = m->node_users();
|
||||
mindspore::HashSet<ValueNodePtr> unused_value_nodes;
|
||||
for (auto &value_node : graph->graph_value_nodes()) {
|
||||
if (node_users.find(value_node) == node_users.end()) {
|
||||
unused_value_nodes.insert(value_node);
|
||||
}
|
||||
}
|
||||
for (auto &value_node : unused_value_nodes) {
|
||||
graph->RemoveNodeFromGraph(value_node);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void AscendGraphOptimization::Reset() {
|
||||
MS_LOG(INFO) << "Clear Ascend Graph Optimization Resource.";
|
||||
memo_.clear();
|
||||
|
@ -60,7 +76,9 @@ void AscendGraphOptimization::OptimizeGraph(const KernelGraphPtr &graph) {
|
|||
OptimizeGraphWithDeviceInfo(graph);
|
||||
OptimizeExecutionOrder(graph);
|
||||
PostOptimization(graph);
|
||||
// must clear memo_ which holds kernel graph after using AscendGraphOptimization class.
|
||||
|
||||
RemoveUnusedValueNode(graph);
|
||||
|
||||
memo_.clear();
|
||||
// clear and reset graph_manager_ after optimization
|
||||
graph_manager_ = MakeManager();
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <memory>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/device_sync.h"
|
||||
#include "utils/shape_utils.h"
|
||||
|
@ -107,6 +108,10 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
std::string device_name() const { return device_name_; }
|
||||
uint32_t device_id() const { return device_id_; }
|
||||
|
||||
void set_from_tensor(const std::weak_ptr<tensor::Tensor> &from_tensor) { from_tensors_.emplace_back(from_tensor); }
|
||||
std::vector<std::weak_ptr<tensor::Tensor>> from_tensors() const { return from_tensors_; }
|
||||
void clear_from_tensors() { from_tensors_.clear(); }
|
||||
|
||||
virtual void SetNodeIndex(const AnfNodePtr &node, size_t out_index) { node_index_ = {node, out_index}; }
|
||||
KernelWithIndex GetNodeIndex() const {
|
||||
return node_index_.first.expired() ? KernelWithIndex{nullptr, node_index_.second}
|
||||
|
@ -167,6 +172,7 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
ShapeVector host_shape_{};
|
||||
// {node, out_index}
|
||||
std::pair<AnfNodeWeakPtr, size_t> node_index_{AnfNodePtr(nullptr), 0};
|
||||
std::vector<std::weak_ptr<tensor::Tensor>> from_tensors_;
|
||||
// The device address of the node that owns the device address cannot be updated and replaced.
|
||||
// Application scenario: set to true when the hardware execution mode requires that ptr cannot be changed during
|
||||
// execution.
|
||||
|
|
|
@ -226,7 +226,21 @@ void FreeMemoryByRefCount(DeviceTensor *const device_tensor, const DeviceContext
|
|||
device_tensor->DecreaseRefCount();
|
||||
if (device_tensor->ref_count() == 0) {
|
||||
if (device_tensor->GetPtr() != nullptr) {
|
||||
FreeMemory(device_tensor, device_context);
|
||||
auto from_tensors = device_tensor->from_tensors();
|
||||
if (from_tensors.empty()) {
|
||||
FreeMemory(device_tensor, device_context);
|
||||
} else {
|
||||
std::for_each(from_tensors.begin(), from_tensors.end(), [](const std::weak_ptr<tensor::Tensor> &t) {
|
||||
auto tensor = t.lock();
|
||||
if (tensor != nullptr) {
|
||||
tensor->set_device_address(nullptr);
|
||||
}
|
||||
});
|
||||
device_tensor->clear_from_tensors();
|
||||
// Reset device address status
|
||||
device_tensor->set_original_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCount();
|
||||
}
|
||||
}
|
||||
device_tensor->ResetRefCount();
|
||||
}
|
||||
|
|
|
@ -304,7 +304,8 @@ void DataPrepareActor::UpdateDeviceAddressForDataNode(const AnfNodePtr &input_no
|
|||
// Assign tensor address to input data node and set `ref_count` to `SIZE_MAX` for avoiding clean
|
||||
AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
|
||||
tensor_address->SetNodeIndex(input_node, 0);
|
||||
tensor_address->set_ref_count(SIZE_MAX);
|
||||
tensor_address->set_original_ref_count(SIZE_MAX);
|
||||
tensor_address->ResetRefCount();
|
||||
} else if (device_address->GetPtr() != nullptr) {
|
||||
// The `device_address` may come from another previous tensor. In order to prevent pollute the device data of
|
||||
// previous tensor, creating a new device address for holding current input tensor data.
|
||||
|
@ -313,7 +314,8 @@ void DataPrepareActor::UpdateDeviceAddressForDataNode(const AnfNodePtr &input_no
|
|||
MS_EXCEPTION_IF_NULL(new_device_address);
|
||||
AnfAlgo::SetOutputAddr(new_device_address, 0, input_node.get());
|
||||
new_device_address->SetNodeIndex(input_node, 0);
|
||||
new_device_address->set_ref_count(SIZE_MAX);
|
||||
new_device_address->set_original_ref_count(SIZE_MAX);
|
||||
new_device_address->ResetRefCount();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -161,12 +161,10 @@ void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, cons
|
|||
}
|
||||
auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
|
||||
if (output_address != nullptr && output_address->GetDeviceType() == device_context->GetDeviceType()) {
|
||||
// The input of PyNative bprop graph is ValueNode.
|
||||
// Setting the address to the ValueNode will lead to memory leak.
|
||||
if (!graph->is_bprop()) {
|
||||
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
|
||||
value_node.get());
|
||||
}
|
||||
// We need to set tensor->device_address to ValueNode even if the tensor is a forward_output tensor
|
||||
// in PyNative Bprop graph. ValueNode device_address is necessary for GraphSchedule::Transform.
|
||||
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
|
||||
value_node.get());
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -369,18 +367,6 @@ void SetSummaryNodesRefCount(const KernelGraph *graph) {
|
|||
}
|
||||
}
|
||||
|
||||
void UpdateRefCountForGraphOutput(const std::vector<KernelWithIndex> &output_with_index) {
|
||||
for (const auto &item_with_index : output_with_index) {
|
||||
if (!AnfAlgo::OutputAddrExist(item_with_index.first, item_with_index.second, false)) {
|
||||
continue;
|
||||
}
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(item_with_index.first, item_with_index.second, false);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
device_address->set_original_ref_count(SIZE_MAX);
|
||||
device_address->ResetRefCount();
|
||||
}
|
||||
}
|
||||
|
||||
void SetGraphInputNodeActualAbstract(const session::OpRunInfo &op_run_info, const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!op_run_info.output_is_dynamic_shape && !op_run_info.input_is_dynamic_shape) {
|
||||
|
@ -446,11 +432,13 @@ GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNod
|
|||
MS_EXCEPTION_IF_NULL(session_);
|
||||
// Graph kernel does not support pynative mode now, print a warning here.
|
||||
graphkernel::GraphKernelFlags::GetInstance().CheckSupport();
|
||||
session_->InitAllBucket(graph, device_context);
|
||||
graph_id = graph->graph_id();
|
||||
} else {
|
||||
graph_id = CompileGraphImpl(graph, device_context);
|
||||
}
|
||||
session_->InitAllBucket(graph, device_context);
|
||||
|
||||
graph->set_front_outputs(outputs);
|
||||
|
||||
session_->DumpGraphs({graph});
|
||||
|
||||
|
@ -559,7 +547,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
|||
graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
|
||||
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
session_->InitAllBucket(graph, device_context);
|
||||
SetSummaryNodesRefCount(graph.get());
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
// Dump .pb graph after graph optimization.
|
||||
|
@ -629,7 +616,6 @@ GraphId GraphCompiler::CompileGraph(const session::OpRunInfo &op_run_info, bool
|
|||
(void)outputs_with_index.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(node, 0, false));
|
||||
}
|
||||
|
||||
UpdateRefCountForGraphOutput(outputs_with_index);
|
||||
AnfAlgo::UpdateGraphValidRefPair(graph);
|
||||
return graph->graph_id();
|
||||
}
|
||||
|
|
|
@ -183,13 +183,13 @@ class GraphCompiler {
|
|||
// Remove single op kernel graph cache and output nodes cache.
|
||||
void EraseSingleOpCache(const GraphInfo &graph_info, const GraphId &graph_id);
|
||||
|
||||
private:
|
||||
DISABLE_COPY_AND_ASSIGN(GraphCompiler);
|
||||
|
||||
// The implementation of compiling graph in Graph Mode, including optimizing graph,
|
||||
// setting operator info, creating kernel and transforming kernel graph to ActorSet.
|
||||
GraphId CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const;
|
||||
|
||||
private:
|
||||
DISABLE_COPY_AND_ASSIGN(GraphCompiler);
|
||||
|
||||
// Add operators' output and input reference map to the graph.
|
||||
void AddOutInRefToGraph(const KernelGraphPtr &graph) const;
|
||||
|
||||
|
|
|
@ -0,0 +1,200 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/pynative/graph_adapter.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ir/tensor.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "runtime/graph_scheduler/device_tensor_store.h"
|
||||
|
||||
namespace mindspore::pynative {
|
||||
namespace {
|
||||
constexpr auto kAttrBpropValueNodeRefCount = "bprop_value_node_ref_count";
|
||||
constexpr auto kAttrValueNodeForwardOuputFlags = "value_node_forward_output_flags";
|
||||
|
||||
tensor::TensorPtr GetTensorFromValueNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<ValueNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
// ValueTuple is already expanded into tensors in backend.
|
||||
if (!value->isa<tensor::Tensor>()) {
|
||||
MS_LOG(DEBUG) << "Only need to process forward output tensor. value:" << value->ToString();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
return tensor;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void GraphAdapter::ClearForwardOutputValueNodeDeviceAddress(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
for (auto &value_node : graph->graph_value_nodes()) {
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
if (tensor->is_forward_output()) {
|
||||
AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The device address of graph value node need to release
|
||||
// if the value node is output of forward_graph in PyNative mode.
|
||||
void GraphAdapter::GenerateRefCountForBpropValueNode(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
HashMap<std::string, size_t> tensor_counts;
|
||||
auto execution_nodes = graph->execution_order();
|
||||
for (auto &node : execution_nodes) {
|
||||
std::vector<session::KernelWithIndex> real_inputs;
|
||||
common::AnfAlgo::GetRealInputs(node, &real_inputs);
|
||||
for (auto &real_input : real_inputs) {
|
||||
auto forward_output_tensor = GetTensorFromValueNode(real_input.first);
|
||||
if (forward_output_tensor == nullptr || !forward_output_tensor->is_forward_output()) {
|
||||
continue;
|
||||
}
|
||||
tensor_counts[forward_output_tensor->id()] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> value_node_ref_count;
|
||||
std::vector<bool> value_node_forward_output_flags;
|
||||
for (auto &value_node : graph->graph_value_nodes()) {
|
||||
auto tensor = GetTensorFromValueNode(value_node);
|
||||
if (tensor == nullptr || !tensor->is_forward_output()) {
|
||||
value_node_ref_count.emplace_back(SIZE_MAX);
|
||||
value_node_forward_output_flags.emplace_back(false);
|
||||
continue;
|
||||
}
|
||||
auto iter = tensor_counts.find(tensor->id());
|
||||
if (iter == tensor_counts.end()) {
|
||||
// The tensor is in bp graph but not used.
|
||||
// e.g. %1-MakeTuple(T1, T2) -> TupleGetItem(%1, 0). T2 is not used.
|
||||
MS_LOG(DEBUG) << "Tensor " << tensor->ToString() << " is not found in value node";
|
||||
value_node_ref_count.emplace_back(SIZE_MAX);
|
||||
value_node_forward_output_flags.emplace_back(false);
|
||||
continue;
|
||||
}
|
||||
|
||||
value_node_ref_count.emplace_back(iter->second);
|
||||
value_node_forward_output_flags.emplace_back(true);
|
||||
}
|
||||
graph->set_attr(kAttrBpropValueNodeRefCount, MakeValue(value_node_ref_count));
|
||||
graph->set_attr(kAttrValueNodeForwardOuputFlags, MakeValue(value_node_forward_output_flags));
|
||||
}
|
||||
|
||||
void GraphAdapter::UpdateForwardOutputInBpropGraph(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(DEBUG) << "Update start";
|
||||
auto value_node_ref_counts = GetValue<std::vector<size_t>>(graph->get_attr(kAttrBpropValueNodeRefCount));
|
||||
auto value_node_forward_output_flags = GetValue<std::vector<bool>>(graph->get_attr(kAttrValueNodeForwardOuputFlags));
|
||||
size_t value_node_size = graph->graph_value_nodes().size();
|
||||
if (value_node_ref_counts.size() != value_node_size || value_node_forward_output_flags.size() != value_node_size) {
|
||||
MS_LOG(EXCEPTION) << "value_node_ref_count.size " << value_node_ref_counts.size()
|
||||
<< " value_node_forward_output_flags.size " << value_node_forward_output_flags.size()
|
||||
<< " not equal to " << value_node_size;
|
||||
}
|
||||
|
||||
size_t value_node_index = 0;
|
||||
HashMap<device::DeviceAddressPtr, size_t> address_ref_count;
|
||||
// Update ValueNode device address
|
||||
for (auto &value_node : graph->graph_value_nodes()) {
|
||||
auto is_forward_output = value_node_forward_output_flags[value_node_index];
|
||||
if (!is_forward_output) {
|
||||
value_node_index++;
|
||||
continue;
|
||||
}
|
||||
size_t value_node_ref_count = value_node_ref_counts[value_node_index++];
|
||||
auto tensor = GetTensorFromValueNode(value_node);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
|
||||
if (device_address == nullptr) {
|
||||
MS_LOG(WARNING) << "Forward output " << tensor->ToString() << " device address is null";
|
||||
continue;
|
||||
}
|
||||
|
||||
if (device_address->GetDeviceType() != device::DeviceType::kCPU) {
|
||||
address_ref_count[device_address] += value_node_ref_count;
|
||||
device_address->set_from_tensor(tensor);
|
||||
}
|
||||
|
||||
auto front_node = AnfAlgo::FetchFrontNodeByBackendNode(value_node, *graph);
|
||||
runtime::DeviceTensorStore::GetInstance().Insert(front_node.get(), device_address);
|
||||
}
|
||||
|
||||
for (auto &[address, ref_count] : address_ref_count) {
|
||||
address->set_original_ref_count(ref_count);
|
||||
address->ResetRefCount();
|
||||
MS_LOG(DEBUG) << "device_address " << address.get() << " ref_count " << address->ref_count();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Update end";
|
||||
}
|
||||
|
||||
bool GraphAdapter::ReplaceBpropGraphParameter(const KernelGraphPtr &graph,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors) {
|
||||
size_t index = 0;
|
||||
bool changed = false;
|
||||
for (const auto &input_node : graph->input_nodes()) {
|
||||
auto params = common::AnfAlgo::GetAllOutput(input_node);
|
||||
for (const auto ¶m : params) {
|
||||
if (index >= input_tensors.size()) {
|
||||
MS_LOG(EXCEPTION) << "Parameter size out of range. Parameter index: " << index
|
||||
<< ", input size: " << input_tensors.size();
|
||||
}
|
||||
const auto &input_tensor = input_tensors[index++];
|
||||
MS_EXCEPTION_IF_NULL(input_tensor);
|
||||
const auto &tensor_address = input_tensor->device_address();
|
||||
auto address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address);
|
||||
if (address != nullptr) {
|
||||
auto tensor_format = address->format();
|
||||
auto param_format = AnfAlgo::GetOutputFormat(param, 0);
|
||||
if (tensor_format != param_format) {
|
||||
// Update parameter format
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{address->format()});
|
||||
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{address->type_id()});
|
||||
kernel_build_info_builder->SetOutputsReshapeType({input_tensor->padding_type()});
|
||||
AnfAlgo::SetOutputAddr(address, 0, param.get());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
|
||||
|
||||
// Update abstract
|
||||
auto type_of_tensor = input_tensor->Dtype();
|
||||
auto shape_of_tensor = input_tensor->shape();
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
|
||||
param->set_abstract(abstract);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace mindspore::pynative
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_GRAPH_ADAPTER_H_
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_GRAPH_ADAPTER_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore::pynative {
|
||||
class GraphAdapter {
|
||||
public:
|
||||
static void UpdateForwardOutputInBpropGraph(const KernelGraphPtr &graph);
|
||||
static bool ReplaceBpropGraphParameter(const KernelGraphPtr &graph,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors);
|
||||
static void GenerateRefCountForBpropValueNode(const KernelGraphPtr &graph);
|
||||
static void ClearForwardOutputValueNodeDeviceAddress(const KernelGraphPtr &graph);
|
||||
};
|
||||
} // namespace mindspore::pynative
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_GRAPH_ADAPTER_H_
|
|
@ -105,7 +105,6 @@ void UpdateInputNodeDeviceAddress(const std::vector<AnfNodePtr> &input_nodes,
|
|||
input_tensor->set_lazy_callback([]() { runtime::OpExecutor::GetInstance().Wait(); });
|
||||
node_address->set_from_persistent_mem(input_tensor->is_parameter());
|
||||
node_address->SetNodeIndex(input_node, 0);
|
||||
UpdateRefCount(node_address.get(), true);
|
||||
}
|
||||
|
||||
// The DeviceType and format of DeviceAddress is always the same after UpdateInputTensor
|
||||
|
@ -174,7 +173,6 @@ void CopyValueNodeTensorToDevice(const ValueNodePtr &node, const device::DeviceC
|
|||
return;
|
||||
}
|
||||
tensor->set_device_address(node_address);
|
||||
UpdateRefCount(node_address.get(), true);
|
||||
CopyTensorDataToDevice(tensor, node, device_context);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,7 +43,6 @@ FuncGraph::FuncGraph(GraphDebugInfoPtr &&debug_info)
|
|||
kw_only_args_count_(0),
|
||||
hyper_param_count_(0),
|
||||
is_generated_(false),
|
||||
is_bprop_(false),
|
||||
return_(nullptr),
|
||||
manager_(),
|
||||
debug_info_(std::move(debug_info)),
|
||||
|
@ -142,7 +141,7 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
|
|||
return p;
|
||||
}
|
||||
|
||||
bool FuncGraph::has_flag(const std::string &key) {
|
||||
bool FuncGraph::has_flag(const std::string &key) const {
|
||||
auto iter = attrs_.find(key);
|
||||
if (iter != attrs_.cend()) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
|
|
|
@ -184,8 +184,6 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
FuncGraphPtr GenerateGraph(const AbstractBasePtrList &args_spec_list);
|
||||
void set_is_generate(bool generated) { is_generated_ = generated; }
|
||||
bool is_generated() const { return is_generated_; }
|
||||
void set_is_bprop(bool is_brop) { is_bprop_ = is_brop; }
|
||||
bool is_bprop() const { return is_bprop_; }
|
||||
|
||||
mindspore::HashMap<std::string, ValuePtr> &attrs() { return attrs_; }
|
||||
void set_attrs(const mindspore::HashMap<std::string, ValuePtr> &attrs) {
|
||||
|
@ -193,7 +191,7 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
attrs_[attr.first] = attr.second;
|
||||
}
|
||||
}
|
||||
bool has_flag(const std::string &key);
|
||||
bool has_flag(const std::string &key) const;
|
||||
void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); }
|
||||
void erase_flag(const std::string &key) { (void)attrs_.erase(key); }
|
||||
|
||||
|
@ -425,9 +423,6 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
size_t hyper_param_count_;
|
||||
// Argument input list for the graph used to generate this graph.
|
||||
bool is_generated_;
|
||||
|
||||
bool is_bprop_;
|
||||
|
||||
// CNode that calls 'return' primitive.
|
||||
// We use shared pointer to manage it.
|
||||
CNodePtr return_;
|
||||
|
|
|
@ -381,6 +381,6 @@ def test_train_feed(num_classes=65536):
|
|||
model = Model(net, loss_fn=loss, optimizer=opt)
|
||||
model.train(3, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
|
||||
loss_value = np.array(parallel_callback.loss_list)
|
||||
expect_out = [11.087254, 10.876551, 10.045684]
|
||||
expect_out = [11.087254, 10.876551, 10.142526]
|
||||
print(loss_value)
|
||||
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)
|
||||
|
|
Loading…
Reference in New Issue