forked from mindspore-Ecosystem/mindspore
MindRT Ascend Devcie Context
This commit is contained in:
parent
1509d3f848
commit
d6b2a34a69
|
@ -2612,14 +2612,18 @@ std::set<FuncGraphPtr> AnfRuntimeAlgorithm::GetFuncGraphbyCallNode(const AnfNode
|
|||
|
||||
if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitch)) {
|
||||
// First input node of call is switch node.
|
||||
const auto &switch_inputs = call_input0->cast<CNodePtr>()->inputs();
|
||||
const auto &input_cnode = call_input0->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
const auto &switch_inputs = input_cnode->inputs();
|
||||
for (size_t i = kSwitchTrueBranchPos; i < switch_inputs.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(switch_inputs[i]);
|
||||
(void)func_graphs.emplace(GetFuncGraphFromPartial(switch_inputs[i], call_depth));
|
||||
}
|
||||
} else if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitchLayer)) {
|
||||
// First input node of call is switch layer node.
|
||||
const auto &tuple_node = call_input0->cast<CNodePtr>()->input(kSwitchLayerBranchPos);
|
||||
const auto &input_cnode = call_input0->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
const auto &tuple_node = input_cnode->input(kSwitchLayerBranchPos);
|
||||
if (!AnfAlgo::CheckPrimitiveType(tuple_node, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input tuple node:" << tuple_node->DebugString()
|
||||
<< " for switch layer node:" << cnode->DebugString();
|
||||
|
|
|
@ -236,7 +236,7 @@ bool TensorNeedSync(const std::shared_ptr<KernelGraph> &kernel_graph, const AnfN
|
|||
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
|
||||
if (tensor_address != device_address) {
|
||||
if (!kernel_graph->is_dynamic_shape() && EnableDeviceCopy() && NeedMemcpyInDevice(tensor_address, device_address)) {
|
||||
auto status = device_address->SyncDeviceToDevice(trans::GetRuntimePaddingShape(parameter, 0),
|
||||
auto status = device_address->AsyncDeviceToDevice(trans::GetRuntimePaddingShape(parameter, 0),
|
||||
tensor_address->GetSize(), tensor_address->type_id(),
|
||||
tensor_address->GetPtr(), tensor_address->format());
|
||||
if (!status) {
|
||||
|
@ -1830,7 +1830,7 @@ void AscendSession::UpdateOutputTensors(const VectorRef *outputs,
|
|||
if (EnableDeviceCopy() && tensor->NeedSyncDeviceToHostImmediately()) {
|
||||
auto dst_device_address = AssignExtraMemForGraphOutput(tensor, node, output_index);
|
||||
MS_EXCEPTION_IF_NULL(dst_device_address);
|
||||
if (!dst_device_address->SyncDeviceToDevice(trans::GetRuntimePaddingShape(node, output_index),
|
||||
if (!dst_device_address->AsyncDeviceToDevice(trans::GetRuntimePaddingShape(node, output_index),
|
||||
address->GetSize(), address->type_id(), address->GetPtr(),
|
||||
address->format())) {
|
||||
MS_LOG(EXCEPTION) << "SyncDeviceToDevice failed!";
|
||||
|
|
|
@ -1600,6 +1600,7 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶
|
|||
// for example "def f(x,y,z) {return x + y}", parameter z in unused
|
||||
auto new_parameter = CreateNewParameter(parameter, graph);
|
||||
graph_inputs->push_back(new_parameter);
|
||||
graph->FrontBackendlMapAdd(parameter, new_parameter);
|
||||
MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -1461,9 +1461,9 @@ void ClearResAtexit() {
|
|||
mindspore::RDR::ResetRecorder();
|
||||
#endif
|
||||
session::ExecutorManager::Instance().Clear();
|
||||
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
|
||||
runtime::GraphScheduler::GetInstance().Clear();
|
||||
device::DeviceContextManager::GetInstance().ClearDeviceContexts();
|
||||
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
|
||||
ad::g_k_prims.clear();
|
||||
ad::ClearKPynativeCellStaticRes();
|
||||
ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
|
||||
|
|
|
@ -180,6 +180,8 @@ void AscendDeviceAddress::BindDevice() const {
|
|||
if (!ascend_device_context->BindDeviceToCurrentThread()) {
|
||||
MS_LOG(EXCEPTION) << "BindDeviceToCurrentThread failed.";
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "device name is null.";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -432,9 +434,32 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size
|
|||
return sync_ok;
|
||||
}
|
||||
|
||||
bool AscendDeviceAddress::SyncDeviceToDevice(const ShapeVector &, size_t size, TypeId type, const void *src_ptr,
|
||||
bool AscendDeviceAddress::SyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
|
||||
const std::string &format) const {
|
||||
MS_LOG(INFO) << "SyncDeviceToDevice, dst(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_)
|
||||
if (type_id_ > kMonadTypeBegin && type_id_ < kMonadTypeEnd) {
|
||||
return true;
|
||||
}
|
||||
BindDevice();
|
||||
if (format_ != format || type_id_ != type) {
|
||||
MS_LOG(ERROR) << "format or type is different, src(format:" << format << ", type_id:" << TypeIdLabel(type)
|
||||
<< "), dst(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_);
|
||||
return false;
|
||||
}
|
||||
if (size_ < size) {
|
||||
MS_LOG(ERROR) << "src size is greater than det size, src size is: " << size << ", dst size is: " << size_;
|
||||
return false;
|
||||
}
|
||||
auto ret_rt_memcpy = rtMemcpy(ptr_, size, src_ptr, size, RT_MEMCPY_DEVICE_TO_DEVICE);
|
||||
if (ret_rt_memcpy != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "SyncDeviceToDevice failed, rtMemcpy mem size [" << size << "], ret [" << ret_rt_memcpy << "]";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendDeviceAddress::AsyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
|
||||
const std::string &format) const {
|
||||
MS_LOG(INFO) << "AsyncDeviceToDevice, dst(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_)
|
||||
<< ", size:" << size_ << "), src(format:" << format << ", type_id:" << TypeIdLabel(type)
|
||||
<< ", size:" << size << ")";
|
||||
if (type_id_ > kMonadTypeBegin && type_id_ < kMonadTypeEnd) {
|
||||
|
|
|
@ -50,6 +50,8 @@ class AscendDeviceAddress : public DeviceAddress {
|
|||
bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override;
|
||||
bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
|
||||
const std::string &format = "DefaultFormat") const override;
|
||||
bool AsyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
|
||||
const std::string &format) const override;
|
||||
bool SyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
|
||||
const std::string &format) const override;
|
||||
void ClearDeviceMemory() override;
|
||||
|
|
|
@ -301,6 +301,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
|
|||
!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
|
||||
HcclCollectiveGroup::instance().FinalizeCollective();
|
||||
}
|
||||
initialized_ = false;
|
||||
MS_LOG(INFO) << "Ascend finalize end";
|
||||
}
|
||||
|
||||
|
@ -407,7 +408,10 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
|
|||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, kAscendDevice, device_id);
|
||||
auto ascend_device_address_ptr =
|
||||
std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, kAscendDevice, device_id);
|
||||
ascend_device_address_ptr->set_is_ptr_persisted(true);
|
||||
return ascend_device_address_ptr;
|
||||
}
|
||||
|
||||
DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
|
@ -415,8 +419,10 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
|
|||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, node_index, kAscendDevice,
|
||||
device_id);
|
||||
auto ascend_device_address_ptr = std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id,
|
||||
node_index, kAscendDevice, device_id);
|
||||
ascend_device_address_ptr->set_is_ptr_persisted(true);
|
||||
return ascend_device_address_ptr;
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::Load(const session::KernelGraph &graph, bool is_task_sink) {
|
||||
|
|
|
@ -73,7 +73,6 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
void *GetModelStream(uint32_t graph_id) const override;
|
||||
// add for MindRT
|
||||
void ReleaseDeviceRes() override;
|
||||
void SetCurrentContext();
|
||||
|
||||
protected:
|
||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
|
@ -90,6 +89,7 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
static bool HcclInit();
|
||||
static bool NeedDestroyHccl();
|
||||
static bool DestroyHccl();
|
||||
void SetCurrentContext();
|
||||
|
||||
void ClearGraphModelMap();
|
||||
bool GraphWithEmptyTaskList(const session::KernelGraph &graph) const;
|
||||
|
|
|
@ -148,6 +148,7 @@ void AiCoreDynamicKernel::AllocateWorkspace() {
|
|||
workspace_addr_.clear();
|
||||
for (auto size : workspaces_size_) {
|
||||
auto device_address_ptr = std::make_shared<AscendDeviceAddress>(nullptr, size, kAscendDevice, device_id);
|
||||
device_address_ptr->set_is_ptr_persisted(true);
|
||||
auto device_ptr = runtime_instance->MallocMem(MemType::kDynamicMem, size, device_address_ptr);
|
||||
if (device_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "MallocMem from memory pool failed. Node info :" << cnode->fullname_with_scope();
|
||||
|
|
|
@ -98,6 +98,8 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
TypeId type_id() const { return type_id_; }
|
||||
bool from_mem_pool() const { return from_mem_pool_; }
|
||||
void set_from_mem_pool(bool from_mem_pool) { from_mem_pool_ = from_mem_pool; }
|
||||
bool is_ptr_persisted() const { return is_ptr_persisted_; }
|
||||
void set_is_ptr_persisted(bool is_ptr_persisted) { is_ptr_persisted_ = is_ptr_persisted; }
|
||||
void set_host_shape(const ShapeVector &shape) { host_shape_ = shape; }
|
||||
virtual void set_status(DeviceAddressStatus status) {}
|
||||
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
|
||||
|
@ -134,6 +136,10 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
ShapeVector host_shape_{};
|
||||
// {node, out_index}
|
||||
std::pair<AnfNodeWeakPtr, size_t> node_index_{AnfNodePtr(nullptr), 0};
|
||||
// The device address of the node that owns the device address cannot be updated and replaced.
|
||||
// application scenario: set to true when the hardware execution mode requires that ptr cannot be changed during
|
||||
// execution.
|
||||
bool is_ptr_persisted_{false};
|
||||
|
||||
// The key of device context.
|
||||
std::string device_name_{""};
|
||||
|
|
|
@ -997,6 +997,7 @@ void KernelAdjust::AssignLoopCtrlTensorMem(const session::KernelGraph &kernel_gr
|
|||
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
device_address =
|
||||
std::make_shared<device::ascend::AscendDeviceAddress>(nullptr, size, format, type_id, kAscendDevice, device_id);
|
||||
device_address->set_is_ptr_persisted(true);
|
||||
|
||||
if (runtime_instance->MallocMem(kStaticMem, size, device_address) == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot alloc static memory for device loop control parameter " << name
|
||||
|
|
|
@ -159,6 +159,10 @@ bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_
|
|||
} else if (dst_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
||||
// Other device tensor copy to CPU device tensor.
|
||||
return src_device_tensor->SyncDeviceToHost(copy_size, dst_device_tensor->GetMutablePtr());
|
||||
} else if (dst_device_tensor->DeviceType() == src_device_tensor->DeviceType()) {
|
||||
return dst_device_tensor->SyncDeviceToDevice(ShapeVector(), src_device_tensor->GetSize(),
|
||||
src_device_tensor->type_id(), src_device_tensor->GetPtr(),
|
||||
src_device_tensor->format());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid device type, src device type: " << src_device_tensor->DeviceType()
|
||||
<< ", dst device type: " << dst_device_tensor->DeviceType();
|
||||
|
|
|
@ -239,7 +239,8 @@ void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector<std::vect
|
|||
auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType())) {
|
||||
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType()) &&
|
||||
!device_address->is_ptr_persisted()) {
|
||||
AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -348,6 +348,9 @@ GraphId GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph, const Device
|
|||
|
||||
auto graph_id = CompileGraphImpl(root_graph, device_context);
|
||||
|
||||
// dump all graphs.
|
||||
device_context->DumpAllGraphs(all_graphs);
|
||||
|
||||
// Cache the backend graph output nodes to front nodes with output index.
|
||||
auto output = func_graph->output();
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include <set>
|
||||
#include "backend/optimizer/ascend/ascend_backend_optimization.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
|
||||
#include "backend/session/ascend_auto_monad.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "runtime/device/ascend/kernel_select_ascend.h"
|
||||
#include "runtime/device/kernel_adjust.h"
|
||||
|
@ -34,6 +33,17 @@
|
|||
#include "debug/dump_proto.h"
|
||||
#include "debug/data_dump/e2e_dump.h"
|
||||
#endif
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
#include "debug/tensor_load.h"
|
||||
#include "debug/debugger/proto_exporter.h"
|
||||
#else
|
||||
#include "debug/debugger/proto_exporter_stub.h"
|
||||
#endif
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
#include "debug/rdr/running_data_recorder.h"
|
||||
#include "debug/rdr/recorder_manager.h"
|
||||
#include "debug/rdr/graph_recorder.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -69,11 +79,47 @@ void Dump(const KernelGraphPtr &graph, uint32_t rank_id) {
|
|||
}
|
||||
#endif
|
||||
|
||||
void AscendDeviceContext::DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) const {
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
auto &json_parser = DumpJsonParser::GetInstance();
|
||||
json_parser.Parse();
|
||||
if (!save_graphs && !json_parser.e2e_dump_enabled() && !json_parser.async_dump_enabled() &&
|
||||
!mindspore::RecorderManager::Instance().RdrEnable()) {
|
||||
return;
|
||||
}
|
||||
for (auto &graph : all_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::string name = "graph_build." + std::to_string(graph->graph_id());
|
||||
DumpGraphParams dump_params = {true, static_cast<int>(kWholeStack)};
|
||||
(void)mindspore::RDR::RecordAnfGraph(SUBMODULE_ID, name, graph, dump_params, ".ir;.pb");
|
||||
if (save_graphs) {
|
||||
std::string file_name = "graph_build_" + std::to_string(graph->graph_id()) + ".ir";
|
||||
DumpIR(file_name, graph, true, kWholeStack);
|
||||
DumpIRProto(graph, "vm_build_" + std::to_string(graph->graph_id()));
|
||||
DumpIR("trace_code_graph", graph, true, kWholeStack);
|
||||
}
|
||||
std::string final_graph = "trace_code_graph_" + std::to_string(graph->graph_id());
|
||||
if (json_parser.e2e_dump_enabled() || json_parser.async_dump_enabled()) {
|
||||
std::string root_dir = json_parser.path() + "/rank_" + std::to_string(rank_id_);
|
||||
std::string target_dir = root_dir + "/graphs";
|
||||
std::string ir_file_path = target_dir + "/" + "ms_output_" + final_graph + ".ir";
|
||||
DumpIRProtoWithSrcInfo(graph, final_graph, target_dir, kDebugWholeStack);
|
||||
DumpIR("trace_code_graph", graph, true, kWholeStack, ir_file_path);
|
||||
DumpGraphExeOrder("ms_execution_order_graph_" + std::to_string(graph->graph_id()) + ".csv", root_dir,
|
||||
graph->execution_order());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void AscendDeviceContext::Initialize() {
|
||||
MS_LOG(INFO) << "Status record: Enter Initialize...";
|
||||
if (initialized_) {
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->SetCurrentContext();
|
||||
runtime_instance_->SetContext();
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -109,8 +155,7 @@ void AscendDeviceContext::Destroy() {
|
|||
}
|
||||
MS_LOG(INFO) << "Status record: Destroy start...";
|
||||
rank_id_ = 0;
|
||||
if (runtime_instance_ != nullptr) {
|
||||
runtime_instance_->ReleaseDeviceRes();
|
||||
if (runtime_instance_) {
|
||||
runtime_instance_ = nullptr;
|
||||
}
|
||||
initialized_ = false;
|
||||
|
@ -181,9 +226,44 @@ void AscendDeviceContext::PreprocessBeforeRunGraph(const KernelGraphPtr &graph)
|
|||
MS_LOG(INFO) << "PreprocessBeforeRunGraph success.";
|
||||
return;
|
||||
}
|
||||
AssignOutputNopNodeDeviceAddress(graph);
|
||||
MS_LOG(INFO) << "PreprocessBeforeRunGraph success.";
|
||||
}
|
||||
|
||||
void AscendDeviceContext::AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto outputs = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
|
||||
for (auto output : outputs) {
|
||||
if (!output->isa<CNode>() || !AnfAlgo::IsRealKernel(output)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!opt::IsNopNode(output)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(output);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(WARNING) << "The input number of nop node :" << output->fullname_with_scope() << " is " << input_num
|
||||
<< ", not equal 1";
|
||||
continue;
|
||||
}
|
||||
|
||||
auto real_input_index = AnfAlgo::GetRealInputIndex(output, 0);
|
||||
auto pre_node_out_device_address = AnfAlgo::GetPrevNodeOutputAddr(output, real_input_index);
|
||||
MS_EXCEPTION_IF_NULL(pre_node_out_device_address);
|
||||
auto ptr = pre_node_out_device_address->GetPtr();
|
||||
auto size = pre_node_out_device_address->GetSize();
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(output, 0);
|
||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(output, 0);
|
||||
auto device_address = CreateDeviceAddress(const_cast<void *>(ptr), size, output_format, output_type);
|
||||
AnfAlgo::SetOutputAddr(device_address, 0, output.get());
|
||||
|
||||
AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(false), output);
|
||||
MS_LOG(INFO) << "Assign device address to output nop node " << output->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
|
||||
void AscendDeviceContext::AllocateGraphMemory(const NotNull<KernelGraphPtr> &root_graph) const {
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->ClearGlobalIdleMem();
|
||||
|
@ -225,7 +305,7 @@ void AscendDeviceContext::LoadModel(const NotNull<KernelGraphPtr> &root_graph) c
|
|||
bool AscendDeviceContext::AllocateMemory(DeviceAddress *const &address, size_t size) const {
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->SetCurrentContext();
|
||||
runtime_instance_->SetContext();
|
||||
auto device_ptr = mem_manager_->MallocMemFromMemPool(size);
|
||||
if (!device_ptr) {
|
||||
return false;
|
||||
|
@ -249,7 +329,7 @@ void AscendDeviceContext::FreeMemory(DeviceAddress *const &address) const {
|
|||
bool AscendDeviceContext::AllocateContinuousMemory(const std::vector<DeviceAddressPtr> &addr_list, size_t total_size,
|
||||
const std::vector<size_t> &size_list) const {
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->SetCurrentContext();
|
||||
runtime_instance_->SetContext();
|
||||
return mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list);
|
||||
}
|
||||
|
||||
|
@ -296,7 +376,7 @@ bool AscendDeviceContext::LaunchGraph(const KernelGraphPtr &graph) const {
|
|||
MS_LOG(INFO) << "Status record: start launch graph. graph id: " << graph->graph_id();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->SetCurrentContext();
|
||||
runtime_instance_->SetContext();
|
||||
device::KernelAdjust::GetInstance().LoadDeviceLoopCtrlParameters(graph);
|
||||
auto ret = ExecuteGraph(graph);
|
||||
MS_LOG(INFO) << "Status record: end launch graph. graph id: " << graph->graph_id();
|
||||
|
@ -336,7 +416,7 @@ bool AscendDeviceContext::LaunchKernel(const CNodePtr &kernel, const vector<Addr
|
|||
}
|
||||
|
||||
bool AscendDeviceContext::BindDeviceToCurrentThread() const {
|
||||
runtime_instance_->SetCurrentContext();
|
||||
runtime_instance_->SetContext();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -126,6 +126,9 @@ class AscendDeviceContext : public DeviceContext {
|
|||
// set rt_context_ to this thread to control device
|
||||
bool BindDeviceToCurrentThread() const;
|
||||
|
||||
// dump all graphs.
|
||||
void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) const override;
|
||||
|
||||
private:
|
||||
// Graph loader interface
|
||||
void AllocateGraphMemory(const NotNull<KernelGraphPtr> &root_graph) const;
|
||||
|
@ -150,6 +153,7 @@ class AscendDeviceContext : public DeviceContext {
|
|||
mutable std::set<KernelGraphPtr> memo_;
|
||||
// Using node to get it's atomics
|
||||
mutable std::map<CNodePtr, std::vector<CNodePtr>> node_atomics_;
|
||||
void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph) const;
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
|
|
|
@ -51,6 +51,8 @@ void AscendGraphOptimization::OptimizeGraph(const KernelGraphPtr &graph) {
|
|||
OptimizeGraphWithDeviceInfo(graph);
|
||||
OptimizeExecutionOrder(graph);
|
||||
PostOptimization(graph);
|
||||
// must clear memo_ which holds kernelgraph after using AscendGraphOptimization class.
|
||||
memo_.clear();
|
||||
MS_LOG(INFO) << "Status record: end optimize graph. graph id: " << graph->graph_id();
|
||||
}
|
||||
|
||||
|
@ -71,6 +73,9 @@ void AscendGraphOptimization::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
memo_.clear();
|
||||
HardWareOptimization(graph);
|
||||
// copy child graph ref output map to father graph ref output map
|
||||
memo_.clear();
|
||||
UpdateRefOutputMap(graph);
|
||||
AnfAlgo::InsertMakeTupleForOutput(NOT_NULL(graph));
|
||||
}
|
||||
|
||||
|
@ -112,9 +117,6 @@ void AscendGraphOptimization::OptimizeExecutionOrder(const KernelGraphPtr &graph
|
|||
|
||||
void AscendGraphOptimization::PostOptimization(const KernelGraphPtr &graph) {
|
||||
MS_LOG(INFO) << "Status record: start post optimization. graph id: " << graph->graph_id();
|
||||
// copy child graph ref output map to father graph ref output map
|
||||
memo_.clear();
|
||||
UpdateRefOutputMap(graph);
|
||||
graph->SetInputNodes();
|
||||
graph->SetOptimizerFlag();
|
||||
MS_LOG(INFO) << "Status record: end post optimization. graph id: " << graph->graph_id();
|
||||
|
|
|
@ -152,6 +152,10 @@ class DeviceContext {
|
|||
// Return collective communication object for caller to access
|
||||
CollectiveCommunicationLibPtr collective_comm_lib() const { return collective_comm_lib_; }
|
||||
|
||||
// TODO(jiaorui): will be delete
|
||||
// Dump all graphs.
|
||||
virtual void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) const {}
|
||||
|
||||
protected:
|
||||
DeviceContextKey device_context_key_;
|
||||
CollectiveCommunicationLibPtr collective_comm_lib_;
|
||||
|
|
|
@ -43,6 +43,11 @@ class DeviceSync {
|
|||
const std::string &format) const {
|
||||
return true;
|
||||
}
|
||||
virtual bool AsyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
|
||||
const std::string &format) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual void *GetMutablePtr() const = 0;
|
||||
virtual void ClearDeviceMemory() = 0;
|
||||
|
||||
|
|
|
@ -95,7 +95,7 @@ fi
|
|||
|
||||
while read line; do
|
||||
if [ -f "${line}" ]; then
|
||||
${CLANG_FORMAT} -i "${line}"
|
||||
"${CLANG_FORMAT}" -i "${line}"
|
||||
fi
|
||||
done < "${FMT_FILE_LIST}"
|
||||
|
||||
|
|
Loading…
Reference in New Issue