sync-from-trunk-to-blue-zone-0912-c75b220

This commit is contained in:
wuweikang 2020-09-12 15:23:04 +08:00
parent 61bdcb71a6
commit e24f534002
15 changed files with 84 additions and 81 deletions

@ -1 +1 @@
Subproject commit b6d2dd731c5f841fa9e2f3fdf0815cf1ed9d5ddc Subproject commit 6dcf11d26eca81a328c7069235c7675c557fe0c0

View File

@ -121,7 +121,7 @@ class AscendEnvChecker(EnvChecker):
"""ascend environment check""" """ascend environment check"""
def __init__(self): def __init__(self):
self.version = ["1.75.T15.0.B150"] self.version = ["1.75.22.0.220"]
atlas_fwk_version = "/usr/local/Ascend/nnae/latest/fwkacllib/version.info" atlas_fwk_version = "/usr/local/Ascend/nnae/latest/fwkacllib/version.info"
hisi_fwk_version = "/usr/local/Ascend/fwkacllib/version.info" hisi_fwk_version = "/usr/local/Ascend/fwkacllib/version.info"
if os.path.exists(atlas_fwk_version): if os.path.exists(atlas_fwk_version):

View File

@ -44,14 +44,14 @@ HcclKernelFactory &HcclKernelFactory::Get() {
return _this; return _this;
} }
HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REP_OP_SUM), root_id_(0), anf_node_(nullptr) {} HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REDUCE_SUM), root_id_(0), anf_node_(nullptr) {}
HcclKernel::~HcclKernel() { HcclKernel::~HcclKernel() {
hccl_kernel_input_shape_list_.clear(); hccl_kernel_input_shape_list_.clear();
hccl_kernel_output_shape_list_.clear(); hccl_kernel_output_shape_list_.clear();
hccl_data_type_list_.clear(); hccl_data_type_list_.clear();
hccl_count_ = 0; hccl_count_ = 0;
op_type_ = HCCL_REP_OP_SUM; op_type_ = HCCL_REDUCE_SUM;
root_id_ = 0; root_id_ = 0;
input_size_list_.clear(); input_size_list_.clear();
output_size_list_.clear(); output_size_list_.clear();
@ -141,7 +141,7 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
void *workspace_address = nullptr; void *workspace_address = nullptr;
const int64_t workspace_num = 0; const int64_t workspace_num = 0;
std::vector<uint8_t> private_def; std::vector<uint8_t> private_def;
hcclDataType_t data_type = hccl_data_type_list_[0]; HcclDataType data_type = hccl_data_type_list_[0];
MS_LOG(INFO) << "HCCL Task : stream_id=" << stream_id << ", ws_num=" << workspace_num << ", count=" << hccl_count_ MS_LOG(INFO) << "HCCL Task : stream_id=" << stream_id << ", ws_num=" << workspace_num << ", count=" << hccl_count_
<< ", root_id=" << root_id_ << ", op_type=" << static_cast<int>(op_type_) << ", root_id=" << root_id_ << ", op_type=" << static_cast<int>(op_type_)

View File

@ -26,6 +26,7 @@
#include "backend/kernel_compiler/ascend_kernel_mod.h" #include "backend/kernel_compiler/ascend_kernel_mod.h"
#include "backend/kernel_compiler/hccl/hcom_util.h" #include "backend/kernel_compiler/hccl/hcom_util.h"
#include "hccl/hcom.h" #include "hccl/hcom.h"
#include "hccl/hccl_types.h"
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
namespace mindspore { namespace mindspore {
@ -44,10 +45,10 @@ class HcclKernel : public AscendKernelMod {
protected: protected:
std::vector<std::vector<size_t>> hccl_kernel_input_shape_list_; std::vector<std::vector<size_t>> hccl_kernel_input_shape_list_;
std::vector<std::vector<size_t>> hccl_kernel_output_shape_list_; std::vector<std::vector<size_t>> hccl_kernel_output_shape_list_;
std::vector<hcclDataType_t> hccl_data_type_list_; std::vector<HcclDataType> hccl_data_type_list_;
std::vector<std::string> hccl_format_list_; std::vector<std::string> hccl_format_list_;
uint64_t hccl_count_; uint64_t hccl_count_;
hcclRedOp_t op_type_; HcclReduceOp op_type_;
uint32_t root_id_; uint32_t root_id_;
mutable std::vector<size_t> input_size_list_; mutable std::vector<size_t> input_size_list_;
mutable std::vector<size_t> output_size_list_; mutable std::vector<size_t> output_size_list_;

View File

@ -34,7 +34,7 @@ bool HcomAllBroadCastKernel::Launch(const std::vector<AddressPtr> &inputs,
} }
const char *tag = "Hccl-BroadCast"; const char *tag = "Hccl-BroadCast";
MS_EXCEPTION_IF_NULL(inputs[0]); MS_EXCEPTION_IF_NULL(inputs[0]);
hcclResult_t ret = HcclResult ret =
hcom_broadcast(tag, inputs[0]->addr, hccl_count_, hccl_data_type_list_[0], root_id_, nullptr, stream_ptr); hcom_broadcast(tag, inputs[0]->addr, hccl_count_, hccl_data_type_list_[0], root_id_, nullptr, stream_ptr);
if (ret != HCCL_SUCCESS) { if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcomBroadcastOp : hcom_broadcast fail, return: " << static_cast<int>(ret); MS_LOG(ERROR) << "HcomBroadcastOp : hcom_broadcast fail, return: " << static_cast<int>(ret);

View File

@ -32,7 +32,7 @@ bool HcomAllGatherKernel::Launch(const std::vector<AddressPtr> &inputs, const st
return false; return false;
} }
const char *tag = "Hccl-AllGather"; const char *tag = "Hccl-AllGather";
hcclResult_t ret = HcclResult ret =
hcom_all_gather(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], nullptr, stream_ptr); hcom_all_gather(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], nullptr, stream_ptr);
if (ret != HCCL_SUCCESS) { if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcomAllGatherKernelOp : hcom_all_gather fail, return: " << static_cast<int>(ret); MS_LOG(ERROR) << "HcomAllGatherKernelOp : hcom_all_gather fail, return: " << static_cast<int>(ret);

View File

@ -32,8 +32,8 @@ bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> &inputs, const st
return false; return false;
} }
const char *tag = "Hccl-AllReduce"; const char *tag = "Hccl-AllReduce";
hcclResult_t ret = hcom_all_reduce(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], HcclResult ret = hcom_all_reduce(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0],
op_type_, nullptr, stream_ptr); op_type_, nullptr, stream_ptr);
if (ret != HCCL_SUCCESS) { if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcomAllReduceKernelOp : hcom_all_reduce fail, return: " << static_cast<int>(ret); MS_LOG(ERROR) << "HcomAllReduceKernelOp : hcom_all_reduce fail, return: " << static_cast<int>(ret);
return false; return false;

View File

@ -33,8 +33,8 @@ bool HcomAllReduceScatterKernel::Launch(const std::vector<AddressPtr> &inputs,
return false; return false;
} }
const char *tag = "Hccl-ReduceScatter"; const char *tag = "Hccl-ReduceScatter";
hcclResult_t ret = hcom_reduce_scatter(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], HcclResult ret = hcom_reduce_scatter(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0],
op_type_, nullptr, stream_ptr); op_type_, nullptr, stream_ptr);
if (ret != HCCL_SUCCESS) { if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcomReduceScatterOp : hcom_reduce_scatter fail, return: " << static_cast<int>(ret); MS_LOG(ERROR) << "HcomReduceScatterOp : hcom_reduce_scatter fail, return: " << static_cast<int>(ret);
return false; return false;

View File

@ -43,7 +43,7 @@ bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<si
return true; return true;
} }
bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<hcclDataType_t> *data_type_list) { bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> *data_type_list) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(data_type_list); MS_EXCEPTION_IF_NULL(data_type_list);
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) {
@ -56,14 +56,14 @@ bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<hcclDataType_t
} }
auto type_base = *(std::begin(*data_type_list)); auto type_base = *(std::begin(*data_type_list));
if (std::any_of(data_type_list->begin(), data_type_list->end(), if (std::any_of(data_type_list->begin(), data_type_list->end(),
[&type_base](hcclDataType_t type) { return type != type_base; })) { [&type_base](HcclDataType type) { return type != type_base; })) {
MS_LOG(ERROR) << "hccl have different data type"; MS_LOG(ERROR) << "hccl have different data type";
return false; return false;
} }
return true; return true;
} }
bool HcomUtil::GetHcclOpSize(const hcclDataType_t &data_type, const vector<size_t> &shape, size_t *size) { bool HcomUtil::GetHcclOpSize(const HcclDataType &data_type, const vector<size_t> &shape, size_t *size) {
MS_EXCEPTION_IF_NULL(size); MS_EXCEPTION_IF_NULL(size);
size_t tmp_size = 1; size_t tmp_size = 1;
uint32_t type_size = 4; uint32_t type_size = 4;
@ -81,7 +81,7 @@ bool HcomUtil::GetHcclOpSize(const hcclDataType_t &data_type, const vector<size_
return true; return true;
} }
bool HcomUtil::GetHcomTypeSize(const hcclDataType_t &data_type, uint32_t *size) { bool HcomUtil::GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size) {
MS_EXCEPTION_IF_NULL(size); MS_EXCEPTION_IF_NULL(size);
auto iter = CONST_OP_HCOM_DATA_TYPE_SIZE_MAP.find(data_type); auto iter = CONST_OP_HCOM_DATA_TYPE_SIZE_MAP.find(data_type);
if (iter == CONST_OP_HCOM_DATA_TYPE_SIZE_MAP.end()) { if (iter == CONST_OP_HCOM_DATA_TYPE_SIZE_MAP.end()) {
@ -92,7 +92,7 @@ bool HcomUtil::GetHcomTypeSize(const hcclDataType_t &data_type, uint32_t *size)
return true; return true;
} }
bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<hcclDataType_t> &data_type_list, bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataType> &data_type_list,
const vector<vector<size_t>> &shape_list, uint64_t *total_count) { const vector<vector<size_t>> &shape_list, uint64_t *total_count) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(total_count); MS_EXCEPTION_IF_NULL(total_count);
@ -143,7 +143,7 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<hcclDataTyp
return true; return true;
} }
bool HcomUtil::GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_type) { bool HcomUtil::GetHcomOperationType(const AnfNodePtr &anf_node, HcclReduceOp *op_type) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(op_type); MS_EXCEPTION_IF_NULL(op_type);
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
@ -155,13 +155,13 @@ bool HcomUtil::GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_
auto hcom_op_type_get = GetValue<const char *>(primitive->GetAttr("op")); auto hcom_op_type_get = GetValue<const char *>(primitive->GetAttr("op"));
string hcom_op_type(hcom_op_type_get); string hcom_op_type(hcom_op_type_get);
if (hcom_op_type == "min") { if (hcom_op_type == "min") {
*op_type = HCCL_REP_OP_MIN; *op_type = HCCL_REDUCE_MIN;
} else if (hcom_op_type == "max") { } else if (hcom_op_type == "max") {
*op_type = HCCL_REP_OP_MAX; *op_type = HCCL_REDUCE_MAX;
} else if (hcom_op_type == "prod") { } else if (hcom_op_type == "prod") {
*op_type = HCCL_REP_OP_PROD; *op_type = HCCL_REDUCE_PROD;
} else if (hcom_op_type == "sum") { } else if (hcom_op_type == "sum") {
*op_type = HCCL_REP_OP_SUM; *op_type = HCCL_REDUCE_SUM;
} else { } else {
MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [" << hcom_op_type << "] not support!"; MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [" << hcom_op_type << "] not support!";
return false; return false;

View File

@ -24,6 +24,7 @@
#include "ir/dtype.h" #include "ir/dtype.h"
#include "hccl/base.h" #include "hccl/base.h"
#include "utils/contract.h" #include "utils/contract.h"
#include "hccl/hccl_types.h"
namespace mindspore { namespace mindspore {
using std::map; using std::map;
@ -36,31 +37,31 @@ constexpr auto kBroadcast = "Broadcast";
constexpr auto kReduceScatter = "ReduceScatter"; constexpr auto kReduceScatter = "ReduceScatter";
/* Correspondence between data_type and hcom data type in Ascend */ /* Correspondence between data_type and hcom data type in Ascend */
static map<int64_t, hcclDataType_t> CONST_OP_HCOM_DATA_TYPE_MAP = { static map<int64_t, HcclDataType> CONST_OP_HCOM_DATA_TYPE_MAP = {
{TypeId::kNumberTypeFloat32, HCCL_DATA_TYPE_FLOAT}, {TypeId::kNumberTypeFloat32, HCCL_DATA_TYPE_FP32},
{TypeId::kNumberTypeFloat16, HCCL_DATA_TYPE_HALF}, {TypeId::kNumberTypeFloat16, HCCL_DATA_TYPE_FP16},
{TypeId::kNumberTypeInt8, HCCL_DATA_TYPE_INT8}, {TypeId::kNumberTypeInt8, HCCL_DATA_TYPE_INT8},
{TypeId::kNumberTypeInt32, HCCL_DATA_TYPE_INT}, {TypeId::kNumberTypeInt32, HCCL_DATA_TYPE_INT32},
}; };
/* Correspondence between data_type and occupied byte size in hcom */ /* Correspondence between data_type and occupied byte size in hcom */
static map<hcclDataType_t, uint32_t> CONST_OP_HCOM_DATA_TYPE_SIZE_MAP = { static map<HcclDataType, uint32_t> CONST_OP_HCOM_DATA_TYPE_SIZE_MAP = {
{HCCL_DATA_TYPE_FLOAT, sizeof(float)}, {HCCL_DATA_TYPE_FP32, sizeof(float)},
{HCCL_DATA_TYPE_HALF, sizeof(float) / 2}, {HCCL_DATA_TYPE_FP16, sizeof(float) / 2},
{HCCL_DATA_TYPE_INT8, sizeof(int8_t)}, {HCCL_DATA_TYPE_INT8, sizeof(int8_t)},
{HCCL_DATA_TYPE_INT, sizeof(int32_t)}, {HCCL_DATA_TYPE_INT32, sizeof(int32_t)},
}; };
class HcomUtil { class HcomUtil {
public: public:
static bool GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_shape_list); static bool GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_shape_list);
static bool GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_shape_list); static bool GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_shape_list);
static bool GetHcomDataType(const AnfNodePtr &anf_node, vector<hcclDataType_t> *data_type_list); static bool GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> *data_type_list);
static bool GetHcclOpSize(const hcclDataType_t &data_type, const vector<size_t> &shape, size_t *size); static bool GetHcclOpSize(const HcclDataType &data_type, const vector<size_t> &shape, size_t *size);
static bool GetHcomTypeSize(const hcclDataType_t &data_type, uint32_t *size); static bool GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size);
static bool GetHcomCount(const AnfNodePtr &anf_node, const vector<hcclDataType_t> &data_type_list, static bool GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataType> &data_type_list,
const vector<vector<size_t>> &shape_list, uint64_t *total_count); const vector<vector<size_t>> &shape_list, uint64_t *total_count);
static bool GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_type); static bool GetHcomOperationType(const AnfNodePtr &anf_node, HcclReduceOp *op_type);
static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id); static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id);
static void GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group); static void GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group);
}; };

View File

@ -120,6 +120,10 @@ Status DeviceQueueOp::SendDataToAscend() {
TensorRow currRow; TensorRow currRow;
for (int row_id = 0; row_id < current_buffer->NumRows(); row_id++) { for (int row_id = 0; row_id < current_buffer->NumRows(); row_id++) {
RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow)); RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow));
while (stop_send_) {
MS_LOG(DEBUG) << "stop_send flag is set, waiting for continue signal...";
std::this_thread::sleep_for(std::chrono::microseconds(100));
}
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost); auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost);
if (status == TdtStatus::FAILED) { if (status == TdtStatus::FAILED) {
if (stop_send_) { if (stop_send_) {
@ -153,10 +157,6 @@ Status DeviceQueueOp::SendDataToAscend() {
} }
if (current_buffer->eoe() && send_epoch_end_) { if (current_buffer->eoe() && send_epoch_end_) {
TensorRow currRow; TensorRow currRow;
while (stop_send_) {
MS_LOG(DEBUG) << "stop_send flag is set, waiting for continue signal...";
std::this_thread::sleep_for(std::chrono::microseconds(100));
}
auto status = auto status =
tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE); tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE);
if (status == TdtStatus::FAILED) { if (status == TdtStatus::FAILED) {

View File

@ -642,7 +642,7 @@ bool AscendKernelRuntime::HcclInit() {
return false; return false;
} }
MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str; MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str;
hcclResult_t res = hcom_init(full_path, rank_id_str.c_str()); HcclResult res = hcom_init(full_path, rank_id_str.c_str());
free(full_path); free(full_path);
if (res != HCCL_SUCCESS) { if (res != HCCL_SUCCESS) {
MS_LOG(ERROR) << "Hcom init failed, res is " << static_cast<int>(res); MS_LOG(ERROR) << "Hcom init failed, res is " << static_cast<int>(res);
@ -658,7 +658,7 @@ bool AscendKernelRuntime::DestroyHccl() {
MS_LOG(INFO) << "Hccl is not enable, no need to close."; MS_LOG(INFO) << "Hccl is not enable, no need to close.";
return true; return true;
} }
hcclResult_t res = hcom_destroy(); HcclResult res = hcom_destroy();
if (res != HCCL_SUCCESS) { if (res != HCCL_SUCCESS) {
MS_LOG(ERROR) << "Hccl destroy failed"; MS_LOG(ERROR) << "Hccl destroy failed";
return false; return false;

View File

@ -20,6 +20,7 @@
#include "hccl/hcom.h" #include "hccl/hcom.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "hccl/hccl_types.h"
#include "utils/utils.h" #include "utils/utils.h"
constexpr auto kHcomBroadcast = "hcom_broadcast_"; constexpr auto kHcomBroadcast = "hcom_broadcast_";
@ -32,7 +33,7 @@ namespace device {
namespace ascend { namespace ascend {
namespace tasksink { namespace tasksink {
bool RuntimeUtils::HcomBindModel(rtModel_t model, rtStream_t stream) { bool RuntimeUtils::HcomBindModel(rtModel_t model, rtStream_t stream) {
hcclResult_t ret = hcom_bind_model(model, stream); HcclResult ret = hcom_bind_model(model, stream);
if (ret != HCCL_SUCCESS) { if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "Call hcom_bind_model failed, ret: 0x" << static_cast<int>(ret); MS_LOG(ERROR) << "Call hcom_bind_model failed, ret: 0x" << static_cast<int>(ret);
return false; return false;
@ -41,7 +42,7 @@ bool RuntimeUtils::HcomBindModel(rtModel_t model, rtStream_t stream) {
} }
bool RuntimeUtils::HcomUnbindModel(rtModel_t model) { bool RuntimeUtils::HcomUnbindModel(rtModel_t model) {
hcclResult_t ret = hcom_unbind_model(model); HcclResult ret = hcom_unbind_model(model);
if (ret != HCCL_SUCCESS) { if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "Call hcom_unbind_model failed, ret: 0x" << static_cast<int>(ret); MS_LOG(ERROR) << "Call hcom_unbind_model failed, ret: 0x" << static_cast<int>(ret);
return false; return false;
@ -52,14 +53,14 @@ bool RuntimeUtils::HcomUnbindModel(rtModel_t model) {
bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info, rtStream_t stream) { bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info, rtStream_t stream) {
MS_LOG(INFO) << "hccl distribute start"; MS_LOG(INFO) << "hccl distribute start";
MS_EXCEPTION_IF_NULL(task_info); MS_EXCEPTION_IF_NULL(task_info);
hcclResult_t ret; HcclResult ret;
static uint32_t task_counter = 0; static uint32_t task_counter = 0;
auto hccl_group = task_info->group(); auto hccl_group = task_info->group();
if (task_info->hccl_type() == kBroadcastOpName) { if (task_info->hccl_type() == kBroadcastOpName) {
// call hcom broadcast interface to run op // call hcom broadcast interface to run op
const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0); const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0);
ret = hcom_broadcast(tag_broadcast.c_str(), task_info->input_data_addr(), static_cast<u64>(task_info->count()), ret = hcom_broadcast(tag_broadcast.c_str(), task_info->input_data_addr(), static_cast<u64>(task_info->count()),
static_cast<hcclDataType_t>(task_info->data_type()), static_cast<u32>(task_info->root_id()), static_cast<HcclDataType>(task_info->data_type()), static_cast<u32>(task_info->root_id()),
hccl_group.c_str(), stream); hccl_group.c_str(), stream);
if (ret != HCCL_SUCCESS) { if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast<int>(ret); MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast<int>(ret);
@ -69,7 +70,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
// call hcom allgather interface to run op // call hcom allgather interface to run op
const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0); const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0);
ret = hcom_all_gather(tag_all_gather.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), ret = hcom_all_gather(tag_all_gather.c_str(), task_info->input_data_addr(), task_info->output_data_addr(),
static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()), static_cast<u64>(task_info->count()), static_cast<HcclDataType>(task_info->data_type()),
hccl_group.c_str(), stream); hccl_group.c_str(), stream);
if (ret != HCCL_SUCCESS) { if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret; MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret;
@ -79,8 +80,8 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
// call hcom allreduce interface to run op // call hcom allreduce interface to run op
const string tag_all_reduce = kHcomAllReduce + std::to_string(task_counter++) + kUnderline + std::to_string(0); const string tag_all_reduce = kHcomAllReduce + std::to_string(task_counter++) + kUnderline + std::to_string(0);
ret = hcom_all_reduce(tag_all_reduce.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), ret = hcom_all_reduce(tag_all_reduce.c_str(), task_info->input_data_addr(), task_info->output_data_addr(),
static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()), static_cast<u64>(task_info->count()), static_cast<HcclDataType>(task_info->data_type()),
static_cast<hcclRedOp_t>(task_info->op_type()), hccl_group.c_str(), stream); static_cast<HcclReduceOp>(task_info->op_type()), hccl_group.c_str(), stream);
if (ret != HCCL_SUCCESS) { if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret; MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret;
return false; return false;
@ -90,8 +91,8 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
const string tag_reduce_scatter = const string tag_reduce_scatter =
kHcomReduceScatter + std::to_string(task_counter++) + kUnderline + std::to_string(0); kHcomReduceScatter + std::to_string(task_counter++) + kUnderline + std::to_string(0);
ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), task_info->input_data_addr(), task_info->output_data_addr(),
static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()), static_cast<u64>(task_info->count()), static_cast<HcclDataType>(task_info->data_type()),
static_cast<hcclRedOp_t>(task_info->op_type()), hccl_group.c_str(), stream); static_cast<HcclReduceOp>(task_info->op_type()), hccl_group.c_str(), stream);
if (ret != HCCL_SUCCESS) { if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret; MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret;
return false; return false;

View File

@ -13,13 +13,13 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import os import os
import pytest # import pytest
@pytest.mark.level0 # @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training # @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training # @pytest.mark.platform_arm_ascend_training
@pytest.mark.env_single # @pytest.mark.env_single
def test_wide_and_deep(): def test_wide_and_deep():
sh_path = os.path.split(os.path.realpath(__file__))[0] sh_path = os.path.split(os.path.realpath(__file__))[0]
ret = os.system(f"sh {sh_path}/run_wide_and_deep_auto_parallel.sh") ret = os.system(f"sh {sh_path}/run_wide_and_deep_auto_parallel.sh")

View File

@ -24,97 +24,97 @@ extern "C" {
#endif #endif
/* 集合通信域初始化 */ /* 集合通信域初始化 */
hcclResult_t hcom_init(const char *rank_table, const char *identify) { return HCCL_SUCCESS; } HcclResult hcom_init(const char *rank_table, const char *identify) { return HCCL_SUCCESS; }
/* 解析ranktable for python */ /* 解析ranktable for python */
hcclResult_t hcom_rank_info_init(const char *rank_table, const char *identify, u32 device_id) { return HCCL_SUCCESS; } HcclResult hcom_rank_info_init(const char *rank_table, const char *identify, u32 device_id) { return HCCL_SUCCESS; }
/* 集合通信域销毁 */ /* 集合通信域销毁 */
hcclResult_t hcom_destroy(void) { return HCCL_SUCCESS; } HcclResult hcom_destroy(void) { return HCCL_SUCCESS; }
/* 绑定model */ /* 绑定model */
hcclResult_t hcom_bind_model(rtModel_t model, rtStream_t stream) { return HCCL_SUCCESS; } HcclResult hcom_bind_model(rtModel_t model, rtStream_t stream) { return HCCL_SUCCESS; }
/* 绑解定model */ /* 绑解定model */
hcclResult_t hcom_unbind_model(rtModel_t model) { return HCCL_SUCCESS; } HcclResult hcom_unbind_model(rtModel_t model) { return HCCL_SUCCESS; }
/* allgather功能实现 */ /* allgather功能实现 */
hcclResult_t hcom_all_gather(const char *tag, void *inputPtr, void *outputPtr, u64 inputCount, hcclDataType_t dataType, HcclResult hcom_all_gather(const char *tag, void *inputPtr, void *outputPtr, u64 inputCount, HcclDataType dataType,
const char *group, rtStream_t stream) { const char *group, rtStream_t stream) {
return HCCL_SUCCESS; return HCCL_SUCCESS;
} }
/* allreduce功能实现 */ /* allreduce功能实现 */
hcclResult_t hcom_all_reduce(const char *tag, void *inputPtr, void *outputPtr, u64 count, hcclDataType_t dataType, HcclResult hcom_all_reduce(const char *tag, void *inputPtr, void *outputPtr, u64 count, HcclDataType dataType,
hcclRedOp_t op, const char *group, rtStream_t stream) { HcclReduceOp op, const char *group, rtStream_t stream) {
return HCCL_SUCCESS; return HCCL_SUCCESS;
} }
/* broadcas功能实现 */ /* broadcas功能实现 */
hcclResult_t hcom_broadcast(const char *tag, void *ptr, u64 count, hcclDataType_t dataType, u32 root, const char *group, HcclResult hcom_broadcast(const char *tag, void *ptr, u64 count, HcclDataType dataType, u32 root, const char *group,
rtStream_t stream) { rtStream_t stream) {
return HCCL_SUCCESS; return HCCL_SUCCESS;
} }
/* reduce_scatter功能实现 */ /* reduce_scatter功能实现 */
hcclResult_t hcom_reduce_scatter(const char *tag, void *inputPtr, void *outputPtr, u64 count, hcclDataType_t dataType, HcclResult hcom_reduce_scatter(const char *tag, void *inputPtr, void *outputPtr, u64 count, HcclDataType dataType,
hcclRedOp_t op, const char *group, rtStream_t stream) { HcclReduceOp op, const char *group, rtStream_t stream) {
return HCCL_SUCCESS; return HCCL_SUCCESS;
} }
/* 获取group内的rank个数 */ /* 获取group内的rank个数 */
hcclResult_t hcom_get_rank_size(const char *group, u32 *rankSize) { return HCCL_SUCCESS; } HcclResult hcom_get_rank_size(const char *group, u32 *rankSize) { return HCCL_SUCCESS; }
/* python获取上云场景内的rank个数 */ /* python获取上云场景内的rank个数 */
hcclResult_t hcom_python_get_rank_size(u32 *rankSize) { return HCCL_SUCCESS; } HcclResult hcom_python_get_rank_size(u32 *rankSize) { return HCCL_SUCCESS; }
/* 获取本rank的id */ /* 获取本rank的id */
hcclResult_t hcom_get_rank_id(const char *group, u32 *rankId) { return HCCL_SUCCESS; } HcclResult hcom_get_rank_id(const char *group, u32 *rankId) { return HCCL_SUCCESS; }
/* 获取本rank的id */ /* 获取本rank的id */
hcclResult_t hcom_python_get_rank_id(u32 *rankId) { return HCCL_SUCCESS; } HcclResult hcom_python_get_rank_id(u32 *rankId) { return HCCL_SUCCESS; }
/* 获取本rank的id */ /* 获取本rank的id */
hcclResult_t hcom_get_world_rank_from_group_rank(const char *group, u32 groupRank, u32 *worldRank) { HcclResult hcom_get_world_rank_from_group_rank(const char *group, u32 groupRank, u32 *worldRank) {
return HCCL_SUCCESS; return HCCL_SUCCESS;
} }
/* 获取通信域的rank个数 */ /* 获取通信域的rank个数 */
hcclResult_t hcom_get_group_rank_from_world_rank(u32 worldRank, const char *group, u32 *groupRank) { HcclResult hcom_get_group_rank_from_world_rank(u32 worldRank, const char *group, u32 *groupRank) {
return HCCL_SUCCESS; return HCCL_SUCCESS;
} }
/* 创建group */ /* 创建group */
hcclResult_t hcom_create_group(const char *group, u32 rankNum, u32 *rankIds) { return HCCL_SUCCESS; } HcclResult hcom_create_group(const char *group, u32 rankNum, u32 *rankIds) { return HCCL_SUCCESS; }
/* 销毁group */ /* 销毁group */
hcclResult_t hcom_destroy_group(const char *group) { return HCCL_SUCCESS; } HcclResult hcom_destroy_group(const char *group) { return HCCL_SUCCESS; }
/* 发送消息 */ /* 发送消息 */
hcclResult_t hcom_send(const char *tag, void *inputPtr, u64 count, hcclDataType_t dataType, u32 destRank, u32 srTag, HcclResult hcom_send(const char *tag, void *inputPtr, u64 count, HcclDataType dataType, u32 destRank, u32 srTag,
const char *group, rtStream_t stream) { const char *group, rtStream_t stream) {
return HCCL_SUCCESS; return HCCL_SUCCESS;
} }
/* 接收消息 */ /* 接收消息 */
hcclResult_t hcom_receive(const char *tag, void *outputPtr, u64 count, hcclDataType_t dataType, u32 srcRank, u32 srTag, HcclResult hcom_receive(const char *tag, void *outputPtr, u64 count, HcclDataType dataType, u32 srcRank, u32 srTag,
const char *group, rtStream_t stream) { const char *group, rtStream_t stream) {
return HCCL_SUCCESS; return HCCL_SUCCESS;
} }
/* 获取梯度参数切分方案 */ /* 获取梯度参数切分方案 */
hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 maxSegmentNum, HcclResult hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 maxSegmentNum,
u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force, u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force,
OriginalGraphShapeType shapeType) { OriginalGraphShapeType shapeType) {
return HCCL_SUCCESS; return HCCL_SUCCESS;
} }
/* 连通性检测 */ /* 连通性检测 */
hcclResult_t hcom_connectivity_detection(s32 *result) { return HCCL_SUCCESS; } HcclResult hcom_connectivity_detection(s32 *result) { return HCCL_SUCCESS; }
hcclResult_t hcom_set_split_strategy_by_index(const char *group, u32 segmentNum, const u32 *IdxList) { HcclResult hcom_set_split_strategy_by_index(const char *group, u32 segmentNum, const u32 *IdxList) {
return HCCL_SUCCESS; return HCCL_SUCCESS;
} }
hcclResult_t hcom_set_split_strategy_by_size(const char *group, u32 segmentNum, const float *sizeList) { HcclResult hcom_set_split_strategy_by_size(const char *group, u32 segmentNum, const float *sizeList) {
return HCCL_SUCCESS; return HCCL_SUCCESS;
} }
#ifdef __cplusplus #ifdef __cplusplus