workspace of comm op can be reused
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
parent
d831aba239
commit
b7e5f956e5
|
@ -1 +1 @@
|
|||
Subproject commit f65be61197ed36dfc9dc10b91b58bf93835fa27b
|
||||
Subproject commit 40e5c42a12c4daa1530e8db9d006d5b3be5b378f
|
|
@ -46,7 +46,7 @@ std::string MsOpNameToHcomOpType(const std::string &ms_op_type) {
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) {
|
||||
void HcclKernelFactory::Register(const std::string &name, HcclKernelCreater &&fun) {
|
||||
hcclKernelMap_.emplace(name, std::move(fun));
|
||||
}
|
||||
|
||||
|
@ -99,7 +99,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) {
|
|||
if (op_name_ == kReceive) {
|
||||
auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(receive_type_);
|
||||
if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) {
|
||||
MS_LOG(ERROR) << "HcomDataType cann't support Current Ascend Data Type : " << receive_type_;
|
||||
MS_LOG(ERROR) << "HcomDataType cannot support Current Ascend Data Type : " << receive_type_;
|
||||
return false;
|
||||
}
|
||||
hccl_data_type_list_.emplace_back(iter->second);
|
||||
|
@ -180,9 +180,17 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
|
|||
return output_size_list_;
|
||||
}
|
||||
|
||||
const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||
const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const {
|
||||
if (!workspace_size_list_.empty() || hccl_data_type_list_.empty()) {
|
||||
return workspace_size_list_;
|
||||
}
|
||||
|
||||
std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
workspace_size_list_.emplace_back(hccl::CalcWorkspaceSize(anf_node_, hccl_data_type_list_[0]));
|
||||
return workspace_size_list_;
|
||||
}
|
||||
|
||||
std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
|
||||
std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_);
|
||||
if (hccl_type == kReceive) {
|
||||
|
@ -221,10 +229,19 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
|
|||
MS_LOG(EXCEPTION) << "Set data memcpy_s failed, ret = " << sec_ret;
|
||||
}
|
||||
|
||||
void *workspace_addr = nullptr;
|
||||
if (task.workspace_size != 0) {
|
||||
if (workspace.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node_->DebugString() << " is empty";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(workspace.at(0));
|
||||
workspace_addr = workspace.at(0)->addr;
|
||||
}
|
||||
|
||||
results.emplace_back(std::make_shared<HcclTaskInfo>(
|
||||
kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, task.workspace_size,
|
||||
task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, op_type_, data_type,
|
||||
group_, NeedDump()));
|
||||
kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, workspace_addr,
|
||||
task.workspace_size, task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_,
|
||||
op_type_, data_type, group_, NeedDump()));
|
||||
}
|
||||
|
||||
return results;
|
||||
|
|
|
@ -68,7 +68,7 @@ class HcclKernelFactory {
|
|||
|
||||
public:
|
||||
static HcclKernelFactory &Get();
|
||||
void Registe(const string &name, HcclKernelCreater &&fun);
|
||||
void Register(const string &name, HcclKernelCreater &&fun);
|
||||
static std::shared_ptr<HcclKernel> Get(const string &name);
|
||||
|
||||
private:
|
||||
|
@ -78,7 +78,7 @@ class HcclKernelFactory {
|
|||
class _HcclKernelRegister {
|
||||
public:
|
||||
_HcclKernelRegister(const string &name, HcclKernelCreater &&fun) {
|
||||
HcclKernelFactory::Get().Registe(name, std::move(fun));
|
||||
HcclKernelFactory::Get().Register(name, std::move(fun));
|
||||
}
|
||||
~_HcclKernelRegister() = default;
|
||||
};
|
||||
|
|
|
@ -433,6 +433,7 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
|
|||
void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) {
|
||||
AssignCommunicationNodeInputMem(type, node);
|
||||
AssignCommunicationNodeOutputMem(type, node);
|
||||
AssignWorkSpaceMem(type, node);
|
||||
}
|
||||
|
||||
void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) {
|
||||
|
|
|
@ -99,7 +99,7 @@ bool FinalizeHccl() {
|
|||
if (ops_kernel_info_store != nullptr) {
|
||||
auto ret = ops_kernel_info_store->Finalize();
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Destory info store failed, ret = " << ret;
|
||||
MS_LOG(ERROR) << "Destroy info store failed, ret = " << ret;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -107,7 +107,7 @@ bool FinalizeHccl() {
|
|||
if (ops_kernel_builder != nullptr) {
|
||||
auto ret = ops_kernel_builder->Finalize();
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Destory builder failed, ret = " << ret;
|
||||
MS_LOG(ERROR) << "Destroy builder failed, ret = " << ret;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -151,7 +151,30 @@ bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTask
|
|||
return true;
|
||||
}
|
||||
|
||||
bool CalcOpRunningParam(const AnfNodePtr &node) { return true; }
|
||||
int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) {
|
||||
MS_EXCEPTION_IF_NULL(ops_kernel_builder);
|
||||
MS_LOG(INFO) << "Start calc workspace size for hccl node " << node->DebugString() << " ,dtype is " << datatype;
|
||||
auto [ge_node, ge_graph] = GenerateStubGeNode(node, datatype);
|
||||
MS_EXCEPTION_IF_NULL(ge_node);
|
||||
auto op = ge_node->GetOpDesc();
|
||||
MS_EXCEPTION_IF_NULL(op);
|
||||
|
||||
MS_LOG(INFO) << "Start to call CalcOpRunningParam";
|
||||
ge::Status ret = ops_kernel_builder->CalcOpRunningParam(*ge_node);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "OpsKernelBuilder CalcOpRunningParam failed, ret = " << ret;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto workspace_sizes = op->GetWorkspaceBytes();
|
||||
if (workspace_sizes.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected workspace size " << workspace_sizes.size();
|
||||
}
|
||||
int64_t workspace_size = workspace_sizes[0];
|
||||
MS_LOG(INFO) << "Node " << node->DebugString() << " workspace size is " << workspace_size;
|
||||
ge_graph.reset();
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
void *GetHcclOpsKernelInfoStore() { return ops_kernel_info_store.get(); }
|
||||
|
||||
|
|
|
@ -23,21 +23,18 @@
|
|||
#include "mindspore/core/ir/anf.h"
|
||||
#include "hccl/hccl_types.h"
|
||||
|
||||
#define MS_API __attribute__((visibility("default")))
|
||||
|
||||
namespace mindspore::hccl {
|
||||
struct MS_API HcclTaskInfo {
|
||||
struct HcclTaskInfo {
|
||||
std::string private_def;
|
||||
int64_t workspace_size;
|
||||
int64_t stream_num;
|
||||
};
|
||||
|
||||
MS_API bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
|
||||
MS_API bool FinalizeHccl();
|
||||
MS_API bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists);
|
||||
MS_API bool CalcOpRunningParam(const AnfNodePtr &node);
|
||||
MS_API void *GetHcclOpsKernelInfoStore();
|
||||
MS_API std::string GetHcclType(const AnfNodePtr &node);
|
||||
bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
|
||||
bool FinalizeHccl();
|
||||
bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists);
|
||||
int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype);
|
||||
void *GetHcclOpsKernelInfoStore();
|
||||
std::string GetHcclType(const AnfNodePtr &node);
|
||||
} // namespace mindspore::hccl
|
||||
#undef MS_API
|
||||
#endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_HCCL_ADAPTER_H
|
||||
|
|
|
@ -64,7 +64,7 @@ namespace hccl {
|
|||
bool InitHccl(uint32_t, std::string_view, std::string_view) { return true; }
|
||||
bool FinalizeHccl() { return true; }
|
||||
bool GenTask(const AnfNodePtr &, HcclDataType, std::vector<HcclTaskInfo> *) { return true; }
|
||||
bool CalcOpRunningParam(const AnfNodePtr &) { return true; }
|
||||
int64_t CalcWorkspaceSize(const AnfNodePtr &, HcclDataType) { return 0; }
|
||||
void *GetHcclOpsKernelInfoStore() { return nullptr; }
|
||||
std::string GetHcclType(const AnfNodePtr &) { return ""; }
|
||||
} // namespace hccl
|
||||
|
|
Loading…
Reference in New Issue