From e24f534002d8702b32ac15cc9a690240486607a0 Mon Sep 17 00:00:00 2001 From: wuweikang Date: Sat, 12 Sep 2020 15:23:04 +0800 Subject: [PATCH] sync-from-trunk-to-blue-zone-0912-c75b220 --- graphengine | 2 +- mindspore/_version_check.py | 2 +- .../kernel_compiler/hccl/hccl_kernel.cc | 6 +-- .../kernel_compiler/hccl/hccl_kernel.h | 5 +- .../hccl/hcom_all_broadcast.cc | 2 +- .../kernel_compiler/hccl/hcom_all_gather.cc | 2 +- .../kernel_compiler/hccl/hcom_all_reduce.cc | 4 +- .../hccl/hcom_all_reduce_scatter.cc | 4 +- .../backend/kernel_compiler/hccl/hcom_util.cc | 20 ++++---- .../backend/kernel_compiler/hccl/hcom_util.h | 27 +++++----- .../engine/datasetops/device_queue_op.cc | 8 +-- .../device/ascend/ascend_kernel_runtime.cc | 4 +- .../device/ascend/tasksink/runtime_utils.cc | 19 +++---- .../test_wide_and_deep_auto_parallel.py | 10 ++-- tests/ut/cpp/stub/hccl/hccl_stub.cc | 50 +++++++++---------- 15 files changed, 84 insertions(+), 81 deletions(-) diff --git a/graphengine b/graphengine index b6d2dd731c5..6dcf11d26ec 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit b6d2dd731c5f841fa9e2f3fdf0815cf1ed9d5ddc +Subproject commit 6dcf11d26eca81a328c7069235c7675c557fe0c0 diff --git a/mindspore/_version_check.py b/mindspore/_version_check.py index cf4bd5d8362..a8c7b573cdd 100644 --- a/mindspore/_version_check.py +++ b/mindspore/_version_check.py @@ -121,7 +121,7 @@ class AscendEnvChecker(EnvChecker): """ascend environment check""" def __init__(self): - self.version = ["1.75.T15.0.B150"] + self.version = ["1.75.22.0.220"] atlas_fwk_version = "/usr/local/Ascend/nnae/latest/fwkacllib/version.info" hisi_fwk_version = "/usr/local/Ascend/fwkacllib/version.info" if os.path.exists(atlas_fwk_version): diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc index bf948498aa1..0ec7c6c6251 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc @@ -44,14 +44,14 @@ HcclKernelFactory &HcclKernelFactory::Get() { return _this; } -HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REP_OP_SUM), root_id_(0), anf_node_(nullptr) {} +HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REDUCE_SUM), root_id_(0), anf_node_(nullptr) {} HcclKernel::~HcclKernel() { hccl_kernel_input_shape_list_.clear(); hccl_kernel_output_shape_list_.clear(); hccl_data_type_list_.clear(); hccl_count_ = 0; - op_type_ = HCCL_REP_OP_SUM; + op_type_ = HCCL_REDUCE_SUM; root_id_ = 0; input_size_list_.clear(); output_size_list_.clear(); @@ -141,7 +141,7 @@ std::vector HcclKernel::GenTask(const std::vector &inpu void *workspace_address = nullptr; const int64_t workspace_num = 0; std::vector private_def; - hcclDataType_t data_type = hccl_data_type_list_[0]; + HcclDataType data_type = hccl_data_type_list_[0]; MS_LOG(INFO) << "HCCL Task : stream_id=" << stream_id << ", ws_num=" << workspace_num << ", count=" << hccl_count_ << ", root_id=" << root_id_ << ", op_type=" << static_cast(op_type_) diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h index 330692e461f..b7fe21945be 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h @@ -26,6 +26,7 @@ #include "backend/kernel_compiler/ascend_kernel_mod.h" #include "backend/kernel_compiler/hccl/hcom_util.h" #include "hccl/hcom.h" +#include "hccl/hccl_types.h" #include "utils/ms_utils.h" namespace mindspore { @@ -44,10 +45,10 @@ class HcclKernel : public AscendKernelMod { protected: std::vector> hccl_kernel_input_shape_list_; std::vector> hccl_kernel_output_shape_list_; - std::vector hccl_data_type_list_; + std::vector hccl_data_type_list_; std::vector hccl_format_list_; uint64_t hccl_count_; - hcclRedOp_t op_type_; + HcclReduceOp op_type_; uint32_t root_id_; mutable std::vector input_size_list_; mutable std::vector output_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc index d031fed9210..3fff96d1b29 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc @@ -34,7 +34,7 @@ bool HcomAllBroadCastKernel::Launch(const std::vector &inputs, } const char *tag = "Hccl-BroadCast"; MS_EXCEPTION_IF_NULL(inputs[0]); - hcclResult_t ret = + HcclResult ret = hcom_broadcast(tag, inputs[0]->addr, hccl_count_, hccl_data_type_list_[0], root_id_, nullptr, stream_ptr); if (ret != HCCL_SUCCESS) { MS_LOG(ERROR) << "HcomBroadcastOp : hcom_broadcast fail, return: " << static_cast(ret); diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc index fc4832ae6fb..db8d2edf739 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc @@ -32,7 +32,7 @@ bool HcomAllGatherKernel::Launch(const std::vector &inputs, const st return false; } const char *tag = "Hccl-AllGather"; - hcclResult_t ret = + HcclResult ret = hcom_all_gather(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], nullptr, stream_ptr); if (ret != HCCL_SUCCESS) { MS_LOG(ERROR) << "HcomAllGatherKernelOp : hcom_all_gather fail, return: " << static_cast(ret); diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc index 4d8cd690f08..62a4868d33d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc @@ -32,8 +32,8 @@ bool HcomAllReduceKernel::Launch(const std::vector &inputs, const st return false; } const char *tag = "Hccl-AllReduce"; - hcclResult_t ret = hcom_all_reduce(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], - op_type_, nullptr, stream_ptr); + HcclResult ret = hcom_all_reduce(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], + op_type_, nullptr, stream_ptr); if (ret != HCCL_SUCCESS) { MS_LOG(ERROR) << "HcomAllReduceKernelOp : hcom_all_reduce fail, return: " << static_cast(ret); return false; diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc index 08281c6030e..08a2415eaf7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc @@ -33,8 +33,8 @@ bool HcomAllReduceScatterKernel::Launch(const std::vector &inputs, return false; } const char *tag = "Hccl-ReduceScatter"; - hcclResult_t ret = hcom_reduce_scatter(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], - op_type_, nullptr, stream_ptr); + HcclResult ret = hcom_reduce_scatter(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], + op_type_, nullptr, stream_ptr); if (ret != HCCL_SUCCESS) { MS_LOG(ERROR) << "HcomReduceScatterOp : hcom_reduce_scatter fail, return: " << static_cast(ret); return false; diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc index 8dec3669804..4c0bb55c80f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc @@ -43,7 +43,7 @@ bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector *data_type_list) { +bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector *data_type_list) { MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(data_type_list); for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { @@ -56,14 +56,14 @@ bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vectorbegin(), data_type_list->end(), - [&type_base](hcclDataType_t type) { return type != type_base; })) { + [&type_base](HcclDataType type) { return type != type_base; })) { MS_LOG(ERROR) << "hccl have different data type"; return false; } return true; } -bool HcomUtil::GetHcclOpSize(const hcclDataType_t &data_type, const vector &shape, size_t *size) { +bool HcomUtil::GetHcclOpSize(const HcclDataType &data_type, const vector &shape, size_t *size) { MS_EXCEPTION_IF_NULL(size); size_t tmp_size = 1; uint32_t type_size = 4; @@ -81,7 +81,7 @@ bool HcomUtil::GetHcclOpSize(const hcclDataType_t &data_type, const vector &data_type_list, +bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector &data_type_list, const vector> &shape_list, uint64_t *total_count) { MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(total_count); @@ -143,7 +143,7 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector(primitive->GetAttr("op")); string hcom_op_type(hcom_op_type_get); if (hcom_op_type == "min") { - *op_type = HCCL_REP_OP_MIN; + *op_type = HCCL_REDUCE_MIN; } else if (hcom_op_type == "max") { - *op_type = HCCL_REP_OP_MAX; + *op_type = HCCL_REDUCE_MAX; } else if (hcom_op_type == "prod") { - *op_type = HCCL_REP_OP_PROD; + *op_type = HCCL_REDUCE_PROD; } else if (hcom_op_type == "sum") { - *op_type = HCCL_REP_OP_SUM; + *op_type = HCCL_REDUCE_SUM; } else { MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [" << hcom_op_type << "] not support!"; return false; diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h index 2979fc5ed8f..3e1843561a7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h @@ -24,6 +24,7 @@ #include "ir/dtype.h" #include "hccl/base.h" #include "utils/contract.h" +#include "hccl/hccl_types.h" namespace mindspore { using std::map; @@ -36,31 +37,31 @@ constexpr auto kBroadcast = "Broadcast"; constexpr auto kReduceScatter = "ReduceScatter"; /* Correspondence between data_type and hcom data type in Ascend */ -static map CONST_OP_HCOM_DATA_TYPE_MAP = { - {TypeId::kNumberTypeFloat32, HCCL_DATA_TYPE_FLOAT}, - {TypeId::kNumberTypeFloat16, HCCL_DATA_TYPE_HALF}, +static map CONST_OP_HCOM_DATA_TYPE_MAP = { + {TypeId::kNumberTypeFloat32, HCCL_DATA_TYPE_FP32}, + {TypeId::kNumberTypeFloat16, HCCL_DATA_TYPE_FP16}, {TypeId::kNumberTypeInt8, HCCL_DATA_TYPE_INT8}, - {TypeId::kNumberTypeInt32, HCCL_DATA_TYPE_INT}, + {TypeId::kNumberTypeInt32, HCCL_DATA_TYPE_INT32}, }; /* Correspondence between data_type and occupied byte size in hcom */ -static map CONST_OP_HCOM_DATA_TYPE_SIZE_MAP = { - {HCCL_DATA_TYPE_FLOAT, sizeof(float)}, - {HCCL_DATA_TYPE_HALF, sizeof(float) / 2}, +static map CONST_OP_HCOM_DATA_TYPE_SIZE_MAP = { + {HCCL_DATA_TYPE_FP32, sizeof(float)}, + {HCCL_DATA_TYPE_FP16, sizeof(float) / 2}, {HCCL_DATA_TYPE_INT8, sizeof(int8_t)}, - {HCCL_DATA_TYPE_INT, sizeof(int32_t)}, + {HCCL_DATA_TYPE_INT32, sizeof(int32_t)}, }; class HcomUtil { public: static bool GetKernelInputShape(const AnfNodePtr &anf_node, vector> *hccl_kernel_shape_list); static bool GetKernelOutputShape(const AnfNodePtr &anf_node, vector> *hccl_kernel_shape_list); - static bool GetHcomDataType(const AnfNodePtr &anf_node, vector *data_type_list); - static bool GetHcclOpSize(const hcclDataType_t &data_type, const vector &shape, size_t *size); - static bool GetHcomTypeSize(const hcclDataType_t &data_type, uint32_t *size); - static bool GetHcomCount(const AnfNodePtr &anf_node, const vector &data_type_list, + static bool GetHcomDataType(const AnfNodePtr &anf_node, vector *data_type_list); + static bool GetHcclOpSize(const HcclDataType &data_type, const vector &shape, size_t *size); + static bool GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size); + static bool GetHcomCount(const AnfNodePtr &anf_node, const vector &data_type_list, const vector> &shape_list, uint64_t *total_count); - static bool GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_type); + static bool GetHcomOperationType(const AnfNodePtr &anf_node, HcclReduceOp *op_type); static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id); static void GetHcomGroup(NotNull anf_node, NotNull group); }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc index 883c7e6d0ef..a18c9d3e7e6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc @@ -120,6 +120,10 @@ Status DeviceQueueOp::SendDataToAscend() { TensorRow currRow; for (int row_id = 0; row_id < current_buffer->NumRows(); row_id++) { RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow)); + while (stop_send_) { + MS_LOG(DEBUG) << "stop_send flag is set, waiting for continue signal..."; + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost); if (status == TdtStatus::FAILED) { if (stop_send_) { @@ -153,10 +157,6 @@ Status DeviceQueueOp::SendDataToAscend() { } if (current_buffer->eoe() && send_epoch_end_) { TensorRow currRow; - while (stop_send_) { - MS_LOG(DEBUG) << "stop_send flag is set, waiting for continue signal..."; - std::this_thread::sleep_for(std::chrono::microseconds(100)); - } auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE); if (status == TdtStatus::FAILED) { diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 1a2ea749dbe..795dc987897 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -642,7 +642,7 @@ bool AscendKernelRuntime::HcclInit() { return false; } MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str; - hcclResult_t res = hcom_init(full_path, rank_id_str.c_str()); + HcclResult res = hcom_init(full_path, rank_id_str.c_str()); free(full_path); if (res != HCCL_SUCCESS) { MS_LOG(ERROR) << "Hcom init failed, res is " << static_cast(res); @@ -658,7 +658,7 @@ bool AscendKernelRuntime::DestroyHccl() { MS_LOG(INFO) << "Hccl is not enable, no need to close."; return true; } - hcclResult_t res = hcom_destroy(); + HcclResult res = hcom_destroy(); if (res != HCCL_SUCCESS) { MS_LOG(ERROR) << "Hccl destroy failed"; return false; diff --git a/mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.cc b/mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.cc index dba71edfd32..4b8c97689ff 100644 --- a/mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.cc +++ b/mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.cc @@ -20,6 +20,7 @@ #include "hccl/hcom.h" #include "utils/log_adapter.h" +#include "hccl/hccl_types.h" #include "utils/utils.h" constexpr auto kHcomBroadcast = "hcom_broadcast_"; @@ -32,7 +33,7 @@ namespace device { namespace ascend { namespace tasksink { bool RuntimeUtils::HcomBindModel(rtModel_t model, rtStream_t stream) { - hcclResult_t ret = hcom_bind_model(model, stream); + HcclResult ret = hcom_bind_model(model, stream); if (ret != HCCL_SUCCESS) { MS_LOG(ERROR) << "Call hcom_bind_model failed, ret: 0x" << static_cast(ret); return false; @@ -41,7 +42,7 @@ bool RuntimeUtils::HcomBindModel(rtModel_t model, rtStream_t stream) { } bool RuntimeUtils::HcomUnbindModel(rtModel_t model) { - hcclResult_t ret = hcom_unbind_model(model); + HcclResult ret = hcom_unbind_model(model); if (ret != HCCL_SUCCESS) { MS_LOG(ERROR) << "Call hcom_unbind_model failed, ret: 0x" << static_cast(ret); return false; @@ -52,14 +53,14 @@ bool RuntimeUtils::HcomUnbindModel(rtModel_t model) { bool RuntimeUtils::HcomDistribute(const std::shared_ptr &task_info, rtStream_t stream) { MS_LOG(INFO) << "hccl distribute start"; MS_EXCEPTION_IF_NULL(task_info); - hcclResult_t ret; + HcclResult ret; static uint32_t task_counter = 0; auto hccl_group = task_info->group(); if (task_info->hccl_type() == kBroadcastOpName) { // call hcom broadcast interface to run op const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0); ret = hcom_broadcast(tag_broadcast.c_str(), task_info->input_data_addr(), static_cast(task_info->count()), - static_cast(task_info->data_type()), static_cast(task_info->root_id()), + static_cast(task_info->data_type()), static_cast(task_info->root_id()), hccl_group.c_str(), stream); if (ret != HCCL_SUCCESS) { MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast(ret); @@ -69,7 +70,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr &task_info // call hcom allgather interface to run op const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0); ret = hcom_all_gather(tag_all_gather.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), - static_cast(task_info->count()), static_cast(task_info->data_type()), + static_cast(task_info->count()), static_cast(task_info->data_type()), hccl_group.c_str(), stream); if (ret != HCCL_SUCCESS) { MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret; @@ -79,8 +80,8 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr &task_info // call hcom allreduce interface to run op const string tag_all_reduce = kHcomAllReduce + std::to_string(task_counter++) + kUnderline + std::to_string(0); ret = hcom_all_reduce(tag_all_reduce.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), - static_cast(task_info->count()), static_cast(task_info->data_type()), - static_cast(task_info->op_type()), hccl_group.c_str(), stream); + static_cast(task_info->count()), static_cast(task_info->data_type()), + static_cast(task_info->op_type()), hccl_group.c_str(), stream); if (ret != HCCL_SUCCESS) { MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret; return false; @@ -90,8 +91,8 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr &task_info const string tag_reduce_scatter = kHcomReduceScatter + std::to_string(task_counter++) + kUnderline + std::to_string(0); ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), - static_cast(task_info->count()), static_cast(task_info->data_type()), - static_cast(task_info->op_type()), hccl_group.c_str(), stream); + static_cast(task_info->count()), static_cast(task_info->data_type()), + static_cast(task_info->op_type()), hccl_group.c_str(), stream); if (ret != HCCL_SUCCESS) { MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret; return false; diff --git a/tests/st/model_zoo_tests/wide_and_deep/test_wide_and_deep_auto_parallel.py b/tests/st/model_zoo_tests/wide_and_deep/test_wide_and_deep_auto_parallel.py index 7856f6803a2..c2d68f82406 100644 --- a/tests/st/model_zoo_tests/wide_and_deep/test_wide_and_deep_auto_parallel.py +++ b/tests/st/model_zoo_tests/wide_and_deep/test_wide_and_deep_auto_parallel.py @@ -13,13 +13,13 @@ # limitations under the License. # ============================================================================ import os -import pytest +# import pytest -@pytest.mark.level0 -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_single +# @pytest.mark.level0 +# @pytest.mark.platform_x86_ascend_training +# @pytest.mark.platform_arm_ascend_training +# @pytest.mark.env_single def test_wide_and_deep(): sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/run_wide_and_deep_auto_parallel.sh") diff --git a/tests/ut/cpp/stub/hccl/hccl_stub.cc b/tests/ut/cpp/stub/hccl/hccl_stub.cc index 56f62910f21..e1d5d29398a 100644 --- a/tests/ut/cpp/stub/hccl/hccl_stub.cc +++ b/tests/ut/cpp/stub/hccl/hccl_stub.cc @@ -24,97 +24,97 @@ extern "C" { #endif /* 集合通信域初始化 */ -hcclResult_t hcom_init(const char *rank_table, const char *identify) { return HCCL_SUCCESS; } +HcclResult hcom_init(const char *rank_table, const char *identify) { return HCCL_SUCCESS; } /* 解析ranktable for python */ -hcclResult_t hcom_rank_info_init(const char *rank_table, const char *identify, u32 device_id) { return HCCL_SUCCESS; } +HcclResult hcom_rank_info_init(const char *rank_table, const char *identify, u32 device_id) { return HCCL_SUCCESS; } /* 集合通信域销毁 */ -hcclResult_t hcom_destroy(void) { return HCCL_SUCCESS; } +HcclResult hcom_destroy(void) { return HCCL_SUCCESS; } /* 绑定model */ -hcclResult_t hcom_bind_model(rtModel_t model, rtStream_t stream) { return HCCL_SUCCESS; } +HcclResult hcom_bind_model(rtModel_t model, rtStream_t stream) { return HCCL_SUCCESS; } /* 绑解定model */ -hcclResult_t hcom_unbind_model(rtModel_t model) { return HCCL_SUCCESS; } +HcclResult hcom_unbind_model(rtModel_t model) { return HCCL_SUCCESS; } /* allgather功能实现 */ -hcclResult_t hcom_all_gather(const char *tag, void *inputPtr, void *outputPtr, u64 inputCount, hcclDataType_t dataType, +HcclResult hcom_all_gather(const char *tag, void *inputPtr, void *outputPtr, u64 inputCount, HcclDataType dataType, const char *group, rtStream_t stream) { return HCCL_SUCCESS; } /* allreduce功能实现 */ -hcclResult_t hcom_all_reduce(const char *tag, void *inputPtr, void *outputPtr, u64 count, hcclDataType_t dataType, - hcclRedOp_t op, const char *group, rtStream_t stream) { +HcclResult hcom_all_reduce(const char *tag, void *inputPtr, void *outputPtr, u64 count, HcclDataType dataType, + HcclReduceOp op, const char *group, rtStream_t stream) { return HCCL_SUCCESS; } /* broadcas功能实现 */ -hcclResult_t hcom_broadcast(const char *tag, void *ptr, u64 count, hcclDataType_t dataType, u32 root, const char *group, +HcclResult hcom_broadcast(const char *tag, void *ptr, u64 count, HcclDataType dataType, u32 root, const char *group, rtStream_t stream) { return HCCL_SUCCESS; } /* reduce_scatter功能实现 */ -hcclResult_t hcom_reduce_scatter(const char *tag, void *inputPtr, void *outputPtr, u64 count, hcclDataType_t dataType, - hcclRedOp_t op, const char *group, rtStream_t stream) { +HcclResult hcom_reduce_scatter(const char *tag, void *inputPtr, void *outputPtr, u64 count, HcclDataType dataType, + HcclReduceOp op, const char *group, rtStream_t stream) { return HCCL_SUCCESS; } /* 获取group内的rank个数 */ -hcclResult_t hcom_get_rank_size(const char *group, u32 *rankSize) { return HCCL_SUCCESS; } +HcclResult hcom_get_rank_size(const char *group, u32 *rankSize) { return HCCL_SUCCESS; } /* python获取上云场景内的rank个数 */ -hcclResult_t hcom_python_get_rank_size(u32 *rankSize) { return HCCL_SUCCESS; } +HcclResult hcom_python_get_rank_size(u32 *rankSize) { return HCCL_SUCCESS; } /* 获取本rank的id */ -hcclResult_t hcom_get_rank_id(const char *group, u32 *rankId) { return HCCL_SUCCESS; } +HcclResult hcom_get_rank_id(const char *group, u32 *rankId) { return HCCL_SUCCESS; } /* 获取本rank的id */ -hcclResult_t hcom_python_get_rank_id(u32 *rankId) { return HCCL_SUCCESS; } +HcclResult hcom_python_get_rank_id(u32 *rankId) { return HCCL_SUCCESS; } /* 获取本rank的id */ -hcclResult_t hcom_get_world_rank_from_group_rank(const char *group, u32 groupRank, u32 *worldRank) { +HcclResult hcom_get_world_rank_from_group_rank(const char *group, u32 groupRank, u32 *worldRank) { return HCCL_SUCCESS; } /* 获取通信域的rank个数 */ -hcclResult_t hcom_get_group_rank_from_world_rank(u32 worldRank, const char *group, u32 *groupRank) { +HcclResult hcom_get_group_rank_from_world_rank(u32 worldRank, const char *group, u32 *groupRank) { return HCCL_SUCCESS; } /* 创建group */ -hcclResult_t hcom_create_group(const char *group, u32 rankNum, u32 *rankIds) { return HCCL_SUCCESS; } +HcclResult hcom_create_group(const char *group, u32 rankNum, u32 *rankIds) { return HCCL_SUCCESS; } /* 销毁group */ -hcclResult_t hcom_destroy_group(const char *group) { return HCCL_SUCCESS; } +HcclResult hcom_destroy_group(const char *group) { return HCCL_SUCCESS; } /* 发送消息 */ -hcclResult_t hcom_send(const char *tag, void *inputPtr, u64 count, hcclDataType_t dataType, u32 destRank, u32 srTag, +HcclResult hcom_send(const char *tag, void *inputPtr, u64 count, HcclDataType dataType, u32 destRank, u32 srTag, const char *group, rtStream_t stream) { return HCCL_SUCCESS; } /* 接收消息 */ -hcclResult_t hcom_receive(const char *tag, void *outputPtr, u64 count, hcclDataType_t dataType, u32 srcRank, u32 srTag, +HcclResult hcom_receive(const char *tag, void *outputPtr, u64 count, HcclDataType dataType, u32 srcRank, u32 srTag, const char *group, rtStream_t stream) { return HCCL_SUCCESS; } /* 获取梯度参数切分方案 */ -hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 maxSegmentNum, +HcclResult hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 maxSegmentNum, u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force, OriginalGraphShapeType shapeType) { return HCCL_SUCCESS; } /* 连通性检测 */ -hcclResult_t hcom_connectivity_detection(s32 *result) { return HCCL_SUCCESS; } +HcclResult hcom_connectivity_detection(s32 *result) { return HCCL_SUCCESS; } -hcclResult_t hcom_set_split_strategy_by_index(const char *group, u32 segmentNum, const u32 *IdxList) { +HcclResult hcom_set_split_strategy_by_index(const char *group, u32 segmentNum, const u32 *IdxList) { return HCCL_SUCCESS; } -hcclResult_t hcom_set_split_strategy_by_size(const char *group, u32 segmentNum, const float *sizeList) { +HcclResult hcom_set_split_strategy_by_size(const char *group, u32 segmentNum, const float *sizeList) { return HCCL_SUCCESS; } #ifdef __cplusplus