diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index c16cca863e2..135737ac79b 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -2612,14 +2612,18 @@ std::set AnfRuntimeAlgorithm::GetFuncGraphbyCallNode(const AnfNode if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitch)) { // First input node of call is switch node. - const auto &switch_inputs = call_input0->cast()->inputs(); + const auto &input_cnode = call_input0->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + const auto &switch_inputs = input_cnode->inputs(); for (size_t i = kSwitchTrueBranchPos; i < switch_inputs.size(); ++i) { MS_EXCEPTION_IF_NULL(switch_inputs[i]); (void)func_graphs.emplace(GetFuncGraphFromPartial(switch_inputs[i], call_depth)); } } else if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitchLayer)) { // First input node of call is switch layer node. - const auto &tuple_node = call_input0->cast()->input(kSwitchLayerBranchPos); + const auto &input_cnode = call_input0->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + const auto &tuple_node = input_cnode->input(kSwitchLayerBranchPos); if (!AnfAlgo::CheckPrimitiveType(tuple_node, prim::kPrimMakeTuple)) { MS_LOG(EXCEPTION) << "Invalid input tuple node:" << tuple_node->DebugString() << " for switch layer node:" << cnode->DebugString(); diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 3b694d94934..85a57d61493 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -236,9 +236,9 @@ bool TensorNeedSync(const std::shared_ptr &kernel_graph, const AnfN auto tensor_address = std::dynamic_pointer_cast(tensor->device_address()); if (tensor_address != device_address) { if (!kernel_graph->is_dynamic_shape() && EnableDeviceCopy() && NeedMemcpyInDevice(tensor_address, device_address)) { - auto status = device_address->SyncDeviceToDevice(trans::GetRuntimePaddingShape(parameter, 0), - tensor_address->GetSize(), tensor_address->type_id(), - tensor_address->GetPtr(), tensor_address->format()); + auto status = device_address->AsyncDeviceToDevice(trans::GetRuntimePaddingShape(parameter, 0), + tensor_address->GetSize(), tensor_address->type_id(), + tensor_address->GetPtr(), tensor_address->format()); if (!status) { MS_LOG(EXCEPTION) << "SyncDeviceToDevice failed."; } @@ -1830,9 +1830,9 @@ void AscendSession::UpdateOutputTensors(const VectorRef *outputs, if (EnableDeviceCopy() && tensor->NeedSyncDeviceToHostImmediately()) { auto dst_device_address = AssignExtraMemForGraphOutput(tensor, node, output_index); MS_EXCEPTION_IF_NULL(dst_device_address); - if (!dst_device_address->SyncDeviceToDevice(trans::GetRuntimePaddingShape(node, output_index), - address->GetSize(), address->type_id(), address->GetPtr(), - address->format())) { + if (!dst_device_address->AsyncDeviceToDevice(trans::GetRuntimePaddingShape(node, output_index), + address->GetSize(), address->type_id(), address->GetPtr(), + address->format())) { MS_LOG(EXCEPTION) << "SyncDeviceToDevice failed!"; } tensor->set_sync_status(kNoNeedSync); diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 92a4e4d30d9..3c85e8581d5 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1600,6 +1600,7 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector ¶ // for example "def f(x,y,z) {return x + y}", parameter z in unused auto new_parameter = CreateNewParameter(parameter, graph); graph_inputs->push_back(new_parameter); + graph->FrontBackendlMapAdd(parameter, new_parameter); MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString(); continue; } diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index ccea6af2cde..6abaa5eea24 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -1461,9 +1461,9 @@ void ClearResAtexit() { mindspore::RDR::ResetRecorder(); #endif session::ExecutorManager::Instance().Clear(); - device::KernelRuntimeManager::Instance().ClearRuntimeResource(); runtime::GraphScheduler::GetInstance().Clear(); device::DeviceContextManager::GetInstance().ClearDeviceContexts(); + device::KernelRuntimeManager::Instance().ClearRuntimeResource(); ad::g_k_prims.clear(); ad::ClearKPynativeCellStaticRes(); ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear(); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc index c7a605d0a77..cb6a852c13d 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -180,6 +180,8 @@ void AscendDeviceAddress::BindDevice() const { if (!ascend_device_context->BindDeviceToCurrentThread()) { MS_LOG(EXCEPTION) << "BindDeviceToCurrentThread failed."; } + } else { + MS_LOG(WARNING) << "device name is null."; } } @@ -432,9 +434,32 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size return sync_ok; } -bool AscendDeviceAddress::SyncDeviceToDevice(const ShapeVector &, size_t size, TypeId type, const void *src_ptr, +bool AscendDeviceAddress::SyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr, const std::string &format) const { - MS_LOG(INFO) << "SyncDeviceToDevice, dst(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) + if (type_id_ > kMonadTypeBegin && type_id_ < kMonadTypeEnd) { + return true; + } + BindDevice(); + if (format_ != format || type_id_ != type) { + MS_LOG(ERROR) << "format or type is different, src(format:" << format << ", type_id:" << TypeIdLabel(type) + << "), dst(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_); + return false; + } + if (size_ < size) { + MS_LOG(ERROR) << "src size is greater than det size, src size is: " << size << ", dst size is: " << size_; + return false; + } + auto ret_rt_memcpy = rtMemcpy(ptr_, size, src_ptr, size, RT_MEMCPY_DEVICE_TO_DEVICE); + if (ret_rt_memcpy != RT_ERROR_NONE) { + MS_LOG(ERROR) << "SyncDeviceToDevice failed, rtMemcpy mem size [" << size << "], ret [" << ret_rt_memcpy << "]"; + return false; + } + return true; +} + +bool AscendDeviceAddress::AsyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr, + const std::string &format) const { + MS_LOG(INFO) << "AsyncDeviceToDevice, dst(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) << ", size:" << size_ << "), src(format:" << format << ", type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; if (type_id_ > kMonadTypeBegin && type_id_ < kMonadTypeEnd) { diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h index 5a82e3185cf..629d48251b3 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h @@ -50,6 +50,8 @@ class AscendDeviceAddress : public DeviceAddress { bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override; bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr, const std::string &format = "DefaultFormat") const override; + bool AsyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr, + const std::string &format) const override; bool SyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr, const std::string &format) const override; void ClearDeviceMemory() override; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 5881c5d33bb..2c5ee91ea59 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -301,6 +301,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() { !context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK)) { HcclCollectiveGroup::instance().FinalizeCollective(); } + initialized_ = false; MS_LOG(INFO) << "Ascend finalize end"; } @@ -407,7 +408,10 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); auto device_id = ms_context->get_param(MS_CTX_DEVICE_ID); - return std::make_shared(device_ptr, device_size, format, type_id, kAscendDevice, device_id); + auto ascend_device_address_ptr = + std::make_shared(device_ptr, device_size, format, type_id, kAscendDevice, device_id); + ascend_device_address_ptr->set_is_ptr_persisted(true); + return ascend_device_address_ptr; } DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, @@ -415,8 +419,10 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); auto device_id = ms_context->get_param(MS_CTX_DEVICE_ID); - return std::make_shared(device_ptr, device_size, format, type_id, node_index, kAscendDevice, - device_id); + auto ascend_device_address_ptr = std::make_shared(device_ptr, device_size, format, type_id, + node_index, kAscendDevice, device_id); + ascend_device_address_ptr->set_is_ptr_persisted(true); + return ascend_device_address_ptr; } bool AscendKernelRuntime::Load(const session::KernelGraph &graph, bool is_task_sink) { diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index c6f8e6a2beb..6a3022c470f 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -73,7 +73,6 @@ class AscendKernelRuntime : public KernelRuntime { void *GetModelStream(uint32_t graph_id) const override; // add for MindRT void ReleaseDeviceRes() override; - void SetCurrentContext(); protected: DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, @@ -90,6 +89,7 @@ class AscendKernelRuntime : public KernelRuntime { static bool HcclInit(); static bool NeedDestroyHccl(); static bool DestroyHccl(); + void SetCurrentContext(); void ClearGraphModelMap(); bool GraphWithEmptyTaskList(const session::KernelGraph &graph) const; diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/ai_core_dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/ascend/executor/ai_core_dynamic_kernel.cc index a3a69e3d10a..190ff1657c8 100644 --- a/mindspore/ccsrc/runtime/device/ascend/executor/ai_core_dynamic_kernel.cc +++ b/mindspore/ccsrc/runtime/device/ascend/executor/ai_core_dynamic_kernel.cc @@ -148,6 +148,7 @@ void AiCoreDynamicKernel::AllocateWorkspace() { workspace_addr_.clear(); for (auto size : workspaces_size_) { auto device_address_ptr = std::make_shared(nullptr, size, kAscendDevice, device_id); + device_address_ptr->set_is_ptr_persisted(true); auto device_ptr = runtime_instance->MallocMem(MemType::kDynamicMem, size, device_address_ptr); if (device_ptr == nullptr) { MS_LOG(EXCEPTION) << "MallocMem from memory pool failed. Node info :" << cnode->fullname_with_scope(); diff --git a/mindspore/ccsrc/runtime/device/device_address.h b/mindspore/ccsrc/runtime/device/device_address.h index 2f59d6a14bc..1099afc590a 100644 --- a/mindspore/ccsrc/runtime/device/device_address.h +++ b/mindspore/ccsrc/runtime/device/device_address.h @@ -98,6 +98,8 @@ class DeviceAddress : public mindspore::DeviceSync { TypeId type_id() const { return type_id_; } bool from_mem_pool() const { return from_mem_pool_; } void set_from_mem_pool(bool from_mem_pool) { from_mem_pool_ = from_mem_pool; } + bool is_ptr_persisted() const { return is_ptr_persisted_; } + void set_is_ptr_persisted(bool is_ptr_persisted) { is_ptr_persisted_ = is_ptr_persisted; } void set_host_shape(const ShapeVector &shape) { host_shape_ = shape; } virtual void set_status(DeviceAddressStatus status) {} virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } @@ -134,6 +136,10 @@ class DeviceAddress : public mindspore::DeviceSync { ShapeVector host_shape_{}; // {node, out_index} std::pair node_index_{AnfNodePtr(nullptr), 0}; + // The device address of the node that owns the device address cannot be updated and replaced. + // application scenario: set to true when the hardware execution mode requires that ptr cannot be changed during + // execution. + bool is_ptr_persisted_{false}; // The key of device context. std::string device_name_{""}; diff --git a/mindspore/ccsrc/runtime/device/kernel_adjust.cc b/mindspore/ccsrc/runtime/device/kernel_adjust.cc index a21200b0716..26920c5b240 100644 --- a/mindspore/ccsrc/runtime/device/kernel_adjust.cc +++ b/mindspore/ccsrc/runtime/device/kernel_adjust.cc @@ -997,6 +997,7 @@ void KernelAdjust::AssignLoopCtrlTensorMem(const session::KernelGraph &kernel_gr auto device_id = ms_context->get_param(MS_CTX_DEVICE_ID); device_address = std::make_shared(nullptr, size, format, type_id, kAscendDevice, device_id); + device_address->set_is_ptr_persisted(true); if (runtime_instance->MallocMem(kStaticMem, size, device_address) == nullptr) { MS_LOG(EXCEPTION) << "Cannot alloc static memory for device loop control parameter " << name diff --git a/mindspore/ccsrc/runtime/framework/actor/actor_common.cc b/mindspore/ccsrc/runtime/framework/actor/actor_common.cc index 09942783166..1f6a0d94c3d 100644 --- a/mindspore/ccsrc/runtime/framework/actor/actor_common.cc +++ b/mindspore/ccsrc/runtime/framework/actor/actor_common.cc @@ -159,6 +159,10 @@ bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_ } else if (dst_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) { // Other device tensor copy to CPU device tensor. return src_device_tensor->SyncDeviceToHost(copy_size, dst_device_tensor->GetMutablePtr()); + } else if (dst_device_tensor->DeviceType() == src_device_tensor->DeviceType()) { + return dst_device_tensor->SyncDeviceToDevice(ShapeVector(), src_device_tensor->GetSize(), + src_device_tensor->type_id(), src_device_tensor->GetPtr(), + src_device_tensor->format()); } else { MS_LOG(ERROR) << "Invalid device type, src device type: " << src_device_tensor->DeviceType() << ", dst device type: " << dst_device_tensor->DeviceType(); diff --git a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc index 63772a447d6..65fef26579e 100644 --- a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc @@ -239,7 +239,8 @@ void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector(input_tensor->device_address()); auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false); MS_EXCEPTION_IF_NULL(device_address); - if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType())) { + if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType()) && + !device_address->is_ptr_persisted()) { AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get()); } } diff --git a/mindspore/ccsrc/runtime/framework/graph_compiler.cc b/mindspore/ccsrc/runtime/framework/graph_compiler.cc index c0e8a2dad75..9a8f13f3d09 100644 --- a/mindspore/ccsrc/runtime/framework/graph_compiler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_compiler.cc @@ -348,6 +348,9 @@ GraphId GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph, const Device auto graph_id = CompileGraphImpl(root_graph, device_context); + // dump all graphs. + device_context->DumpAllGraphs(all_graphs); + // Cache the backend graph output nodes to front nodes with output index. auto output = func_graph->output(); MS_EXCEPTION_IF_NULL(output); diff --git a/mindspore/ccsrc/runtime/hardware/ascend/ascend_device_context.cc b/mindspore/ccsrc/runtime/hardware/ascend/ascend_device_context.cc index b5b6c56b205..6e2b5e03112 100644 --- a/mindspore/ccsrc/runtime/hardware/ascend/ascend_device_context.cc +++ b/mindspore/ccsrc/runtime/hardware/ascend/ascend_device_context.cc @@ -19,7 +19,6 @@ #include #include "backend/optimizer/ascend/ascend_backend_optimization.h" #include "backend/optimizer/graph_kernel/graph_kernel_optimization.h" -#include "backend/session/ascend_auto_monad.h" #include "utils/context/graph_kernel_flags.h" #include "runtime/device/ascend/kernel_select_ascend.h" #include "runtime/device/kernel_adjust.h" @@ -34,6 +33,17 @@ #include "debug/dump_proto.h" #include "debug/data_dump/e2e_dump.h" #endif +#ifdef ENABLE_DEBUGGER +#include "debug/tensor_load.h" +#include "debug/debugger/proto_exporter.h" +#else +#include "debug/debugger/proto_exporter_stub.h" +#endif +#ifdef ENABLE_DUMP_IR +#include "debug/rdr/running_data_recorder.h" +#include "debug/rdr/recorder_manager.h" +#include "debug/rdr/graph_recorder.h" +#endif namespace mindspore { namespace device { @@ -69,11 +79,47 @@ void Dump(const KernelGraphPtr &graph, uint32_t rank_id) { } #endif +void AscendDeviceContext::DumpAllGraphs(const std::vector &all_graphs) const { +#ifdef ENABLE_DUMP_IR + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); + auto &json_parser = DumpJsonParser::GetInstance(); + json_parser.Parse(); + if (!save_graphs && !json_parser.e2e_dump_enabled() && !json_parser.async_dump_enabled() && + !mindspore::RecorderManager::Instance().RdrEnable()) { + return; + } + for (auto &graph : all_graphs) { + MS_EXCEPTION_IF_NULL(graph); + std::string name = "graph_build." + std::to_string(graph->graph_id()); + DumpGraphParams dump_params = {true, static_cast(kWholeStack)}; + (void)mindspore::RDR::RecordAnfGraph(SUBMODULE_ID, name, graph, dump_params, ".ir;.pb"); + if (save_graphs) { + std::string file_name = "graph_build_" + std::to_string(graph->graph_id()) + ".ir"; + DumpIR(file_name, graph, true, kWholeStack); + DumpIRProto(graph, "vm_build_" + std::to_string(graph->graph_id())); + DumpIR("trace_code_graph", graph, true, kWholeStack); + } + std::string final_graph = "trace_code_graph_" + std::to_string(graph->graph_id()); + if (json_parser.e2e_dump_enabled() || json_parser.async_dump_enabled()) { + std::string root_dir = json_parser.path() + "/rank_" + std::to_string(rank_id_); + std::string target_dir = root_dir + "/graphs"; + std::string ir_file_path = target_dir + "/" + "ms_output_" + final_graph + ".ir"; + DumpIRProtoWithSrcInfo(graph, final_graph, target_dir, kDebugWholeStack); + DumpIR("trace_code_graph", graph, true, kWholeStack, ir_file_path); + DumpGraphExeOrder("ms_execution_order_graph_" + std::to_string(graph->graph_id()) + ".csv", root_dir, + graph->execution_order()); + } + } +#endif +} + void AscendDeviceContext::Initialize() { MS_LOG(INFO) << "Status record: Enter Initialize..."; if (initialized_) { MS_EXCEPTION_IF_NULL(runtime_instance_); - runtime_instance_->SetCurrentContext(); + runtime_instance_->SetContext(); return; } @@ -109,8 +155,7 @@ void AscendDeviceContext::Destroy() { } MS_LOG(INFO) << "Status record: Destroy start..."; rank_id_ = 0; - if (runtime_instance_ != nullptr) { - runtime_instance_->ReleaseDeviceRes(); + if (runtime_instance_) { runtime_instance_ = nullptr; } initialized_ = false; @@ -181,9 +226,44 @@ void AscendDeviceContext::PreprocessBeforeRunGraph(const KernelGraphPtr &graph) MS_LOG(INFO) << "PreprocessBeforeRunGraph success."; return; } + AssignOutputNopNodeDeviceAddress(graph); MS_LOG(INFO) << "PreprocessBeforeRunGraph success."; } +void AscendDeviceContext::AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph) const { + MS_EXCEPTION_IF_NULL(graph); + auto outputs = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); + for (auto output : outputs) { + if (!output->isa() || !AnfAlgo::IsRealKernel(output)) { + continue; + } + + if (!opt::IsNopNode(output)) { + continue; + } + + size_t input_num = AnfAlgo::GetInputTensorNum(output); + if (input_num != 1) { + MS_LOG(WARNING) << "The input number of nop node :" << output->fullname_with_scope() << " is " << input_num + << ", not equal 1"; + continue; + } + + auto real_input_index = AnfAlgo::GetRealInputIndex(output, 0); + auto pre_node_out_device_address = AnfAlgo::GetPrevNodeOutputAddr(output, real_input_index); + MS_EXCEPTION_IF_NULL(pre_node_out_device_address); + auto ptr = pre_node_out_device_address->GetPtr(); + auto size = pre_node_out_device_address->GetSize(); + std::string output_format = AnfAlgo::GetOutputFormat(output, 0); + auto output_type = AnfAlgo::GetOutputDeviceDataType(output, 0); + auto device_address = CreateDeviceAddress(const_cast(ptr), size, output_format, output_type); + AnfAlgo::SetOutputAddr(device_address, 0, output.get()); + + AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(false), output); + MS_LOG(INFO) << "Assign device address to output nop node " << output->fullname_with_scope(); + } +} + void AscendDeviceContext::AllocateGraphMemory(const NotNull &root_graph) const { MS_EXCEPTION_IF_NULL(runtime_instance_); runtime_instance_->ClearGlobalIdleMem(); @@ -225,7 +305,7 @@ void AscendDeviceContext::LoadModel(const NotNull &root_graph) c bool AscendDeviceContext::AllocateMemory(DeviceAddress *const &address, size_t size) const { MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(runtime_instance_); - runtime_instance_->SetCurrentContext(); + runtime_instance_->SetContext(); auto device_ptr = mem_manager_->MallocMemFromMemPool(size); if (!device_ptr) { return false; @@ -249,7 +329,7 @@ void AscendDeviceContext::FreeMemory(DeviceAddress *const &address) const { bool AscendDeviceContext::AllocateContinuousMemory(const std::vector &addr_list, size_t total_size, const std::vector &size_list) const { MS_EXCEPTION_IF_NULL(runtime_instance_); - runtime_instance_->SetCurrentContext(); + runtime_instance_->SetContext(); return mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list); } @@ -296,7 +376,7 @@ bool AscendDeviceContext::LaunchGraph(const KernelGraphPtr &graph) const { MS_LOG(INFO) << "Status record: start launch graph. graph id: " << graph->graph_id(); MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(runtime_instance_); - runtime_instance_->SetCurrentContext(); + runtime_instance_->SetContext(); device::KernelAdjust::GetInstance().LoadDeviceLoopCtrlParameters(graph); auto ret = ExecuteGraph(graph); MS_LOG(INFO) << "Status record: end launch graph. graph id: " << graph->graph_id(); @@ -336,7 +416,7 @@ bool AscendDeviceContext::LaunchKernel(const CNodePtr &kernel, const vectorSetCurrentContext(); + runtime_instance_->SetContext(); return true; } diff --git a/mindspore/ccsrc/runtime/hardware/ascend/ascend_device_context.h b/mindspore/ccsrc/runtime/hardware/ascend/ascend_device_context.h index 12609e08998..a6a013fb138 100644 --- a/mindspore/ccsrc/runtime/hardware/ascend/ascend_device_context.h +++ b/mindspore/ccsrc/runtime/hardware/ascend/ascend_device_context.h @@ -126,6 +126,9 @@ class AscendDeviceContext : public DeviceContext { // set rt_context_ to this thread to control device bool BindDeviceToCurrentThread() const; + // dump all graphs. + void DumpAllGraphs(const std::vector &all_graphs) const override; + private: // Graph loader interface void AllocateGraphMemory(const NotNull &root_graph) const; @@ -150,6 +153,7 @@ class AscendDeviceContext : public DeviceContext { mutable std::set memo_; // Using node to get it's atomics mutable std::map> node_atomics_; + void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph) const; }; } // namespace ascend } // namespace device diff --git a/mindspore/ccsrc/runtime/hardware/ascend/ascend_graph_optimization.cc b/mindspore/ccsrc/runtime/hardware/ascend/ascend_graph_optimization.cc index 29aa65f5167..4c4c7fd0f0a 100644 --- a/mindspore/ccsrc/runtime/hardware/ascend/ascend_graph_optimization.cc +++ b/mindspore/ccsrc/runtime/hardware/ascend/ascend_graph_optimization.cc @@ -51,6 +51,8 @@ void AscendGraphOptimization::OptimizeGraph(const KernelGraphPtr &graph) { OptimizeGraphWithDeviceInfo(graph); OptimizeExecutionOrder(graph); PostOptimization(graph); + // must clear memo_ which holds kernelgraph after using AscendGraphOptimization class. + memo_.clear(); MS_LOG(INFO) << "Status record: end optimize graph. graph id: " << graph->graph_id(); } @@ -71,6 +73,9 @@ void AscendGraphOptimization::OptimizeGraphWithDeviceInfo(const KernelGraphPtr & MS_EXCEPTION_IF_NULL(graph); memo_.clear(); HardWareOptimization(graph); + // copy child graph ref output map to father graph ref output map + memo_.clear(); + UpdateRefOutputMap(graph); AnfAlgo::InsertMakeTupleForOutput(NOT_NULL(graph)); } @@ -112,9 +117,6 @@ void AscendGraphOptimization::OptimizeExecutionOrder(const KernelGraphPtr &graph void AscendGraphOptimization::PostOptimization(const KernelGraphPtr &graph) { MS_LOG(INFO) << "Status record: start post optimization. graph id: " << graph->graph_id(); - // copy child graph ref output map to father graph ref output map - memo_.clear(); - UpdateRefOutputMap(graph); graph->SetInputNodes(); graph->SetOptimizerFlag(); MS_LOG(INFO) << "Status record: end post optimization. graph id: " << graph->graph_id(); diff --git a/mindspore/ccsrc/runtime/hardware/device_context.h b/mindspore/ccsrc/runtime/hardware/device_context.h index 24c9acb1d89..87f61c0c0f6 100644 --- a/mindspore/ccsrc/runtime/hardware/device_context.h +++ b/mindspore/ccsrc/runtime/hardware/device_context.h @@ -152,6 +152,10 @@ class DeviceContext { // Return collective communication object for caller to access CollectiveCommunicationLibPtr collective_comm_lib() const { return collective_comm_lib_; } + // TODO(jiaorui): will be delete + // Dump all graphs. + virtual void DumpAllGraphs(const std::vector &all_graphs) const {} + protected: DeviceContextKey device_context_key_; CollectiveCommunicationLibPtr collective_comm_lib_; diff --git a/mindspore/core/ir/device_sync.h b/mindspore/core/ir/device_sync.h index 4f54f03e70e..f80197c58cb 100644 --- a/mindspore/core/ir/device_sync.h +++ b/mindspore/core/ir/device_sync.h @@ -43,6 +43,11 @@ class DeviceSync { const std::string &format) const { return true; } + virtual bool AsyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr, + const std::string &format) const { + return true; + } + virtual void *GetMutablePtr() const = 0; virtual void ClearDeviceMemory() = 0; diff --git a/scripts/format_source_code.sh b/scripts/format_source_code.sh index 8fb9759b130..d9d273c5131 100755 --- a/scripts/format_source_code.sh +++ b/scripts/format_source_code.sh @@ -95,7 +95,7 @@ fi while read line; do if [ -f "${line}" ]; then - ${CLANG_FORMAT} -i "${line}" + "${CLANG_FORMAT}" -i "${line}" fi done < "${FMT_FILE_LIST}"