From d490450270b197032d68b20d68e548389cf0c162 Mon Sep 17 00:00:00 2001 From: caifubi Date: Mon, 23 May 2022 15:54:23 +0800 Subject: [PATCH] PyNative Bprop run in Graph. 1. Dynamic structure bp graph need to be split and run op by op. 2. Static bp graph can run graph integated. (Only support CPU/GPU/Non-DynamaiShape) --- .../backend/common/session/kernel_graph.h | 6 + .../backend/common/session/session_basic.cc | 13 +- .../ccsrc/backend/graph_compiler/backend.cc | 125 ++++++++--- .../ccsrc/backend/graph_compiler/backend.h | 9 + mindspore/ccsrc/include/common/utils/utils.h | 9 +- mindspore/ccsrc/pipeline/jit/action.cc | 7 +- .../pipeline/pynative/pynative_execute.cc | 15 +- .../pipeline/pynative/pynative_execute.h | 3 + .../hal/hardware/ascend_graph_optimization.cc | 20 +- .../ccsrc/runtime/device/device_address.h | 6 + .../graph_scheduler/actor/actor_common.cc | 16 +- .../actor/data_prepare_actor.cc | 6 +- .../runtime/graph_scheduler/graph_compiler.cc | 28 +-- .../runtime/graph_scheduler/graph_compiler.h | 6 +- .../ccsrc/runtime/pynative/graph_adapter.cc | 200 ++++++++++++++++++ .../ccsrc/runtime/pynative/graph_adapter.h | 33 +++ .../ccsrc/runtime/pynative/run_op_helper.cc | 2 - mindspore/core/ir/func_graph.cc | 3 +- mindspore/core/ir/func_graph.h | 7 +- tests/st/auto_parallel/cell_shard.py | 2 +- 20 files changed, 436 insertions(+), 80 deletions(-) create mode 100644 mindspore/ccsrc/runtime/pynative/graph_adapter.cc create mode 100644 mindspore/ccsrc/runtime/pynative/graph_adapter.h diff --git a/mindspore/ccsrc/backend/common/session/kernel_graph.h b/mindspore/ccsrc/backend/common/session/kernel_graph.h index a0e3d28382d..ca8c39eaec8 100644 --- a/mindspore/ccsrc/backend/common/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/common/session/kernel_graph.h @@ -115,6 +115,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph { first_step_ = graph.first_step_; has_optimizer_ = graph.has_optimizer_; is_dynamic_shape_ = graph.is_dynamic_shape_; + front_outputs_ = graph.front_outputs_; } ~KernelGraph() override; @@ -445,6 +446,9 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph { return iter->second; } + AnfNodePtrList front_outputs() const { return front_outputs_; } + void set_front_outputs(const AnfNodePtrList &outputs) { front_outputs_ = outputs; } + private: // remove value node form graph bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); @@ -494,6 +498,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph { std::map> summary_nodes_; // parameters that will be updated when graph is executed mindspore::HashSet updated_parameters_; + // graph needn't execute bool executable_{false}; // exist summary node in graph @@ -515,6 +520,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph { CNodePtr start_label_; CNodePtr end_goto_; + AnfNodePtrList front_outputs_; // Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is // related to the input of this kernel graph. The first of unordered map is the input of this kernel graph, the second // of unordered map is front node corresponding to the output of previous kernel graph. diff --git a/mindspore/ccsrc/backend/common/session/session_basic.cc b/mindspore/ccsrc/backend/common/session/session_basic.cc index 888ff97cc7e..d79894f7dce 100644 --- a/mindspore/ccsrc/backend/common/session/session_basic.cc +++ b/mindspore/ccsrc/backend/common/session/session_basic.cc @@ -1419,7 +1419,8 @@ void SessionBasic::GetParameterIndex(const KernelGraph *graph, const std::vector MS_EXCEPTION_IF_NULL(parallel_context); auto parallel_mode = parallel_context->parallel_mode(); bool is_parallel_forward_ms_function = - !graph->is_bprop() && (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel); + !graph->has_flag(kFlagIsPynativeBpropGraph) && + (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel); for (const auto &input_node : graph->input_nodes()) { auto params = common::AnfAlgo::GetAllOutput(input_node); for (const auto ¶m : params) { @@ -2617,12 +2618,12 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector< HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &op_output_map); HandleOpOutputs(kernel, op_outputs, cnode_refcount, &op_output_map, &graph_output_info); // Save grad node to Bucket - if (kernel_graph->is_bprop()) { + if (kernel_graph->has_flag(kFlagIsPynativeBpropGraph)) { AddGradAddrToBucket(graph_id, graph_output_info.graph_output_tensors); } } // Clear bucket resources every step - if (kernel_graph->is_bprop()) { + if (kernel_graph->has_flag(kFlagIsPynativeBpropGraph)) { ClearAllBucket(graph_id); } @@ -2699,10 +2700,8 @@ void SetGraphBpropAttr(const KernelGraphPtr &graph) { auto &execution_orders = graph->execution_order(); if (std::any_of(execution_orders.begin(), execution_orders.end(), [](const AnfNodePtr &node) { return node->scope()->name().rfind("Gradient", 0) == 0; })) { - graph->set_is_bprop(true); + graph->set_flag(kFlagIsPynativeBpropGraph, true); MS_LOG(INFO) << "Match bprop graph"; - } else { - graph->set_is_bprop(false); } } @@ -2798,7 +2797,7 @@ void SessionBasic::InitAllBucket(const KernelGraphPtr &graph, const device::Devi } SetGraphBpropAttr(graph); - if (!graph->is_bprop()) { + if (!graph->has_flag(kFlagIsPynativeBpropGraph)) { return; } diff --git a/mindspore/ccsrc/backend/graph_compiler/backend.cc b/mindspore/ccsrc/backend/graph_compiler/backend.cc index df354ce05e8..e102f528db1 100644 --- a/mindspore/ccsrc/backend/graph_compiler/backend.cc +++ b/mindspore/ccsrc/backend/graph_compiler/backend.cc @@ -37,6 +37,7 @@ #include "runtime/hardware/device_context_manager.h" #include "runtime/graph_scheduler/graph_compiler.h" #include "runtime/pynative/run_op_helper.h" +#include "runtime/pynative/graph_adapter.h" #include "distributed/recovery/recovery_context.h" #include "include/common/utils/scoped_long_running.h" #ifdef ENABLE_D @@ -440,6 +441,7 @@ MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string SetDebuggerInit(); #endif runtime::GraphScheduler::GetInstance().Initialize(); + pynative_run_in_graph_ = device_context->GetDeviceType() != device::DeviceType::kAscend; } void MindRTBackend::ProcessNotSupportCnode(const FuncGraphPtr &func_graph, @@ -889,6 +891,74 @@ void PushTupleTensor(const VectorRef &args, const std::vector ¶m input_tensor->push_back(tensor_input); } +void MindRTBackend::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, FuncGraph *root_graph) { + bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() && + distributed::recovery::RecoveryContext::GetInstance()->need_reset()); + if (need_contruct_output) { + // Update device address for output node of graph. + // Summary processing will use the output device address, so must be after the summary processing. + actor_set->output_actor_->UpdateOutputDeviceAddress(); + + // Fetch outputs. + MS_EXCEPTION_IF_NULL(actor_set->output_actor_); + auto &output_tensors = actor_set->output_actor_->outputs(); + if (!output_tensors.empty()) { + size_t output_position = 0; + ConstructOutputs(root_graph->output(), output_tensors, &output_position, outputs); + } + } +} + +void MindRTBackend::RunGraphIntergated(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info, + const std::vector> &inputs, VectorRef *outputs) { + WaitTaskFinish(); + MS_EXCEPTION_IF_NULL(graph_compiler_); + auto graphs = graph_compiler_info.graphs_; + auto &op_executor = runtime::OpExecutor::GetInstance(); + op_executor.Register([this]() { BatchBuildCallback(); }); + for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) { + const auto &graph = graphs[graph_index]; + MS_EXCEPTION_IF_NULL(graph); + // TODO(caifubi): Update parameter format for Ascend + if (!graph->has_flag(kFlagGraphCompiled)) { + graph_compiler_->CompileGraphImpl(graph, graph_compiler_info.device_contexts_.front()); + graph->CacheGraphOutputToFrontNodeWithIndex({graph->output()}, graph->front_outputs()); + // Clear front outputs + graph->set_front_outputs({}); + // Transform graph to actor DAG, and schedule the actor DAG. + const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(graph_compiler_info); + runtime::GraphScheduler::GetInstance().Schedule(actor_set); + pynative::GraphAdapter::ClearForwardOutputValueNodeDeviceAddress(graph); + pynative::GraphAdapter::GenerateRefCountForBpropValueNode(graph); + graph->set_flag(kFlagGraphCompiled, true); + } + pynative::GraphAdapter::UpdateForwardOutputInBpropGraph(graph); + } + + // Run actor DAG. + mindspore::ScopedLongRunning long_running; + const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info); + MS_EXCEPTION_IF_NULL(actor_set); + runtime::GraphScheduler::GetInstance().Run(actor_set, graph_compiler_info.device_contexts_, inputs); + + MS_EXCEPTION_IF_NULL(graph_compiler_); + graph_compiler_->Summary(graph_compiler_info.graphs_); + + ConstructOutputs(actor_set, outputs, root_graph_); + // Clear bucket resources every step + for (auto &graph : graphs) { + if (graph->has_flag(kFlagIsPynativeBpropGraph)) { + graph_compiler_->AddGradAddrToBucket(graph->graph_id(), actor_set->output_actor_->outputs()); + graph_compiler_->ClearAllBucket(graph->graph_id()); + } + } + + runtime::GraphScheduler::GetInstance().ClearActorData(actor_set); + // Close abstract_lock for dynamic_shape + AnfUtils::CloseAbstractLock(); + MS_LOG(INFO) << "Status record: end run actor: " << actor_info; +} + void MindRTBackend::RunGraphBySingleOp(const std::vector &graphs, const std::vector> &inputs, VectorRef *outputs) { WaitTaskFinish(); @@ -941,18 +1011,42 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector &graphs graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info); // Save grad node to Bucket - if (graph->is_bprop() && (!common::AnfAlgo::IsControlOpExecInBackend(kernel)) && !kernel->is_parallel()) { + if (graph->has_flag(kFlagIsPynativeBpropGraph) && (!common::AnfAlgo::IsControlOpExecInBackend(kernel)) && + !kernel->is_parallel()) { graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors); } } WaitTaskFinish(); // Clear bucket resources every step - if (graph->is_bprop()) { + if (graph->has_flag(kFlagIsPynativeBpropGraph)) { graph_compiler_->ClearAllBucket(graph->graph_id()); } } } +void MindRTBackend::RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info, + const std::vector> &input_tensors, + VectorRef *outputs) { + bool contain_cut_graph = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(), + [](const KernelGraphPtr &graph) { return graph->has_flag(kFlagsIsCutGraph); }); + if (contain_cut_graph) { + // Python API will be called in cut_graph, so we cannot release gil here. + RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs); + } else { + // Release python gil. + mindspore::ScopedLongRunning long_running; + bool is_dynamic_shape = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(), + [](const KernelGraphPtr &graph) { return graph->is_dynamic_shape(); }); + // TODO(caifubi): PyNative dynamic shape not support run in graph now. + if (pynative_run_in_graph_ && !is_dynamic_shape && !root_graph_->has_flag(kFlagIsDynamicStructure)) { + RunGraphIntergated(actor_info, graph_compiler_info, input_tensors, outputs); + } else { + RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs); + } + } + MS_LOG(INFO) << "Status record: end run actor: " << actor_info; +} + void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs) { MS_EXCEPTION_IF_NULL(root_graph_); if (IsGraphOutputValueNodeOrParameter(root_graph_->output(), args, outputs)) { @@ -1013,16 +1107,7 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, MS_EXCEPTION_IF_NULL(outputs); // There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode. if (real_execution_mode_ == kPynativeMode) { - bool is_cut_graph = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(), - [](const KernelGraphPtr &graph) { return graph->has_flag(kFlagsIsCutGraph); }); - if (is_cut_graph) { - RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs); - } else { - // Release python gil. - mindspore::ScopedLongRunning long_running; - RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs); - } - MS_LOG(INFO) << "Status record: end run actor: " << actor_info; + RunGraphByCondition(actor_info, graph_compiler_info, input_tensors, outputs); return; } @@ -1036,21 +1121,7 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, MS_EXCEPTION_IF_NULL(graph_compiler_); graph_compiler_->Summary(graph_compiler_info.graphs_); - bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() && - distributed::recovery::RecoveryContext::GetInstance()->need_reset()); - if (need_contruct_output) { - // Update device address for output node of graph. - // Summary processing will use the output device address, so must be after the summary processing. - actor_set->output_actor_->UpdateOutputDeviceAddress(); - - // Fetch outputs. - MS_EXCEPTION_IF_NULL(actor_set->output_actor_); - auto &output_tensors = actor_set->output_actor_->outputs(); - if (output_tensors.size() > 0) { - size_t output_position = 0; - ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs); - } - } + ConstructOutputs(actor_set, outputs, root_graph_); runtime::GraphScheduler::GetInstance().ClearActorData(actor_set); // Close abstract_lock for dynamic_shape diff --git a/mindspore/ccsrc/backend/graph_compiler/backend.h b/mindspore/ccsrc/backend/graph_compiler/backend.h index 999669d647f..5d9670013ef 100644 --- a/mindspore/ccsrc/backend/graph_compiler/backend.h +++ b/mindspore/ccsrc/backend/graph_compiler/backend.h @@ -143,6 +143,8 @@ class BACKEND_EXPORT MindRTBackend : public Backend { // Get saved OpBuildTask in OpExecutor and build all the kernels together in PyNative mode. void CompileSingleOpGraphs(const std::vector> &build_tasks); + void ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, FuncGraph *root_graph); + // Restore the outputs tuple by the origin funcGraph output node and output tensors. void ConstructOutputs(const AnfNodePtr &output_node, const std::vector &output_tensors, size_t *output_position, VectorRef *outputs); @@ -173,11 +175,16 @@ class BACKEND_EXPORT MindRTBackend : public Backend { void DispatchOpTask(bool single_op_cache_hit, VectorRef *outputs, GraphCompilerInfo *graph_compiler_info, OpRunInfo *op_run_info); + void RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info, + const std::vector> &input_tensors, VectorRef *outputs); // Split complete kernel graph to single op graph in PyNative back // propagation, then compile and run single op graph. void RunGraphBySingleOp(const std::vector &graphs, const std::vector> &inputs, VectorRef *outputs); + void RunGraphIntergated(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info, + const std::vector> &inputs, VectorRef *outputs); + void UpdateOutput(const std::vector &output_nodes, VectorRef *const outputs); void ReleaseForwardOutput(const std::vector &input_tensors); @@ -212,6 +219,8 @@ class BACKEND_EXPORT MindRTBackend : public Backend { void CompileSubGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode = device::RunMode::kUnknown); void ProcessNotSupportCnode(const FuncGraphPtr &func_graph, const device::DeviceType &old_target, const device::DeviceType &new_target); + // TODO(caifubi): Remove this flag when Ascend backend is ok. + bool pynative_run_in_graph_{false}; }; using MindRTBackendPtr = std::shared_ptr; } // namespace compile diff --git a/mindspore/ccsrc/include/common/utils/utils.h b/mindspore/ccsrc/include/common/utils/utils.h index 08eccdb56e4..b9393880104 100644 --- a/mindspore/ccsrc/include/common/utils/utils.h +++ b/mindspore/ccsrc/include/common/utils/utils.h @@ -583,12 +583,15 @@ constexpr auto kActualAbstract = "actual_abstract"; constexpr auto kAttrZeroInfinity = "zero_infinity"; constexpr auto kAttrBlank = "blank"; +// FuncGraph Flags +constexpr auto kFlagsIsCutGraph = "is_cut_graph"; +constexpr auto kFlagGraphCompiled = "graph_compiled"; +constexpr auto kFlagIsDynamicStructure = "is_dynamic_structure"; +constexpr auto kFlagIsPynativeBpropGraph = "is_pynative_bprop_graph"; + // TODO(dsj): for ms_function running in graph_mode. should be delete later constexpr auto kAttrMSFunction = "ms_function_graph"; -// KernelGraph Flags -constexpr auto kFlagsIsCutGraph = "is_cut_graph"; - // custom operator func type constexpr auto kCustomTypeAOT = "aot"; constexpr auto kCustomTypeJULIA = "julia"; diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 2178e5d24c1..bcf55c8e434 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -1033,9 +1033,10 @@ bool TaskEmitAction(const ResourcePtr &resource) { DisableMindRT(resource); auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode(); auto is_parallel = (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel); - bool pynative_switch_to_graph_mode = context_ptr->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode && - (!func_graph->is_bprop() || func_graph->manager()->func_graphs().size() > 1) && - !is_parallel; + bool pynative_switch_to_graph_mode = + context_ptr->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode && + (!func_graph->has_flag(kFlagIsPynativeBpropGraph) || func_graph->manager()->func_graphs().size() > 1) && + !is_parallel; SetRunMode(resource, pynative_switch_to_graph_mode); auto bc_ptr = resource->GetResult(kBackend).cast(); MS_EXCEPTION_IF_NULL(bc_ptr); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 1afb68b5f7e..1e5e714faee 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -872,6 +872,8 @@ void ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const OpExecIn std::vector total_output_tensors; TensorValueToTensor(added_out, &total_output_tensors); RunReplace(added_make_tuple, total_output_tensors, grad_graph); + std::for_each(total_output_tensors.begin(), total_output_tensors.end(), + [](tensor::TensorPtr &tensor) { tensor->set_is_forward_output(true); }); top_cell->set_op_info_with_ms_func_forward_tensors(op_exec_info->op_info, total_output_tensors); } @@ -926,6 +928,13 @@ void UpdateTensorInfo(const tensor::TensorPtr &new_tensor, const std::vector(new_tensor->device_address()); MS_EXCEPTION_IF_NULL(new_device_address); + + // CPU host tensor data_c is different from device address if the address is from mem_pool. + if (new_device_address->from_mem_pool()) { + pre_tensor->set_device_address(new_device_address); + continue; + } + auto old_ptr = old_device_address->GetMutablePtr(); MS_EXCEPTION_IF_NULL(old_ptr); auto new_ptr = new_device_address->GetPtr(); @@ -3338,7 +3347,10 @@ void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &g // Get bprop graph of top cell auto bprop_graph = GetBpropGraph(grad, cell, w_args, p_args, size, args); MS_EXCEPTION_IF_NULL(bprop_graph); - bprop_graph->set_is_bprop(true); + if (top_cell()->is_dynamic_structure()) { + bprop_graph->set_flag(kFlagIsDynamicStructure, true); + } + bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true); resource->set_func_graph(bprop_graph); auto manager = resource->manager(); MS_EXCEPTION_IF_NULL(manager); @@ -3666,6 +3678,7 @@ void GradExecutor::CheckNeedCompileGraph() { pre_top_cell->Clear(); already_run_top_cell_[already_top_cell_id] = new_top_cell; g_pyobj_id_cache.clear(); + top_cell()->set_is_dynamic_structure(true); } else { MS_LOG(DEBUG) << "The op info has not been changed, no need to compile graph again"; pre_top_cell->set_input_args_id(new_top_cell->input_args_id()); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index bb3cda3741a..1c0cbe3a23b 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -104,6 +104,8 @@ class TopCellInfo { void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; } bool forward_already_run() const { return forward_already_run_; } void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; } + void set_is_dynamic_structure(bool is_dynamic_structure) { is_dynamic_structure_ = is_dynamic_structure; } + bool is_dynamic_structure() const { return is_dynamic_structure_; } pipeline::ResourcePtr resource() const { return resource_; } FuncGraphPtr df_builder() const { return df_builder_; } FuncGraphPtr fg() const { return fg_; } @@ -150,6 +152,7 @@ class TopCellInfo { bool is_init_kpynative_{false}; bool forward_already_run_{false}; bool need_compile_graph_{false}; + bool is_dynamic_structure_{false}; size_t op_num_{0}; size_t grad_order_{0}; pipeline::ResourcePtr resource_{nullptr}; diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_graph_optimization.cc b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_graph_optimization.cc index 9b5995759fd..c807cea5bed 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_graph_optimization.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_graph_optimization.cc @@ -37,6 +37,22 @@ namespace device { namespace ascend { using AscendAutoMonad = mindspore::session::AscendAutoMonad; +namespace { +void RemoveUnusedValueNode(const KernelGraphPtr &graph) { + auto m = graph->manager(); + auto node_users = m->node_users(); + mindspore::HashSet unused_value_nodes; + for (auto &value_node : graph->graph_value_nodes()) { + if (node_users.find(value_node) == node_users.end()) { + unused_value_nodes.insert(value_node); + } + } + for (auto &value_node : unused_value_nodes) { + graph->RemoveNodeFromGraph(value_node); + } +} +} // namespace + void AscendGraphOptimization::Reset() { MS_LOG(INFO) << "Clear Ascend Graph Optimization Resource."; memo_.clear(); @@ -60,7 +76,9 @@ void AscendGraphOptimization::OptimizeGraph(const KernelGraphPtr &graph) { OptimizeGraphWithDeviceInfo(graph); OptimizeExecutionOrder(graph); PostOptimization(graph); - // must clear memo_ which holds kernel graph after using AscendGraphOptimization class. + + RemoveUnusedValueNode(graph); + memo_.clear(); // clear and reset graph_manager_ after optimization graph_manager_ = MakeManager(); diff --git a/mindspore/ccsrc/runtime/device/device_address.h b/mindspore/ccsrc/runtime/device/device_address.h index 5c9f1235006..866a0799ad7 100644 --- a/mindspore/ccsrc/runtime/device/device_address.h +++ b/mindspore/ccsrc/runtime/device/device_address.h @@ -22,6 +22,7 @@ #include #include #include +#include "ir/tensor.h" #include "ir/dtype.h" #include "ir/device_sync.h" #include "utils/shape_utils.h" @@ -107,6 +108,10 @@ class DeviceAddress : public mindspore::DeviceSync { std::string device_name() const { return device_name_; } uint32_t device_id() const { return device_id_; } + void set_from_tensor(const std::weak_ptr &from_tensor) { from_tensors_.emplace_back(from_tensor); } + std::vector> from_tensors() const { return from_tensors_; } + void clear_from_tensors() { from_tensors_.clear(); } + virtual void SetNodeIndex(const AnfNodePtr &node, size_t out_index) { node_index_ = {node, out_index}; } KernelWithIndex GetNodeIndex() const { return node_index_.first.expired() ? KernelWithIndex{nullptr, node_index_.second} @@ -167,6 +172,7 @@ class DeviceAddress : public mindspore::DeviceSync { ShapeVector host_shape_{}; // {node, out_index} std::pair node_index_{AnfNodePtr(nullptr), 0}; + std::vector> from_tensors_; // The device address of the node that owns the device address cannot be updated and replaced. // Application scenario: set to true when the hardware execution mode requires that ptr cannot be changed during // execution. diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc index 66784315df6..75cf3b460e9 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc @@ -226,7 +226,21 @@ void FreeMemoryByRefCount(DeviceTensor *const device_tensor, const DeviceContext device_tensor->DecreaseRefCount(); if (device_tensor->ref_count() == 0) { if (device_tensor->GetPtr() != nullptr) { - FreeMemory(device_tensor, device_context); + auto from_tensors = device_tensor->from_tensors(); + if (from_tensors.empty()) { + FreeMemory(device_tensor, device_context); + } else { + std::for_each(from_tensors.begin(), from_tensors.end(), [](const std::weak_ptr &t) { + auto tensor = t.lock(); + if (tensor != nullptr) { + tensor->set_device_address(nullptr); + } + }); + device_tensor->clear_from_tensors(); + // Reset device address status + device_tensor->set_original_ref_count(SIZE_MAX); + device_tensor->ResetRefCount(); + } } device_tensor->ResetRefCount(); } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.cc index de75693f844..e5f814af1fe 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.cc @@ -304,7 +304,8 @@ void DataPrepareActor::UpdateDeviceAddressForDataNode(const AnfNodePtr &input_no // Assign tensor address to input data node and set `ref_count` to `SIZE_MAX` for avoiding clean AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get()); tensor_address->SetNodeIndex(input_node, 0); - tensor_address->set_ref_count(SIZE_MAX); + tensor_address->set_original_ref_count(SIZE_MAX); + tensor_address->ResetRefCount(); } else if (device_address->GetPtr() != nullptr) { // The `device_address` may come from another previous tensor. In order to prevent pollute the device data of // previous tensor, creating a new device address for holding current input tensor data. @@ -313,7 +314,8 @@ void DataPrepareActor::UpdateDeviceAddressForDataNode(const AnfNodePtr &input_no MS_EXCEPTION_IF_NULL(new_device_address); AnfAlgo::SetOutputAddr(new_device_address, 0, input_node.get()); new_device_address->SetNodeIndex(input_node, 0); - new_device_address->set_ref_count(SIZE_MAX); + new_device_address->set_original_ref_count(SIZE_MAX); + new_device_address->ResetRefCount(); } } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc b/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc index 38bd0e5cf8b..5885a0ed6d6 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc @@ -161,12 +161,10 @@ void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, cons } auto output_address = std::dynamic_pointer_cast(tensor->device_address()); if (output_address != nullptr && output_address->GetDeviceType() == device_context->GetDeviceType()) { - // The input of PyNative bprop graph is ValueNode. - // Setting the address to the ValueNode will lead to memory leak. - if (!graph->is_bprop()) { - AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast(tensor->device_address()), output_idx++, - value_node.get()); - } + // We need to set tensor->device_address to ValueNode even if the tensor is a forward_output tensor + // in PyNative Bprop graph. ValueNode device_address is necessary for GraphSchedule::Transform. + AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast(tensor->device_address()), output_idx++, + value_node.get()); continue; } @@ -369,18 +367,6 @@ void SetSummaryNodesRefCount(const KernelGraph *graph) { } } -void UpdateRefCountForGraphOutput(const std::vector &output_with_index) { - for (const auto &item_with_index : output_with_index) { - if (!AnfAlgo::OutputAddrExist(item_with_index.first, item_with_index.second, false)) { - continue; - } - auto device_address = AnfAlgo::GetMutableOutputAddr(item_with_index.first, item_with_index.second, false); - MS_EXCEPTION_IF_NULL(device_address); - device_address->set_original_ref_count(SIZE_MAX); - device_address->ResetRefCount(); - } -} - void SetGraphInputNodeActualAbstract(const session::OpRunInfo &op_run_info, const KernelGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); if (!op_run_info.output_is_dynamic_shape && !op_run_info.input_is_dynamic_shape) { @@ -446,11 +432,13 @@ GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNod MS_EXCEPTION_IF_NULL(session_); // Graph kernel does not support pynative mode now, print a warning here. graphkernel::GraphKernelFlags::GetInstance().CheckSupport(); - session_->InitAllBucket(graph, device_context); graph_id = graph->graph_id(); } else { graph_id = CompileGraphImpl(graph, device_context); } + session_->InitAllBucket(graph, device_context); + + graph->set_front_outputs(outputs); session_->DumpGraphs({graph}); @@ -559,7 +547,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get())); MS_EXCEPTION_IF_NULL(session_); - session_->InitAllBucket(graph, device_context); SetSummaryNodesRefCount(graph.get()); #ifdef ENABLE_DUMP_IR // Dump .pb graph after graph optimization. @@ -629,7 +616,6 @@ GraphId GraphCompiler::CompileGraph(const session::OpRunInfo &op_run_info, bool (void)outputs_with_index.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(node, 0, false)); } - UpdateRefCountForGraphOutput(outputs_with_index); AnfAlgo::UpdateGraphValidRefPair(graph); return graph->graph_id(); } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.h b/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.h index 30d9704c594..793ab3a191f 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.h @@ -183,13 +183,13 @@ class GraphCompiler { // Remove single op kernel graph cache and output nodes cache. void EraseSingleOpCache(const GraphInfo &graph_info, const GraphId &graph_id); - private: - DISABLE_COPY_AND_ASSIGN(GraphCompiler); - // The implementation of compiling graph in Graph Mode, including optimizing graph, // setting operator info, creating kernel and transforming kernel graph to ActorSet. GraphId CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const; + private: + DISABLE_COPY_AND_ASSIGN(GraphCompiler); + // Add operators' output and input reference map to the graph. void AddOutInRefToGraph(const KernelGraphPtr &graph) const; diff --git a/mindspore/ccsrc/runtime/pynative/graph_adapter.cc b/mindspore/ccsrc/runtime/pynative/graph_adapter.cc new file mode 100644 index 00000000000..037e55b4709 --- /dev/null +++ b/mindspore/ccsrc/runtime/pynative/graph_adapter.cc @@ -0,0 +1,200 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/pynative/graph_adapter.h" + +#include +#include +#include +#include "ir/tensor.h" +#include "include/common/utils/convert_utils.h" +#include "include/common/utils/anfalgo.h" +#include "backend/common/session/anf_runtime_algorithm.h" +#include "runtime/graph_scheduler/device_tensor_store.h" + +namespace mindspore::pynative { +namespace { +constexpr auto kAttrBpropValueNodeRefCount = "bprop_value_node_ref_count"; +constexpr auto kAttrValueNodeForwardOuputFlags = "value_node_forward_output_flags"; + +tensor::TensorPtr GetTensorFromValueNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return nullptr; + } + auto value_node = node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + // ValueTuple is already expanded into tensors in backend. + if (!value->isa()) { + MS_LOG(DEBUG) << "Only need to process forward output tensor. value:" << value->ToString(); + return nullptr; + } + + auto tensor = value->cast(); + return tensor; +} +} // namespace + +void GraphAdapter::ClearForwardOutputValueNodeDeviceAddress(const KernelGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + for (auto &value_node : graph->graph_value_nodes()) { + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + auto tensor = value->cast(); + MS_EXCEPTION_IF_NULL(tensor); + if (tensor->is_forward_output()) { + AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get()); + } + } + } +} + +// The device address of graph value node need to release +// if the value node is output of forward_graph in PyNative mode. +void GraphAdapter::GenerateRefCountForBpropValueNode(const KernelGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + HashMap tensor_counts; + auto execution_nodes = graph->execution_order(); + for (auto &node : execution_nodes) { + std::vector real_inputs; + common::AnfAlgo::GetRealInputs(node, &real_inputs); + for (auto &real_input : real_inputs) { + auto forward_output_tensor = GetTensorFromValueNode(real_input.first); + if (forward_output_tensor == nullptr || !forward_output_tensor->is_forward_output()) { + continue; + } + tensor_counts[forward_output_tensor->id()] += 1; + } + } + + std::vector value_node_ref_count; + std::vector value_node_forward_output_flags; + for (auto &value_node : graph->graph_value_nodes()) { + auto tensor = GetTensorFromValueNode(value_node); + if (tensor == nullptr || !tensor->is_forward_output()) { + value_node_ref_count.emplace_back(SIZE_MAX); + value_node_forward_output_flags.emplace_back(false); + continue; + } + auto iter = tensor_counts.find(tensor->id()); + if (iter == tensor_counts.end()) { + // The tensor is in bp graph but not used. + // e.g. %1-MakeTuple(T1, T2) -> TupleGetItem(%1, 0). T2 is not used. + MS_LOG(DEBUG) << "Tensor " << tensor->ToString() << " is not found in value node"; + value_node_ref_count.emplace_back(SIZE_MAX); + value_node_forward_output_flags.emplace_back(false); + continue; + } + + value_node_ref_count.emplace_back(iter->second); + value_node_forward_output_flags.emplace_back(true); + } + graph->set_attr(kAttrBpropValueNodeRefCount, MakeValue(value_node_ref_count)); + graph->set_attr(kAttrValueNodeForwardOuputFlags, MakeValue(value_node_forward_output_flags)); +} + +void GraphAdapter::UpdateForwardOutputInBpropGraph(const KernelGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(DEBUG) << "Update start"; + auto value_node_ref_counts = GetValue>(graph->get_attr(kAttrBpropValueNodeRefCount)); + auto value_node_forward_output_flags = GetValue>(graph->get_attr(kAttrValueNodeForwardOuputFlags)); + size_t value_node_size = graph->graph_value_nodes().size(); + if (value_node_ref_counts.size() != value_node_size || value_node_forward_output_flags.size() != value_node_size) { + MS_LOG(EXCEPTION) << "value_node_ref_count.size " << value_node_ref_counts.size() + << " value_node_forward_output_flags.size " << value_node_forward_output_flags.size() + << " not equal to " << value_node_size; + } + + size_t value_node_index = 0; + HashMap address_ref_count; + // Update ValueNode device address + for (auto &value_node : graph->graph_value_nodes()) { + auto is_forward_output = value_node_forward_output_flags[value_node_index]; + if (!is_forward_output) { + value_node_index++; + continue; + } + size_t value_node_ref_count = value_node_ref_counts[value_node_index++]; + auto tensor = GetTensorFromValueNode(value_node); + MS_EXCEPTION_IF_NULL(tensor); + auto device_address = std::dynamic_pointer_cast(tensor->device_address()); + if (device_address == nullptr) { + MS_LOG(WARNING) << "Forward output " << tensor->ToString() << " device address is null"; + continue; + } + + if (device_address->GetDeviceType() != device::DeviceType::kCPU) { + address_ref_count[device_address] += value_node_ref_count; + device_address->set_from_tensor(tensor); + } + + auto front_node = AnfAlgo::FetchFrontNodeByBackendNode(value_node, *graph); + runtime::DeviceTensorStore::GetInstance().Insert(front_node.get(), device_address); + } + + for (auto &[address, ref_count] : address_ref_count) { + address->set_original_ref_count(ref_count); + address->ResetRefCount(); + MS_LOG(DEBUG) << "device_address " << address.get() << " ref_count " << address->ref_count(); + } + MS_LOG(DEBUG) << "Update end"; +} + +bool GraphAdapter::ReplaceBpropGraphParameter(const KernelGraphPtr &graph, + const std::vector &input_tensors) { + size_t index = 0; + bool changed = false; + for (const auto &input_node : graph->input_nodes()) { + auto params = common::AnfAlgo::GetAllOutput(input_node); + for (const auto ¶m : params) { + if (index >= input_tensors.size()) { + MS_LOG(EXCEPTION) << "Parameter size out of range. Parameter index: " << index + << ", input size: " << input_tensors.size(); + } + const auto &input_tensor = input_tensors[index++]; + MS_EXCEPTION_IF_NULL(input_tensor); + const auto &tensor_address = input_tensor->device_address(); + auto address = std::dynamic_pointer_cast(tensor_address); + if (address != nullptr) { + auto tensor_format = address->format(); + auto param_format = AnfAlgo::GetOutputFormat(param, 0); + if (tensor_format != param_format) { + // Update parameter format + auto kernel_build_info_builder = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_build_info_builder); + kernel_build_info_builder->SetOutputsFormat(std::vector{address->format()}); + kernel_build_info_builder->SetOutputsDeviceType(std::vector{address->type_id()}); + kernel_build_info_builder->SetOutputsReshapeType({input_tensor->padding_type()}); + AnfAlgo::SetOutputAddr(address, 0, param.get()); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get()); + + // Update abstract + auto type_of_tensor = input_tensor->Dtype(); + auto shape_of_tensor = input_tensor->shape(); + auto abstract = std::make_shared(type_of_tensor, shape_of_tensor); + param->set_abstract(abstract); + changed = true; + } + } + } + } + return changed; +} +} // namespace mindspore::pynative diff --git a/mindspore/ccsrc/runtime/pynative/graph_adapter.h b/mindspore/ccsrc/runtime/pynative/graph_adapter.h new file mode 100644 index 00000000000..92bae50a131 --- /dev/null +++ b/mindspore/ccsrc/runtime/pynative/graph_adapter.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_GRAPH_ADAPTER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_GRAPH_ADAPTER_H_ + +#include +#include "backend/common/session/kernel_graph.h" + +namespace mindspore::pynative { +class GraphAdapter { + public: + static void UpdateForwardOutputInBpropGraph(const KernelGraphPtr &graph); + static bool ReplaceBpropGraphParameter(const KernelGraphPtr &graph, + const std::vector &input_tensors); + static void GenerateRefCountForBpropValueNode(const KernelGraphPtr &graph); + static void ClearForwardOutputValueNodeDeviceAddress(const KernelGraphPtr &graph); +}; +} // namespace mindspore::pynative +#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_GRAPH_ADAPTER_H_ diff --git a/mindspore/ccsrc/runtime/pynative/run_op_helper.cc b/mindspore/ccsrc/runtime/pynative/run_op_helper.cc index bb931296f4a..08ee5a0a543 100644 --- a/mindspore/ccsrc/runtime/pynative/run_op_helper.cc +++ b/mindspore/ccsrc/runtime/pynative/run_op_helper.cc @@ -105,7 +105,6 @@ void UpdateInputNodeDeviceAddress(const std::vector &input_nodes, input_tensor->set_lazy_callback([]() { runtime::OpExecutor::GetInstance().Wait(); }); node_address->set_from_persistent_mem(input_tensor->is_parameter()); node_address->SetNodeIndex(input_node, 0); - UpdateRefCount(node_address.get(), true); } // The DeviceType and format of DeviceAddress is always the same after UpdateInputTensor @@ -174,7 +173,6 @@ void CopyValueNodeTensorToDevice(const ValueNodePtr &node, const device::DeviceC return; } tensor->set_device_address(node_address); - UpdateRefCount(node_address.get(), true); CopyTensorDataToDevice(tensor, node, device_context); } } diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc index de9ba5d195b..7d55f6b5a46 100644 --- a/mindspore/core/ir/func_graph.cc +++ b/mindspore/core/ir/func_graph.cc @@ -43,7 +43,6 @@ FuncGraph::FuncGraph(GraphDebugInfoPtr &&debug_info) kw_only_args_count_(0), hyper_param_count_(0), is_generated_(false), - is_bprop_(false), return_(nullptr), manager_(), debug_info_(std::move(debug_info)), @@ -142,7 +141,7 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { return p; } -bool FuncGraph::has_flag(const std::string &key) { +bool FuncGraph::has_flag(const std::string &key) const { auto iter = attrs_.find(key); if (iter != attrs_.cend()) { MS_EXCEPTION_IF_NULL(iter->second); diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index 7f4d46a3b6d..ff7cdc6ccc1 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -184,8 +184,6 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder { FuncGraphPtr GenerateGraph(const AbstractBasePtrList &args_spec_list); void set_is_generate(bool generated) { is_generated_ = generated; } bool is_generated() const { return is_generated_; } - void set_is_bprop(bool is_brop) { is_bprop_ = is_brop; } - bool is_bprop() const { return is_bprop_; } mindspore::HashMap &attrs() { return attrs_; } void set_attrs(const mindspore::HashMap &attrs) { @@ -193,7 +191,7 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder { attrs_[attr.first] = attr.second; } } - bool has_flag(const std::string &key); + bool has_flag(const std::string &key) const; void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); } void erase_flag(const std::string &key) { (void)attrs_.erase(key); } @@ -425,9 +423,6 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder { size_t hyper_param_count_; // Argument input list for the graph used to generate this graph. bool is_generated_; - - bool is_bprop_; - // CNode that calls 'return' primitive. // We use shared pointer to manage it. CNodePtr return_; diff --git a/tests/st/auto_parallel/cell_shard.py b/tests/st/auto_parallel/cell_shard.py index a8b18ab175c..6752a898c43 100644 --- a/tests/st/auto_parallel/cell_shard.py +++ b/tests/st/auto_parallel/cell_shard.py @@ -381,6 +381,6 @@ def test_train_feed(num_classes=65536): model = Model(net, loss_fn=loss, optimizer=opt) model.train(3, dataset, dataset_sink_mode=False, callbacks=parallel_callback) loss_value = np.array(parallel_callback.loss_list) - expect_out = [11.087254, 10.876551, 10.045684] + expect_out = [11.087254, 10.876551, 10.142526] print(loss_value) assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)