!38516 Add HCCL Globalworkspace

Merge pull request !38516 from archer2049/master
This commit is contained in:
i-robot 2022-07-22 01:25:39 +00:00 committed by Gitee
commit 005212c4cb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 45 additions and 3 deletions

View File

@ -17,6 +17,7 @@
#include "plugin/device/ascend/hal/device/ascend_memory_adapter.h"
#include <algorithm>
#include "ir/func_graph.h"
#include "runtime/mem.h"
#include "utils/ms_context.h"
#include "utils/convert_utils_base.h"
@ -33,6 +34,7 @@ constexpr size_t kExtraReservedMemory = 10485760; // 10mb
constexpr double kHalfRatio = 0.5;
// The Ascend max available device memory is 32GB.
constexpr float kAscendMaxDeviceMemory = 32;
constexpr uint64_t kOverflowAddrSize = 512;
size_t AscendMemAdapter::GetRoundDownAlignSize(size_t input_size) {
return (input_size / kAscendMemAlignSize) * kAscendMemAlignSize;
@ -174,6 +176,20 @@ uint8_t *AscendMemAdapter::MallocDynamicDevMem(size_t size, const std::string &t
return memory_block_ptr;
}
uint8_t *AscendMemAdapter::MallocOverflowMem(const CNodePtr &kernel) {
std::lock_guard<std::mutex> locker(overflow_mutex_);
auto funcGraph = kernel->func_graph();
MS_EXCEPTION_IF_NULL(funcGraph);
if (overflow_memory_info_map_.find(funcGraph->ToString()) != overflow_memory_info_map_.cend()) {
return overflow_memory_info_map_.find(funcGraph->ToString())->second;
} else {
auto overflow_memory_ptr = MallocStaticDevMem(kOverflowAddrSize, "overflow memory ptr");
MS_EXCEPTION_IF_NULL(overflow_memory_ptr);
overflow_memory_info_map_.insert({funcGraph->ToString(), overflow_memory_ptr});
return overflow_memory_ptr;
}
}
void AscendMemAdapter::ResetDynamicMemory() { cur_dynamic_mem_offset_ = 0; }
std::string AscendMemAdapter::DevMemStatistics() const {

View File

@ -22,11 +22,11 @@
#include <memory>
#include <vector>
#include "utils/ms_context.h"
#include "ir/anf.h"
namespace mindspore {
namespace device {
namespace ascend {
class AscendMemAdapter {
public:
static AscendMemAdapter &GetInstance() {
@ -39,6 +39,7 @@ class AscendMemAdapter {
uint8_t *MallocStaticDevMem(size_t size, const std::string &tag = "");
uint8_t *MallocDynamicDevMem(size_t size, const std::string &tag = "");
uint8_t *MallocOverflowMem(const CNodePtr &kernel);
bool FreeStaticDevMem(void *) const { return true; }
void ResetDynamicMemory();
@ -73,6 +74,9 @@ class AscendMemAdapter {
// Support multi-thread.
std::mutex mutex_;
// Support overflow case.
std::mutex overflow_mutex_;
// rts Memory INFO
size_t device_hbm_total_size_{0};
size_t device_hbm_free_size_{0};
@ -90,6 +94,9 @@ class AscendMemAdapter {
uint64_t static_mem_offset_{0};
std::vector<std::shared_ptr<MemoryBlock>> static_memory_block_list_;
static size_t GetRoundDownAlignSize(size_t input_size);
// overflow memory info, key is kernel, val is memory ptr
mindspore::HashMap<std::string, uint8_t *> overflow_memory_info_map_;
};
} // namespace ascend
} // namespace device

View File

@ -82,6 +82,9 @@ void HcclTask::Distribute() {
ge_task.kernelHcclInfo[0].dataType = static_cast<int32_t>(task_info_->data_type());
ge_task.kernelHcclInfo[0].opType = static_cast<int32_t>(task_info_->op_type());
ge_task.kernelHcclInfo[0].rootId = task_info_->root_id();
if (!task_info_->global_workspace_addr().empty()) {
ge_task.kernelHcclInfo[0].global_workspace_addr = task_info_->global_workspace_addr();
}
std::vector<rtStream_t> secondary_stream_list;
std::transform(secondary_stream_list_.begin(), secondary_stream_list_.end(),

View File

@ -262,6 +262,11 @@ class HcclTaskInfo : public TaskInfo {
int64_t op_type() const { return op_type_; }
int64_t data_type() const { return data_type_; }
const std::string &group() const { return group_; }
const std::vector<void *> &global_workspace_addr() const { return global_workspace_addr_; }
void SetGlobalWorkspaceAddr(const std::vector<void *> &global_workspace_addr) {
this->global_workspace_addr_ = global_workspace_addr;
}
private:
std::string hccl_type_;
@ -277,6 +282,8 @@ class HcclTaskInfo : public TaskInfo {
int64_t op_type_;
int64_t data_type_;
std::string group_;
// hccl global overflow addr
std::vector<void *> global_workspace_addr_;
};
class ProfilerTraceTaskInfo : public TaskInfo {

View File

@ -24,6 +24,7 @@
#include "runtime/device/kernel_runtime.h"
#include "plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h"
#include "plugin/device/ascend/hal/device/distribute/ascend_collective.h"
#include "plugin/device/ascend/hal/device/ascend_memory_adapter.h"
using HcclTaskInfoPtr = std::shared_ptr<mindspore::ge::model_runner::HcclTaskInfo>;
using mindspore::ge::model_runner::HcclTaskInfo;
@ -286,11 +287,19 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
workspace_addr = workspace.at(0)->addr;
}
results.emplace_back(
std::vector<void *> global_workspace_addr;
auto overflow_memory_ptr =
device::ascend::AscendMemAdapter::GetInstance().MallocOverflowMem(anf_node_.lock()->cast<CNodePtr>());
MS_EXCEPTION_IF_NULL(overflow_memory_ptr);
global_workspace_addr.push_back(reinterpret_cast<void *>(overflow_memory_ptr));
HcclTaskInfoPtr hcclTaskInfo =
std::make_shared<HcclTaskInfo>(unique_name_, stream_id, hccl::HcclAdapter::GetHcclType(anf_node), input_data_addr,
output_data_addr, workspace_addr, task.workspace_size, task.stream_num,
private_def, hccl::HcclAdapter::GetInstance().GetHcclOpsKernelInfoStore(),
hccl_count_, root_id_, op_type_, data_type, group_, NeedDump()));
hccl_count_, root_id_, op_type_, data_type, group_, NeedDump());
hcclTaskInfo->SetGlobalWorkspaceAddr(global_workspace_addr);
results.emplace_back(hcclTaskInfo);
}
return results;