forked from mindspore-Ecosystem/mindspore
!26007 Ascend Device Context for MindRT
Merge pull request !26007 from hwjiaorui/ascend_context
This commit is contained in:
commit
1bfedfcb7f
|
@ -2563,14 +2563,18 @@ std::set<FuncGraphPtr> AnfRuntimeAlgorithm::GetFuncGraphbyCallNode(const AnfNode
|
||||||
|
|
||||||
if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitch)) {
|
if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitch)) {
|
||||||
// First input node of call is switch node.
|
// First input node of call is switch node.
|
||||||
const auto &switch_inputs = call_input0->cast<CNodePtr>()->inputs();
|
const auto &input_cnode = call_input0->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||||
|
const auto &switch_inputs = input_cnode->inputs();
|
||||||
for (size_t i = kSwitchTrueBranchPos; i < switch_inputs.size(); ++i) {
|
for (size_t i = kSwitchTrueBranchPos; i < switch_inputs.size(); ++i) {
|
||||||
MS_EXCEPTION_IF_NULL(switch_inputs[i]);
|
MS_EXCEPTION_IF_NULL(switch_inputs[i]);
|
||||||
(void)func_graphs.emplace(GetFuncGraphFromPartial(switch_inputs[i], call_depth));
|
(void)func_graphs.emplace(GetFuncGraphFromPartial(switch_inputs[i], call_depth));
|
||||||
}
|
}
|
||||||
} else if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitchLayer)) {
|
} else if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitchLayer)) {
|
||||||
// First input node of call is switch layer node.
|
// First input node of call is switch layer node.
|
||||||
const auto &tuple_node = call_input0->cast<CNodePtr>()->input(kSwitchLayerBranchPos);
|
const auto &input_cnode = call_input0->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||||
|
const auto &tuple_node = input_cnode->input(kSwitchLayerBranchPos);
|
||||||
if (!AnfAlgo::CheckPrimitiveType(tuple_node, prim::kPrimMakeTuple)) {
|
if (!AnfAlgo::CheckPrimitiveType(tuple_node, prim::kPrimMakeTuple)) {
|
||||||
MS_LOG(EXCEPTION) << "Invalid input tuple node:" << tuple_node->DebugString()
|
MS_LOG(EXCEPTION) << "Invalid input tuple node:" << tuple_node->DebugString()
|
||||||
<< " for switch layer node:" << cnode->DebugString();
|
<< " for switch layer node:" << cnode->DebugString();
|
||||||
|
|
|
@ -236,9 +236,9 @@ bool TensorNeedSync(const std::shared_ptr<KernelGraph> &kernel_graph, const AnfN
|
||||||
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
|
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
|
||||||
if (tensor_address != device_address) {
|
if (tensor_address != device_address) {
|
||||||
if (!kernel_graph->is_dynamic_shape() && EnableDeviceCopy() && NeedMemcpyInDevice(tensor_address, device_address)) {
|
if (!kernel_graph->is_dynamic_shape() && EnableDeviceCopy() && NeedMemcpyInDevice(tensor_address, device_address)) {
|
||||||
auto status = device_address->SyncDeviceToDevice(trans::GetRuntimePaddingShape(parameter, 0),
|
auto status = device_address->AsyncDeviceToDevice(trans::GetRuntimePaddingShape(parameter, 0),
|
||||||
tensor_address->GetSize(), tensor_address->type_id(),
|
tensor_address->GetSize(), tensor_address->type_id(),
|
||||||
tensor_address->GetPtr(), tensor_address->format());
|
tensor_address->GetPtr(), tensor_address->format());
|
||||||
if (!status) {
|
if (!status) {
|
||||||
MS_LOG(EXCEPTION) << "SyncDeviceToDevice failed.";
|
MS_LOG(EXCEPTION) << "SyncDeviceToDevice failed.";
|
||||||
}
|
}
|
||||||
|
@ -1830,9 +1830,9 @@ void AscendSession::UpdateOutputTensors(const VectorRef *outputs,
|
||||||
if (EnableDeviceCopy() && tensor->NeedSyncDeviceToHostImmediately()) {
|
if (EnableDeviceCopy() && tensor->NeedSyncDeviceToHostImmediately()) {
|
||||||
auto dst_device_address = AssignExtraMemForGraphOutput(tensor, node, output_index);
|
auto dst_device_address = AssignExtraMemForGraphOutput(tensor, node, output_index);
|
||||||
MS_EXCEPTION_IF_NULL(dst_device_address);
|
MS_EXCEPTION_IF_NULL(dst_device_address);
|
||||||
if (!dst_device_address->SyncDeviceToDevice(trans::GetRuntimePaddingShape(node, output_index),
|
if (!dst_device_address->AsyncDeviceToDevice(trans::GetRuntimePaddingShape(node, output_index),
|
||||||
address->GetSize(), address->type_id(), address->GetPtr(),
|
address->GetSize(), address->type_id(), address->GetPtr(),
|
||||||
address->format())) {
|
address->format())) {
|
||||||
MS_LOG(EXCEPTION) << "SyncDeviceToDevice failed!";
|
MS_LOG(EXCEPTION) << "SyncDeviceToDevice failed!";
|
||||||
}
|
}
|
||||||
tensor->set_sync_status(kNoNeedSync);
|
tensor->set_sync_status(kNoNeedSync);
|
||||||
|
|
|
@ -1604,6 +1604,7 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶
|
||||||
// for example "def f(x,y,z) {return x + y}", parameter z in unused
|
// for example "def f(x,y,z) {return x + y}", parameter z in unused
|
||||||
auto new_parameter = CreateNewParameter(parameter, graph);
|
auto new_parameter = CreateNewParameter(parameter, graph);
|
||||||
graph_inputs->push_back(new_parameter);
|
graph_inputs->push_back(new_parameter);
|
||||||
|
graph->FrontBackendlMapAdd(parameter, new_parameter);
|
||||||
MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
|
MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1461,9 +1461,9 @@ void ClearResAtexit() {
|
||||||
mindspore::RDR::ResetRecorder();
|
mindspore::RDR::ResetRecorder();
|
||||||
#endif
|
#endif
|
||||||
session::ExecutorManager::Instance().Clear();
|
session::ExecutorManager::Instance().Clear();
|
||||||
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
|
|
||||||
runtime::GraphScheduler::GetInstance().Clear();
|
runtime::GraphScheduler::GetInstance().Clear();
|
||||||
device::DeviceContextManager::GetInstance().ClearDeviceContexts();
|
device::DeviceContextManager::GetInstance().ClearDeviceContexts();
|
||||||
|
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
|
||||||
ad::g_k_prims.clear();
|
ad::g_k_prims.clear();
|
||||||
ad::ClearKPynativeCellStaticRes();
|
ad::ClearKPynativeCellStaticRes();
|
||||||
ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
|
ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
|
||||||
|
|
|
@ -180,6 +180,8 @@ void AscendDeviceAddress::BindDevice() const {
|
||||||
if (!ascend_device_context->BindDeviceToCurrentThread()) {
|
if (!ascend_device_context->BindDeviceToCurrentThread()) {
|
||||||
MS_LOG(EXCEPTION) << "BindDeviceToCurrentThread failed.";
|
MS_LOG(EXCEPTION) << "BindDeviceToCurrentThread failed.";
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(WARNING) << "device name is null.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -432,9 +434,32 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size
|
||||||
return sync_ok;
|
return sync_ok;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AscendDeviceAddress::SyncDeviceToDevice(const ShapeVector &, size_t size, TypeId type, const void *src_ptr,
|
bool AscendDeviceAddress::SyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
|
||||||
const std::string &format) const {
|
const std::string &format) const {
|
||||||
MS_LOG(INFO) << "SyncDeviceToDevice, dst(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_)
|
if (type_id_ > kMonadTypeBegin && type_id_ < kMonadTypeEnd) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
BindDevice();
|
||||||
|
if (format_ != format || type_id_ != type) {
|
||||||
|
MS_LOG(ERROR) << "format or type is different, src(format:" << format << ", type_id:" << TypeIdLabel(type)
|
||||||
|
<< "), dst(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (size_ < size) {
|
||||||
|
MS_LOG(ERROR) << "src size is greater than det size, src size is: " << size << ", dst size is: " << size_;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto ret_rt_memcpy = rtMemcpy(ptr_, size, src_ptr, size, RT_MEMCPY_DEVICE_TO_DEVICE);
|
||||||
|
if (ret_rt_memcpy != RT_ERROR_NONE) {
|
||||||
|
MS_LOG(ERROR) << "SyncDeviceToDevice failed, rtMemcpy mem size [" << size << "], ret [" << ret_rt_memcpy << "]";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AscendDeviceAddress::AsyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
|
||||||
|
const std::string &format) const {
|
||||||
|
MS_LOG(INFO) << "AsyncDeviceToDevice, dst(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_)
|
||||||
<< ", size:" << size_ << "), src(format:" << format << ", type_id:" << TypeIdLabel(type)
|
<< ", size:" << size_ << "), src(format:" << format << ", type_id:" << TypeIdLabel(type)
|
||||||
<< ", size:" << size << ")";
|
<< ", size:" << size << ")";
|
||||||
if (type_id_ > kMonadTypeBegin && type_id_ < kMonadTypeEnd) {
|
if (type_id_ > kMonadTypeBegin && type_id_ < kMonadTypeEnd) {
|
||||||
|
|
|
@ -50,6 +50,8 @@ class AscendDeviceAddress : public DeviceAddress {
|
||||||
bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override;
|
bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override;
|
||||||
bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
|
bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
|
||||||
const std::string &format = "DefaultFormat") const override;
|
const std::string &format = "DefaultFormat") const override;
|
||||||
|
bool AsyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
|
||||||
|
const std::string &format) const override;
|
||||||
bool SyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
|
bool SyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
|
||||||
const std::string &format) const override;
|
const std::string &format) const override;
|
||||||
void ClearDeviceMemory() override;
|
void ClearDeviceMemory() override;
|
||||||
|
|
|
@ -301,6 +301,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
|
||||||
!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
|
!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
|
||||||
HcclCollectiveGroup::instance().FinalizeCollective();
|
HcclCollectiveGroup::instance().FinalizeCollective();
|
||||||
}
|
}
|
||||||
|
initialized_ = false;
|
||||||
MS_LOG(INFO) << "Ascend finalize end";
|
MS_LOG(INFO) << "Ascend finalize end";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -407,7 +408,10 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(ms_context);
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||||
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, kAscendDevice, device_id);
|
auto ascend_device_address_ptr =
|
||||||
|
std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, kAscendDevice, device_id);
|
||||||
|
ascend_device_address_ptr->set_is_ptr_persisted(true);
|
||||||
|
return ascend_device_address_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||||
|
@ -415,8 +419,10 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(ms_context);
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||||
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, node_index, kAscendDevice,
|
auto ascend_device_address_ptr = std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id,
|
||||||
device_id);
|
node_index, kAscendDevice, device_id);
|
||||||
|
ascend_device_address_ptr->set_is_ptr_persisted(true);
|
||||||
|
return ascend_device_address_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AscendKernelRuntime::Load(const session::KernelGraph &graph, bool is_task_sink) {
|
bool AscendKernelRuntime::Load(const session::KernelGraph &graph, bool is_task_sink) {
|
||||||
|
|
|
@ -73,7 +73,6 @@ class AscendKernelRuntime : public KernelRuntime {
|
||||||
void *GetModelStream(uint32_t graph_id) const override;
|
void *GetModelStream(uint32_t graph_id) const override;
|
||||||
// add for MindRT
|
// add for MindRT
|
||||||
void ReleaseDeviceRes() override;
|
void ReleaseDeviceRes() override;
|
||||||
void SetCurrentContext();
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||||
|
@ -90,6 +89,7 @@ class AscendKernelRuntime : public KernelRuntime {
|
||||||
static bool HcclInit();
|
static bool HcclInit();
|
||||||
static bool NeedDestroyHccl();
|
static bool NeedDestroyHccl();
|
||||||
static bool DestroyHccl();
|
static bool DestroyHccl();
|
||||||
|
void SetCurrentContext();
|
||||||
|
|
||||||
void ClearGraphModelMap();
|
void ClearGraphModelMap();
|
||||||
bool GraphWithEmptyTaskList(const session::KernelGraph &graph) const;
|
bool GraphWithEmptyTaskList(const session::KernelGraph &graph) const;
|
||||||
|
|
|
@ -150,6 +150,7 @@ void AiCoreDynamicKernel::AllocateWorkspace() {
|
||||||
workspace_addr_.clear();
|
workspace_addr_.clear();
|
||||||
for (auto size : workspaces_size_) {
|
for (auto size : workspaces_size_) {
|
||||||
auto device_address_ptr = std::make_shared<AscendDeviceAddress>(nullptr, size, kAscendDevice, device_id);
|
auto device_address_ptr = std::make_shared<AscendDeviceAddress>(nullptr, size, kAscendDevice, device_id);
|
||||||
|
device_address_ptr->set_is_ptr_persisted(true);
|
||||||
auto device_ptr = runtime_instance->MallocMem(MemType::kDynamicMem, size, device_address_ptr);
|
auto device_ptr = runtime_instance->MallocMem(MemType::kDynamicMem, size, device_address_ptr);
|
||||||
if (device_ptr == nullptr) {
|
if (device_ptr == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "MallocMem from memory pool failed. Node info :" << cnode->fullname_with_scope();
|
MS_LOG(EXCEPTION) << "MallocMem from memory pool failed. Node info :" << cnode->fullname_with_scope();
|
||||||
|
|
|
@ -98,6 +98,8 @@ class DeviceAddress : public mindspore::DeviceSync {
|
||||||
TypeId type_id() const { return type_id_; }
|
TypeId type_id() const { return type_id_; }
|
||||||
bool from_mem_pool() const { return from_mem_pool_; }
|
bool from_mem_pool() const { return from_mem_pool_; }
|
||||||
void set_from_mem_pool(bool from_mem_pool) { from_mem_pool_ = from_mem_pool; }
|
void set_from_mem_pool(bool from_mem_pool) { from_mem_pool_ = from_mem_pool; }
|
||||||
|
bool is_ptr_persisted() const { return is_ptr_persisted_; }
|
||||||
|
void set_is_ptr_persisted(bool is_ptr_persisted) { is_ptr_persisted_ = is_ptr_persisted; }
|
||||||
void set_host_shape(const ShapeVector &shape) { host_shape_ = shape; }
|
void set_host_shape(const ShapeVector &shape) { host_shape_ = shape; }
|
||||||
virtual void set_status(DeviceAddressStatus status) {}
|
virtual void set_status(DeviceAddressStatus status) {}
|
||||||
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
|
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
|
||||||
|
@ -134,6 +136,10 @@ class DeviceAddress : public mindspore::DeviceSync {
|
||||||
ShapeVector host_shape_{};
|
ShapeVector host_shape_{};
|
||||||
// {node, out_index}
|
// {node, out_index}
|
||||||
std::pair<AnfNodeWeakPtr, size_t> node_index_{AnfNodePtr(nullptr), 0};
|
std::pair<AnfNodeWeakPtr, size_t> node_index_{AnfNodePtr(nullptr), 0};
|
||||||
|
// The device address of the node that owns the device address cannot be updated and replaced.
|
||||||
|
// application scenario: set to true when the hardware execution mode requires that ptr cannot be changed during
|
||||||
|
// execution.
|
||||||
|
bool is_ptr_persisted_{false};
|
||||||
|
|
||||||
// The key of device context.
|
// The key of device context.
|
||||||
std::string device_name_{""};
|
std::string device_name_{""};
|
||||||
|
|
|
@ -997,6 +997,7 @@ void KernelAdjust::AssignLoopCtrlTensorMem(const session::KernelGraph &kernel_gr
|
||||||
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||||
device_address =
|
device_address =
|
||||||
std::make_shared<device::ascend::AscendDeviceAddress>(nullptr, size, format, type_id, kAscendDevice, device_id);
|
std::make_shared<device::ascend::AscendDeviceAddress>(nullptr, size, format, type_id, kAscendDevice, device_id);
|
||||||
|
device_address->set_is_ptr_persisted(true);
|
||||||
|
|
||||||
if (runtime_instance->MallocMem(kStaticMem, size, device_address) == nullptr) {
|
if (runtime_instance->MallocMem(kStaticMem, size, device_address) == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Cannot alloc static memory for device loop control parameter " << name
|
MS_LOG(EXCEPTION) << "Cannot alloc static memory for device loop control parameter " << name
|
||||||
|
|
|
@ -149,6 +149,10 @@ bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_
|
||||||
} else if (dst_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
} else if (dst_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
||||||
// Other device tensor copy to CPU device tensor.
|
// Other device tensor copy to CPU device tensor.
|
||||||
return src_device_tensor->SyncDeviceToHost(copy_size, dst_device_tensor->GetMutablePtr());
|
return src_device_tensor->SyncDeviceToHost(copy_size, dst_device_tensor->GetMutablePtr());
|
||||||
|
} else if (dst_device_tensor->DeviceType() == src_device_tensor->DeviceType()) {
|
||||||
|
return dst_device_tensor->SyncDeviceToDevice(ShapeVector(), src_device_tensor->GetSize(),
|
||||||
|
src_device_tensor->type_id(), src_device_tensor->GetPtr(),
|
||||||
|
src_device_tensor->format());
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "Invalid device type, src device type: " << src_device_tensor->DeviceType()
|
MS_LOG(ERROR) << "Invalid device type, src device type: " << src_device_tensor->DeviceType()
|
||||||
<< ", dst device type: " << dst_device_tensor->DeviceType();
|
<< ", dst device type: " << dst_device_tensor->DeviceType();
|
||||||
|
|
|
@ -239,7 +239,8 @@ void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector<std::vect
|
||||||
auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
|
auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
|
||||||
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
|
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
|
||||||
MS_EXCEPTION_IF_NULL(device_address);
|
MS_EXCEPTION_IF_NULL(device_address);
|
||||||
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType())) {
|
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType()) &&
|
||||||
|
!device_address->is_ptr_persisted()) {
|
||||||
AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
|
AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -348,6 +348,9 @@ GraphId GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph, const Device
|
||||||
|
|
||||||
auto graph_id = CompileGraphImpl(root_graph, device_context);
|
auto graph_id = CompileGraphImpl(root_graph, device_context);
|
||||||
|
|
||||||
|
// dump all graphs.
|
||||||
|
device_context->DumpAllGraphs(all_graphs);
|
||||||
|
|
||||||
// Cache the backend graph output nodes to front nodes with output index.
|
// Cache the backend graph output nodes to front nodes with output index.
|
||||||
auto output = func_graph->output();
|
auto output = func_graph->output();
|
||||||
MS_EXCEPTION_IF_NULL(output);
|
MS_EXCEPTION_IF_NULL(output);
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
#include "backend/optimizer/ascend/ascend_backend_optimization.h"
|
#include "backend/optimizer/ascend/ascend_backend_optimization.h"
|
||||||
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
|
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
|
||||||
#include "backend/session/ascend_auto_monad.h"
|
|
||||||
#include "utils/context/graph_kernel_flags.h"
|
#include "utils/context/graph_kernel_flags.h"
|
||||||
#include "runtime/device/ascend/kernel_select_ascend.h"
|
#include "runtime/device/ascend/kernel_select_ascend.h"
|
||||||
#include "runtime/device/kernel_adjust.h"
|
#include "runtime/device/kernel_adjust.h"
|
||||||
|
@ -34,6 +33,17 @@
|
||||||
#include "debug/dump_proto.h"
|
#include "debug/dump_proto.h"
|
||||||
#include "debug/data_dump/e2e_dump.h"
|
#include "debug/data_dump/e2e_dump.h"
|
||||||
#endif
|
#endif
|
||||||
|
#ifdef ENABLE_DEBUGGER
|
||||||
|
#include "debug/tensor_load.h"
|
||||||
|
#include "debug/debugger/proto_exporter.h"
|
||||||
|
#else
|
||||||
|
#include "debug/debugger/proto_exporter_stub.h"
|
||||||
|
#endif
|
||||||
|
#ifdef ENABLE_DUMP_IR
|
||||||
|
#include "debug/rdr/running_data_recorder.h"
|
||||||
|
#include "debug/rdr/recorder_manager.h"
|
||||||
|
#include "debug/rdr/graph_recorder.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
|
@ -69,11 +79,47 @@ void Dump(const KernelGraphPtr &graph, uint32_t rank_id) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
void AscendDeviceContext::DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) const {
|
||||||
|
#ifdef ENABLE_DUMP_IR
|
||||||
|
auto context_ptr = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
|
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||||
|
auto &json_parser = DumpJsonParser::GetInstance();
|
||||||
|
json_parser.Parse();
|
||||||
|
if (!save_graphs && !json_parser.e2e_dump_enabled() && !json_parser.async_dump_enabled() &&
|
||||||
|
!mindspore::RecorderManager::Instance().RdrEnable()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (auto &graph : all_graphs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
std::string name = "graph_build." + std::to_string(graph->graph_id());
|
||||||
|
DumpGraphParams dump_params = {true, static_cast<int>(kWholeStack)};
|
||||||
|
(void)mindspore::RDR::RecordAnfGraph(SUBMODULE_ID, name, graph, dump_params, ".ir;.pb");
|
||||||
|
if (save_graphs) {
|
||||||
|
std::string file_name = "graph_build_" + std::to_string(graph->graph_id()) + ".ir";
|
||||||
|
DumpIR(file_name, graph, true, kWholeStack);
|
||||||
|
DumpIRProto(graph, "vm_build_" + std::to_string(graph->graph_id()));
|
||||||
|
DumpIR("trace_code_graph", graph, true, kWholeStack);
|
||||||
|
}
|
||||||
|
std::string final_graph = "trace_code_graph_" + std::to_string(graph->graph_id());
|
||||||
|
if (json_parser.e2e_dump_enabled() || json_parser.async_dump_enabled()) {
|
||||||
|
std::string root_dir = json_parser.path() + "/rank_" + std::to_string(rank_id_);
|
||||||
|
std::string target_dir = root_dir + "/graphs";
|
||||||
|
std::string ir_file_path = target_dir + "/" + "ms_output_" + final_graph + ".ir";
|
||||||
|
DumpIRProtoWithSrcInfo(graph, final_graph, target_dir, kDebugWholeStack);
|
||||||
|
DumpIR("trace_code_graph", graph, true, kWholeStack, ir_file_path);
|
||||||
|
DumpGraphExeOrder("ms_execution_order_graph_" + std::to_string(graph->graph_id()) + ".csv", root_dir,
|
||||||
|
graph->execution_order());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
void AscendDeviceContext::Initialize() {
|
void AscendDeviceContext::Initialize() {
|
||||||
MS_LOG(INFO) << "Status record: Enter Initialize...";
|
MS_LOG(INFO) << "Status record: Enter Initialize...";
|
||||||
if (initialized_) {
|
if (initialized_) {
|
||||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||||
runtime_instance_->SetCurrentContext();
|
runtime_instance_->SetContext();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,8 +155,7 @@ void AscendDeviceContext::Destroy() {
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Status record: Destroy start...";
|
MS_LOG(INFO) << "Status record: Destroy start...";
|
||||||
rank_id_ = 0;
|
rank_id_ = 0;
|
||||||
if (runtime_instance_ != nullptr) {
|
if (runtime_instance_) {
|
||||||
runtime_instance_->ReleaseDeviceRes();
|
|
||||||
runtime_instance_ = nullptr;
|
runtime_instance_ = nullptr;
|
||||||
}
|
}
|
||||||
initialized_ = false;
|
initialized_ = false;
|
||||||
|
@ -181,9 +226,44 @@ void AscendDeviceContext::PreprocessBeforeRunGraph(const KernelGraphPtr &graph)
|
||||||
MS_LOG(INFO) << "PreprocessBeforeRunGraph success.";
|
MS_LOG(INFO) << "PreprocessBeforeRunGraph success.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
AssignOutputNopNodeDeviceAddress(graph);
|
||||||
MS_LOG(INFO) << "PreprocessBeforeRunGraph success.";
|
MS_LOG(INFO) << "PreprocessBeforeRunGraph success.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AscendDeviceContext::AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
auto outputs = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
|
||||||
|
for (auto output : outputs) {
|
||||||
|
if (!output->isa<CNode>() || !AnfAlgo::IsRealKernel(output)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!opt::IsNopNode(output)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t input_num = AnfAlgo::GetInputTensorNum(output);
|
||||||
|
if (input_num != 1) {
|
||||||
|
MS_LOG(WARNING) << "The input number of nop node :" << output->fullname_with_scope() << " is " << input_num
|
||||||
|
<< ", not equal 1";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto real_input_index = AnfAlgo::GetRealInputIndex(output, 0);
|
||||||
|
auto pre_node_out_device_address = AnfAlgo::GetPrevNodeOutputAddr(output, real_input_index);
|
||||||
|
MS_EXCEPTION_IF_NULL(pre_node_out_device_address);
|
||||||
|
auto ptr = pre_node_out_device_address->GetPtr();
|
||||||
|
auto size = pre_node_out_device_address->GetSize();
|
||||||
|
std::string output_format = AnfAlgo::GetOutputFormat(output, 0);
|
||||||
|
auto output_type = AnfAlgo::GetOutputDeviceDataType(output, 0);
|
||||||
|
auto device_address = CreateDeviceAddress(const_cast<void *>(ptr), size, output_format, output_type);
|
||||||
|
AnfAlgo::SetOutputAddr(device_address, 0, output.get());
|
||||||
|
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(false), output);
|
||||||
|
MS_LOG(INFO) << "Assign device address to output nop node " << output->fullname_with_scope();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void AscendDeviceContext::AllocateGraphMemory(const NotNull<KernelGraphPtr> &root_graph) const {
|
void AscendDeviceContext::AllocateGraphMemory(const NotNull<KernelGraphPtr> &root_graph) const {
|
||||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||||
runtime_instance_->ClearGlobalIdleMem();
|
runtime_instance_->ClearGlobalIdleMem();
|
||||||
|
@ -225,7 +305,7 @@ void AscendDeviceContext::LoadModel(const NotNull<KernelGraphPtr> &root_graph) c
|
||||||
bool AscendDeviceContext::AllocateMemory(DeviceAddress *const &address, size_t size) const {
|
bool AscendDeviceContext::AllocateMemory(DeviceAddress *const &address, size_t size) const {
|
||||||
MS_EXCEPTION_IF_NULL(address);
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||||
runtime_instance_->SetCurrentContext();
|
runtime_instance_->SetContext();
|
||||||
auto device_ptr = mem_manager_->MallocMemFromMemPool(size);
|
auto device_ptr = mem_manager_->MallocMemFromMemPool(size);
|
||||||
if (!device_ptr) {
|
if (!device_ptr) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -249,7 +329,7 @@ void AscendDeviceContext::FreeMemory(DeviceAddress *const &address) const {
|
||||||
bool AscendDeviceContext::AllocateContinuousMemory(const std::vector<DeviceAddressPtr> &addr_list, size_t total_size,
|
bool AscendDeviceContext::AllocateContinuousMemory(const std::vector<DeviceAddressPtr> &addr_list, size_t total_size,
|
||||||
const std::vector<size_t> &size_list) const {
|
const std::vector<size_t> &size_list) const {
|
||||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||||
runtime_instance_->SetCurrentContext();
|
runtime_instance_->SetContext();
|
||||||
return mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list);
|
return mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -296,7 +376,7 @@ bool AscendDeviceContext::LaunchGraph(const KernelGraphPtr &graph) const {
|
||||||
MS_LOG(INFO) << "Status record: start launch graph. graph id: " << graph->graph_id();
|
MS_LOG(INFO) << "Status record: start launch graph. graph id: " << graph->graph_id();
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||||
runtime_instance_->SetCurrentContext();
|
runtime_instance_->SetContext();
|
||||||
device::KernelAdjust::GetInstance().LoadDeviceLoopCtrlParameters(graph);
|
device::KernelAdjust::GetInstance().LoadDeviceLoopCtrlParameters(graph);
|
||||||
auto ret = ExecuteGraph(graph);
|
auto ret = ExecuteGraph(graph);
|
||||||
MS_LOG(INFO) << "Status record: end launch graph. graph id: " << graph->graph_id();
|
MS_LOG(INFO) << "Status record: end launch graph. graph id: " << graph->graph_id();
|
||||||
|
@ -336,7 +416,7 @@ bool AscendDeviceContext::LaunchKernel(const CNodePtr &kernel, const vector<Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AscendDeviceContext::BindDeviceToCurrentThread() const {
|
bool AscendDeviceContext::BindDeviceToCurrentThread() const {
|
||||||
runtime_instance_->SetCurrentContext();
|
runtime_instance_->SetContext();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -126,6 +126,9 @@ class AscendDeviceContext : public DeviceContext {
|
||||||
// set rt_context_ to this thread to control device
|
// set rt_context_ to this thread to control device
|
||||||
bool BindDeviceToCurrentThread() const;
|
bool BindDeviceToCurrentThread() const;
|
||||||
|
|
||||||
|
// dump all graphs.
|
||||||
|
void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Graph loader interface
|
// Graph loader interface
|
||||||
void AllocateGraphMemory(const NotNull<KernelGraphPtr> &root_graph) const;
|
void AllocateGraphMemory(const NotNull<KernelGraphPtr> &root_graph) const;
|
||||||
|
@ -150,6 +153,7 @@ class AscendDeviceContext : public DeviceContext {
|
||||||
mutable std::set<KernelGraphPtr> memo_;
|
mutable std::set<KernelGraphPtr> memo_;
|
||||||
// Using node to get it's atomics
|
// Using node to get it's atomics
|
||||||
mutable std::map<CNodePtr, std::vector<CNodePtr>> node_atomics_;
|
mutable std::map<CNodePtr, std::vector<CNodePtr>> node_atomics_;
|
||||||
|
void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph) const;
|
||||||
};
|
};
|
||||||
} // namespace ascend
|
} // namespace ascend
|
||||||
} // namespace device
|
} // namespace device
|
||||||
|
|
|
@ -51,6 +51,8 @@ void AscendGraphOptimization::OptimizeGraph(const KernelGraphPtr &graph) {
|
||||||
OptimizeGraphWithDeviceInfo(graph);
|
OptimizeGraphWithDeviceInfo(graph);
|
||||||
OptimizeExecutionOrder(graph);
|
OptimizeExecutionOrder(graph);
|
||||||
PostOptimization(graph);
|
PostOptimization(graph);
|
||||||
|
// must clear memo_ which holds kernelgraph after using AscendGraphOptimization class.
|
||||||
|
memo_.clear();
|
||||||
MS_LOG(INFO) << "Status record: end optimize graph. graph id: " << graph->graph_id();
|
MS_LOG(INFO) << "Status record: end optimize graph. graph id: " << graph->graph_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,6 +73,9 @@ void AscendGraphOptimization::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
memo_.clear();
|
memo_.clear();
|
||||||
HardWareOptimization(graph);
|
HardWareOptimization(graph);
|
||||||
|
// copy child graph ref output map to father graph ref output map
|
||||||
|
memo_.clear();
|
||||||
|
UpdateRefOutputMap(graph);
|
||||||
AnfAlgo::InsertMakeTupleForOutput(NOT_NULL(graph));
|
AnfAlgo::InsertMakeTupleForOutput(NOT_NULL(graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,9 +117,6 @@ void AscendGraphOptimization::OptimizeExecutionOrder(const KernelGraphPtr &graph
|
||||||
|
|
||||||
void AscendGraphOptimization::PostOptimization(const KernelGraphPtr &graph) {
|
void AscendGraphOptimization::PostOptimization(const KernelGraphPtr &graph) {
|
||||||
MS_LOG(INFO) << "Status record: start post optimization. graph id: " << graph->graph_id();
|
MS_LOG(INFO) << "Status record: start post optimization. graph id: " << graph->graph_id();
|
||||||
// copy child graph ref output map to father graph ref output map
|
|
||||||
memo_.clear();
|
|
||||||
UpdateRefOutputMap(graph);
|
|
||||||
graph->SetInputNodes();
|
graph->SetInputNodes();
|
||||||
graph->SetOptimizerFlag();
|
graph->SetOptimizerFlag();
|
||||||
MS_LOG(INFO) << "Status record: end post optimization. graph id: " << graph->graph_id();
|
MS_LOG(INFO) << "Status record: end post optimization. graph id: " << graph->graph_id();
|
||||||
|
|
|
@ -152,6 +152,10 @@ class DeviceContext {
|
||||||
// Return collective communication object for caller to access
|
// Return collective communication object for caller to access
|
||||||
CollectiveCommunicationLibPtr collective_comm_lib() const { return collective_comm_lib_; }
|
CollectiveCommunicationLibPtr collective_comm_lib() const { return collective_comm_lib_; }
|
||||||
|
|
||||||
|
// TODO(jiaorui): will be delete
|
||||||
|
// Dump all graphs.
|
||||||
|
virtual void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) const {}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
DeviceContextKey device_context_key_;
|
DeviceContextKey device_context_key_;
|
||||||
CollectiveCommunicationLibPtr collective_comm_lib_;
|
CollectiveCommunicationLibPtr collective_comm_lib_;
|
||||||
|
|
|
@ -43,6 +43,11 @@ class DeviceSync {
|
||||||
const std::string &format) const {
|
const std::string &format) const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
virtual bool AsyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
|
||||||
|
const std::string &format) const {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
virtual void *GetMutablePtr() const = 0;
|
virtual void *GetMutablePtr() const = 0;
|
||||||
virtual void ClearDeviceMemory() = 0;
|
virtual void ClearDeviceMemory() = 0;
|
||||||
|
|
||||||
|
|
|
@ -95,7 +95,7 @@ fi
|
||||||
|
|
||||||
while read line; do
|
while read line; do
|
||||||
if [ -f "${line}" ]; then
|
if [ -f "${line}" ]; then
|
||||||
${CLANG_FORMAT} -i "${line}"
|
"${CLANG_FORMAT}" -i "${line}"
|
||||||
fi
|
fi
|
||||||
done < "${FMT_FILE_LIST}"
|
done < "${FMT_FILE_LIST}"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue