forked from mindspore-Ecosystem/mindspore
!2313 Save profiling point data to framework
Merge pull request !2313 from caifubi/save-profiling-point
This commit is contained in:
commit
a52231440a
|
@ -14,7 +14,6 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <map>
|
||||
#include "device/ascend/profiling/reporter/graph_desc_reporter.h"
|
||||
#include "device/ascend/profiling/profiling_utils.h"
|
||||
#include "kernel/kernel.h"
|
||||
|
@ -24,6 +23,7 @@
|
|||
#include "utils/utils.h"
|
||||
#include "device/ascend/profiling/reporter/task_desc_reporter.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "device/ascend/profiling/reporter/point_reporter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -33,8 +33,9 @@ constexpr char kCustomNode[] = "PROFILING_CUSTOM_";
|
|||
constexpr char kFpStartNode[] = "PROFILING_FP_START";
|
||||
constexpr char kBpEndNode[] = "PROFILING_BP_END";
|
||||
constexpr char kIterEndNode[] = "PROFILING_ITER_END";
|
||||
std::unordered_map<uint32_t, std::vector<CNodePtr>> ProfilingUtils::graph_profiling_cnode_;
|
||||
std::unordered_map<uint32_t, std::vector<std::string>> ProfilingUtils::graph_kernel_name_;
|
||||
std::map<uint32_t, std::vector<CNodePtr>> ProfilingUtils::graph_profiling_cnode_;
|
||||
std::map<uint32_t, std::vector<std::string>> ProfilingUtils::graph_kernel_name_;
|
||||
std::map<uint32_t, std::vector<std::shared_ptr<ProfDesc>>> ProfilingUtils::graph_point_;
|
||||
uint32_t ProfilingUtils::custom_node_index_ = 1;
|
||||
|
||||
ProfilingTraceInfo ProfilingUtils::GetProfilingTraceFromEnv(NotNull<const session::KernelGraph *> graph_ptr) {
|
||||
|
@ -102,6 +103,7 @@ std::string ProfilingUtils::GetTraceBegin(const std::vector<CNodePtr> &cnode_exe
|
|||
void ProfilingUtils::GetCNodeOutputRealNode(const std::string &node_name, const std::vector<CNodePtr> &cnode_exec_order,
|
||||
NotNull<std::set<std::string> *> getnext_outputs) {
|
||||
for (const auto &cnode : cnode_exec_order) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (const auto &input : cnode->inputs()) {
|
||||
auto prev_cnode = AnfAlgo::VisitKernel(input, 0);
|
||||
if (!prev_cnode.first->isa<CNode>()) {
|
||||
|
@ -203,6 +205,17 @@ NotNull<CNodePtr> ProfilingUtils::CreateProfilingCNode(const ProfilingContent &p
|
|||
return NOT_NULL(cnode_ptr);
|
||||
}
|
||||
|
||||
void ProfilingUtils::SaveProfilingPoint(uint32_t graph_id, const std::string &node_name, uint32_t point_id) {
|
||||
std::shared_ptr<ProfDesc> prof_desc_ptr = std::make_shared<PointDesc>(node_name, point_id);
|
||||
auto iter = graph_point_.find(graph_id);
|
||||
if (iter == graph_point_.end()) {
|
||||
std::vector<std::shared_ptr<ProfDesc>> tmp_vect = {prof_desc_ptr};
|
||||
graph_point_.insert({graph_id, tmp_vect});
|
||||
} else {
|
||||
iter->second.emplace_back(prof_desc_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
void ProfilingUtils::ProfilingTraceFpStart(const mindspore::AnfNodePtr &anf_node,
|
||||
const ProfilingTraceInfo &profiling_trace_info,
|
||||
NotNull<session::KernelGraph *> graph_ptr,
|
||||
|
@ -213,6 +226,8 @@ void ProfilingUtils::ProfilingTraceFpStart(const mindspore::AnfNodePtr &anf_node
|
|||
ProfilingContent fp_profiling_content = {false, kProfilingFpStartLogId, 0};
|
||||
auto fp_profiling_node = CreateProfilingCNodeWithStream(anf_node, fp_profiling_content, graph_ptr);
|
||||
kernel_list->emplace_back(fp_profiling_node);
|
||||
// insert ProfDesc
|
||||
SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), kProfilingFpStartLogId);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -244,13 +259,16 @@ void ProfilingUtils::ProfilingCustomOp(const AnfNodePtr &anf_node, const Profili
|
|||
}
|
||||
MS_LOG(INFO) << "Profiling Match CustomOp:" << anf_node->fullname_with_scope();
|
||||
// custom op profiling job start from 3.
|
||||
ProfilingContent front_profiling_content = {false, 2 * custom_node_index_ + 1, 0};
|
||||
auto custom_point_id = 2 * custom_node_index_ + 1;
|
||||
ProfilingContent front_profiling_content = {false, custom_point_id, 0};
|
||||
CNodePtr front_node = CreateProfilingCNodeWithStream(anf_node, front_profiling_content, graph_ptr);
|
||||
kernel_list->insert(kernel_list->end() - 1, front_node);
|
||||
SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), custom_point_id);
|
||||
|
||||
ProfilingContent back_profiling_content = {false, 2 * custom_node_index_ + 2, 0};
|
||||
ProfilingContent back_profiling_content = {false, custom_point_id + 1, 0};
|
||||
CNodePtr back_node = CreateProfilingCNodeWithStream(anf_node, back_profiling_content, graph_ptr);
|
||||
kernel_list->insert(kernel_list->end(), back_node);
|
||||
SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), custom_point_id + 1);
|
||||
++custom_node_index_;
|
||||
}
|
||||
|
||||
|
@ -263,6 +281,7 @@ void ProfilingUtils::ProfilingTraceBpEnd(const AnfNodePtr &anf_node, const Profi
|
|||
ProfilingContent bp_end_profiling_content = {false, kProfilingBpEndLogId, 0};
|
||||
CNodePtr bp_end_node = CreateProfilingCNodeWithStream(anf_node, bp_end_profiling_content, graph_ptr);
|
||||
kernel_list->emplace_back(bp_end_node);
|
||||
SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), kProfilingBpEndLogId);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -276,6 +295,7 @@ void ProfilingUtils::ProfilingTraceEnd(const AnfNodePtr &anf_node, const Profili
|
|||
ProfilingContent bp_end_profiling_content = {true, kProfilingIterEndLogId, 0};
|
||||
CNodePtr bp_kernel_ptr = CreateProfilingCNodeWithStream(anf_node, bp_end_profiling_content, graph_ptr);
|
||||
kernel_list->emplace_back(bp_kernel_ptr);
|
||||
SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), kProfilingIterEndLogId);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -325,6 +345,18 @@ void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids,
|
|||
GraphDescReporter graph_reporter(context->device_id(), "vm.graph_desc_info", ret->second);
|
||||
graph_profiling_cnode_.erase(ret);
|
||||
graph_reporter.ReportData();
|
||||
|
||||
// Report profiling point
|
||||
auto point_iter = graph_point_.find(graph->graph_id());
|
||||
if (point_iter == graph_point_.end()) {
|
||||
MS_LOG(ERROR) << "Graph id not found in graph_point";
|
||||
return;
|
||||
}
|
||||
PointReporter point_reporter(context->device_id(), "vm.point");
|
||||
for (const auto &point : point_iter->second) {
|
||||
point_reporter.AddReportData(point);
|
||||
}
|
||||
point_reporter.ReportData();
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -23,6 +24,7 @@
|
|||
#include <unordered_map>
|
||||
#include "session/kernel_graph.h"
|
||||
#include "utils/contract.h"
|
||||
#include "device/ascend/profiling/reporter/profiling_desc.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -104,7 +106,7 @@ class ProfilingUtils {
|
|||
NotNull<session::KernelGraph *> graph_ptr,
|
||||
NotNull<std::vector<mindspore::CNodePtr> *> kernel_list);
|
||||
|
||||
static std::unordered_map<uint32_t, std::vector<std::string>> graph_kernel_name() { return graph_kernel_name_; }
|
||||
static std::map<uint32_t, std::vector<std::string>> graph_kernel_name() { return graph_kernel_name_; }
|
||||
|
||||
inline static constexpr char kProfiling[] = "Profiling";
|
||||
inline static constexpr char kNotify[] = "notify";
|
||||
|
@ -126,10 +128,12 @@ class ProfilingUtils {
|
|||
NotNull<std::set<std::string> *> getnext_outputs);
|
||||
|
||||
static bool ValidComputeGraph(NotNull<const session::KernelGraph *> graph_ptr);
|
||||
static void SaveProfilingPoint(uint32_t graph_id, const std::string &node_name, uint32_t point_id);
|
||||
|
||||
// graph id --> (kernel name list)
|
||||
static std::unordered_map<uint32_t, std::vector<CNodePtr>> graph_profiling_cnode_;
|
||||
static std::unordered_map<uint32_t, std::vector<std::string>> graph_kernel_name_;
|
||||
static std::map<uint32_t, std::vector<CNodePtr>> graph_profiling_cnode_;
|
||||
static std::map<uint32_t, std::vector<std::string>> graph_kernel_name_;
|
||||
static std::map<uint32_t, std::vector<std::shared_ptr<ProfDesc>>> graph_point_;
|
||||
static uint32_t custom_node_index_;
|
||||
};
|
||||
} // namespace ascend
|
||||
|
|
|
@ -56,8 +56,8 @@ void DescReporter::ReportByLine(const std::string &data, const std::string &file
|
|||
}
|
||||
}
|
||||
|
||||
void DescReporter::ReportData() {
|
||||
for (const auto &desc : prof_desc_) {
|
||||
void DescReporter::ReportAllLine() {
|
||||
for (const auto &desc : prof_desc_list_) {
|
||||
auto data = desc->ToString();
|
||||
ReportByLine(data, file_name_);
|
||||
}
|
||||
|
|
|
@ -32,16 +32,17 @@ namespace ascend {
|
|||
class DescReporter {
|
||||
public:
|
||||
virtual ~DescReporter() = 0;
|
||||
DescReporter(int device_id, std::string file_name, std::vector<CNodePtr> cnode_list)
|
||||
: device_id_(device_id), file_name_(std::move(file_name)), cnode_list_(std::move(cnode_list)) {}
|
||||
virtual void ReportData();
|
||||
DescReporter(int device_id, std::string file_name) : device_id_(device_id), file_name_(std::move(file_name)) {}
|
||||
|
||||
virtual void ReportData() = 0;
|
||||
|
||||
protected:
|
||||
void ReportByLine(const std::string &data, const std::string &file_name) const;
|
||||
void ReportAllLine();
|
||||
|
||||
int device_id_;
|
||||
std::string file_name_;
|
||||
std::vector<CNodePtr> cnode_list_;
|
||||
std::vector<std::shared_ptr<ProfDesc>> prof_desc_;
|
||||
std::vector<std::shared_ptr<ProfDesc>> prof_desc_list_;
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace device {
|
|||
namespace ascend {
|
||||
void GraphDescReporter::ReportData() {
|
||||
for (const auto &node : cnode_list_) {
|
||||
if (AnfAlgo::GetKernelType(node) != TBE_KERNEL) {
|
||||
if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AUTO_DIFF_KERNEL) {
|
||||
MS_LOG(WARNING) << "Skip non tbe kernel";
|
||||
continue;
|
||||
}
|
||||
|
@ -57,9 +57,9 @@ void GraphDescReporter::ReportData() {
|
|||
}
|
||||
|
||||
auto graph_desc = std::make_shared<GraphDesc>(op_name, op_type, input_data_list, output_data_list);
|
||||
prof_desc_.emplace_back(graph_desc);
|
||||
prof_desc_list_.emplace_back(graph_desc);
|
||||
}
|
||||
DescReporter::ReportData();
|
||||
ReportAllLine();
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
|
|
|
@ -28,9 +28,12 @@ namespace ascend {
|
|||
class GraphDescReporter : public DescReporter {
|
||||
public:
|
||||
GraphDescReporter(uint32_t device_id, const std::string &file_name, std::vector<CNodePtr> cnode_list)
|
||||
: DescReporter(device_id, file_name, std::move(cnode_list)) {}
|
||||
: DescReporter(device_id, file_name), cnode_list_(std::move(cnode_list)) {}
|
||||
~GraphDescReporter() override = default;
|
||||
void ReportData() override;
|
||||
|
||||
private:
|
||||
std::vector<CNodePtr> cnode_list_;
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "device/ascend/profiling/reporter/point_reporter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
void PointReporter::ReportData() { ReportAllLine(); }
|
||||
|
||||
void PointReporter::AddReportData(const std::shared_ptr<ProfDesc> &prof_desc) {
|
||||
prof_desc_list_.emplace_back(prof_desc);
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "device/ascend/profiling/reporter/desc_reporter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
class PointReporter : public DescReporter {
|
||||
public:
|
||||
PointReporter(uint32_t device_id, const std::string &file_name) : DescReporter(device_id, file_name) {}
|
||||
~PointReporter() override = default;
|
||||
void ReportData() override;
|
||||
void AddReportData(const std::shared_ptr<ProfDesc> &prof_desc);
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_
|
|
@ -66,6 +66,12 @@ std::string GraphDesc::ToString() {
|
|||
return desc;
|
||||
}
|
||||
|
||||
std::string PointDesc::ToString() {
|
||||
std::string desc;
|
||||
desc.append(std::to_string(point_id_)).append(" ").append(op_name_).append("\n");
|
||||
return desc;
|
||||
}
|
||||
|
||||
std::string GraphDesc::DataShapeToString(const std::vector<size_t> &shape) {
|
||||
std::ostringstream oss;
|
||||
oss << "\"";
|
||||
|
|
|
@ -71,6 +71,16 @@ class GraphDesc : public ProfDesc {
|
|||
std::vector<DataElement> output_data_list_;
|
||||
[[nodiscard]] static std::string DataShapeToString(const std::vector<size_t> &shape);
|
||||
};
|
||||
|
||||
class PointDesc : public ProfDesc {
|
||||
public:
|
||||
PointDesc(std::string op_name, uint32_t point_id) : ProfDesc(std::move(op_name)), point_id_(point_id) {}
|
||||
~PointDesc() override = default;
|
||||
std::string ToString() override;
|
||||
|
||||
private:
|
||||
uint32_t point_id_;
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,7 +31,7 @@ void TaskDescReporter::ReportData() {
|
|||
|
||||
size_t task_index = 0;
|
||||
for (const auto &node : cnode_list_) {
|
||||
if (AnfAlgo::GetKernelType(node) != TBE_KERNEL) {
|
||||
if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AUTO_DIFF_KERNEL) {
|
||||
MS_LOG(WARNING) << "Skip non tbe kernel";
|
||||
++task_index;
|
||||
continue;
|
||||
|
@ -44,10 +44,10 @@ void TaskDescReporter::ReportData() {
|
|||
CheckStreamTaskValid(task_index, task_index);
|
||||
auto desc_ptr = std::make_shared<TaskDesc>(node->fullname_with_scope(), task_ids_[task_index],
|
||||
ascend_kernel_mod->block_dim(), stream_ids_[task_index]);
|
||||
prof_desc_.emplace_back(desc_ptr);
|
||||
prof_desc_list_.emplace_back(desc_ptr);
|
||||
++task_index;
|
||||
}
|
||||
DescReporter::ReportData();
|
||||
ReportAllLine();
|
||||
}
|
||||
|
||||
void TaskDescReporter::CheckStreamTaskValid(uint32_t task_id, uint32_t stream_id) {
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace ascend {
|
|||
class TaskDescReporter : public DescReporter {
|
||||
public:
|
||||
TaskDescReporter(int device_id, const std::string &file_name, std::vector<CNodePtr> cnode_list)
|
||||
: DescReporter(device_id, file_name, std::move(cnode_list)) {}
|
||||
: DescReporter(device_id, file_name), cnode_list_(std::move(cnode_list)) {}
|
||||
~TaskDescReporter() override = default;
|
||||
void ReportData() override;
|
||||
void set_task_ids(const std::vector<uint32_t> &task_ids) { task_ids_ = task_ids; }
|
||||
|
@ -38,6 +38,7 @@ class TaskDescReporter : public DescReporter {
|
|||
std::vector<uint32_t> task_ids_;
|
||||
std::vector<uint32_t> stream_ids_;
|
||||
void CheckStreamTaskValid(uint32_t task_id, uint32_t stream_id);
|
||||
std::vector<CNodePtr> cnode_list_;
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
|
|
|
@ -54,13 +54,13 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
|
|||
MS_EXCEPTION_IF_NULL(task_info);
|
||||
hcclResult_t ret;
|
||||
static uint32_t task_counter = 0;
|
||||
|
||||
auto hccl_group = task_info->group();
|
||||
if (task_info->hccl_type() == kBroadcastOpName) {
|
||||
// call hcom broadcast interface to run op
|
||||
const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0);
|
||||
ret = hcom_broadcast(tag_broadcast.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()),
|
||||
static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()),
|
||||
static_cast<u32>(task_info->root_id()), task_info->group().c_str(), stream);
|
||||
static_cast<u32>(task_info->root_id()), hccl_group.c_str(), stream);
|
||||
if (ret != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast<int>(ret);
|
||||
return false;
|
||||
|
@ -70,7 +70,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
|
|||
const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0);
|
||||
ret = hcom_all_gather(tag_all_gather.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()),
|
||||
reinterpret_cast<void *>(task_info->output_data_addr()), static_cast<u64>(task_info->count()),
|
||||
static_cast<hcclDataType_t>(task_info->data_type()), task_info->group().c_str(), stream);
|
||||
static_cast<hcclDataType_t>(task_info->data_type()), hccl_group.c_str(), stream);
|
||||
if (ret != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret;
|
||||
return false;
|
||||
|
@ -81,7 +81,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
|
|||
ret = hcom_all_reduce(tag_all_reduce.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()),
|
||||
reinterpret_cast<void *>(task_info->output_data_addr()), static_cast<u64>(task_info->count()),
|
||||
static_cast<hcclDataType_t>(task_info->data_type()),
|
||||
static_cast<hcclRedOp_t>(task_info->op_type()), task_info->group().c_str(), stream);
|
||||
static_cast<hcclRedOp_t>(task_info->op_type()), hccl_group.c_str(), stream);
|
||||
if (ret != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret;
|
||||
return false;
|
||||
|
@ -93,7 +93,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
|
|||
ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()),
|
||||
reinterpret_cast<void *>(task_info->output_data_addr()),
|
||||
static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()),
|
||||
static_cast<hcclRedOp_t>(task_info->op_type()), task_info->group().c_str(), stream);
|
||||
static_cast<hcclRedOp_t>(task_info->op_type()), hccl_group.c_str(), stream);
|
||||
if (ret != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret;
|
||||
return false;
|
||||
|
|
Loading…
Reference in New Issue