!38516 Add HCCL Globalworkspace
Merge pull request !38516 from archer2049/master
This commit is contained in:
commit
005212c4cb
|
@ -17,6 +17,7 @@
|
|||
#include "plugin/device/ascend/hal/device/ascend_memory_adapter.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include "ir/func_graph.h"
|
||||
#include "runtime/mem.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
|
@ -33,6 +34,7 @@ constexpr size_t kExtraReservedMemory = 10485760; // 10mb
|
|||
constexpr double kHalfRatio = 0.5;
|
||||
// The Ascend max available device memory is 32GB.
|
||||
constexpr float kAscendMaxDeviceMemory = 32;
|
||||
constexpr uint64_t kOverflowAddrSize = 512;
|
||||
|
||||
size_t AscendMemAdapter::GetRoundDownAlignSize(size_t input_size) {
|
||||
return (input_size / kAscendMemAlignSize) * kAscendMemAlignSize;
|
||||
|
@ -174,6 +176,20 @@ uint8_t *AscendMemAdapter::MallocDynamicDevMem(size_t size, const std::string &t
|
|||
return memory_block_ptr;
|
||||
}
|
||||
|
||||
uint8_t *AscendMemAdapter::MallocOverflowMem(const CNodePtr &kernel) {
|
||||
std::lock_guard<std::mutex> locker(overflow_mutex_);
|
||||
auto funcGraph = kernel->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(funcGraph);
|
||||
if (overflow_memory_info_map_.find(funcGraph->ToString()) != overflow_memory_info_map_.cend()) {
|
||||
return overflow_memory_info_map_.find(funcGraph->ToString())->second;
|
||||
} else {
|
||||
auto overflow_memory_ptr = MallocStaticDevMem(kOverflowAddrSize, "overflow memory ptr");
|
||||
MS_EXCEPTION_IF_NULL(overflow_memory_ptr);
|
||||
overflow_memory_info_map_.insert({funcGraph->ToString(), overflow_memory_ptr});
|
||||
return overflow_memory_ptr;
|
||||
}
|
||||
}
|
||||
|
||||
void AscendMemAdapter::ResetDynamicMemory() { cur_dynamic_mem_offset_ = 0; }
|
||||
|
||||
std::string AscendMemAdapter::DevMemStatistics() const {
|
||||
|
|
|
@ -22,11 +22,11 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include "utils/ms_context.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
|
||||
class AscendMemAdapter {
|
||||
public:
|
||||
static AscendMemAdapter &GetInstance() {
|
||||
|
@ -39,6 +39,7 @@ class AscendMemAdapter {
|
|||
|
||||
uint8_t *MallocStaticDevMem(size_t size, const std::string &tag = "");
|
||||
uint8_t *MallocDynamicDevMem(size_t size, const std::string &tag = "");
|
||||
uint8_t *MallocOverflowMem(const CNodePtr &kernel);
|
||||
bool FreeStaticDevMem(void *) const { return true; }
|
||||
void ResetDynamicMemory();
|
||||
|
||||
|
@ -73,6 +74,9 @@ class AscendMemAdapter {
|
|||
// Support multi-thread.
|
||||
std::mutex mutex_;
|
||||
|
||||
// Support overflow case.
|
||||
std::mutex overflow_mutex_;
|
||||
|
||||
// rts Memory INFO
|
||||
size_t device_hbm_total_size_{0};
|
||||
size_t device_hbm_free_size_{0};
|
||||
|
@ -90,6 +94,9 @@ class AscendMemAdapter {
|
|||
uint64_t static_mem_offset_{0};
|
||||
std::vector<std::shared_ptr<MemoryBlock>> static_memory_block_list_;
|
||||
static size_t GetRoundDownAlignSize(size_t input_size);
|
||||
|
||||
// overflow memory info, key is kernel, val is memory ptr
|
||||
mindspore::HashMap<std::string, uint8_t *> overflow_memory_info_map_;
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
|
|
|
@ -82,6 +82,9 @@ void HcclTask::Distribute() {
|
|||
ge_task.kernelHcclInfo[0].dataType = static_cast<int32_t>(task_info_->data_type());
|
||||
ge_task.kernelHcclInfo[0].opType = static_cast<int32_t>(task_info_->op_type());
|
||||
ge_task.kernelHcclInfo[0].rootId = task_info_->root_id();
|
||||
if (!task_info_->global_workspace_addr().empty()) {
|
||||
ge_task.kernelHcclInfo[0].global_workspace_addr = task_info_->global_workspace_addr();
|
||||
}
|
||||
|
||||
std::vector<rtStream_t> secondary_stream_list;
|
||||
std::transform(secondary_stream_list_.begin(), secondary_stream_list_.end(),
|
||||
|
|
|
@ -262,6 +262,11 @@ class HcclTaskInfo : public TaskInfo {
|
|||
int64_t op_type() const { return op_type_; }
|
||||
int64_t data_type() const { return data_type_; }
|
||||
const std::string &group() const { return group_; }
|
||||
const std::vector<void *> &global_workspace_addr() const { return global_workspace_addr_; }
|
||||
|
||||
void SetGlobalWorkspaceAddr(const std::vector<void *> &global_workspace_addr) {
|
||||
this->global_workspace_addr_ = global_workspace_addr;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string hccl_type_;
|
||||
|
@ -277,6 +282,8 @@ class HcclTaskInfo : public TaskInfo {
|
|||
int64_t op_type_;
|
||||
int64_t data_type_;
|
||||
std::string group_;
|
||||
// hccl global overflow addr
|
||||
std::vector<void *> global_workspace_addr_;
|
||||
};
|
||||
|
||||
class ProfilerTraceTaskInfo : public TaskInfo {
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "runtime/device/kernel_runtime.h"
|
||||
#include "plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h"
|
||||
#include "plugin/device/ascend/hal/device/distribute/ascend_collective.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_memory_adapter.h"
|
||||
|
||||
using HcclTaskInfoPtr = std::shared_ptr<mindspore::ge::model_runner::HcclTaskInfo>;
|
||||
using mindspore::ge::model_runner::HcclTaskInfo;
|
||||
|
@ -286,11 +287,19 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
|
|||
workspace_addr = workspace.at(0)->addr;
|
||||
}
|
||||
|
||||
results.emplace_back(
|
||||
std::vector<void *> global_workspace_addr;
|
||||
auto overflow_memory_ptr =
|
||||
device::ascend::AscendMemAdapter::GetInstance().MallocOverflowMem(anf_node_.lock()->cast<CNodePtr>());
|
||||
MS_EXCEPTION_IF_NULL(overflow_memory_ptr);
|
||||
global_workspace_addr.push_back(reinterpret_cast<void *>(overflow_memory_ptr));
|
||||
|
||||
HcclTaskInfoPtr hcclTaskInfo =
|
||||
std::make_shared<HcclTaskInfo>(unique_name_, stream_id, hccl::HcclAdapter::GetHcclType(anf_node), input_data_addr,
|
||||
output_data_addr, workspace_addr, task.workspace_size, task.stream_num,
|
||||
private_def, hccl::HcclAdapter::GetInstance().GetHcclOpsKernelInfoStore(),
|
||||
hccl_count_, root_id_, op_type_, data_type, group_, NeedDump()));
|
||||
hccl_count_, root_id_, op_type_, data_type, group_, NeedDump());
|
||||
hcclTaskInfo->SetGlobalWorkspaceAddr(global_workspace_addr);
|
||||
results.emplace_back(hcclTaskInfo);
|
||||
}
|
||||
|
||||
return results;
|
||||
|
|
Loading…
Reference in New Issue