diff --git a/graphengine b/graphengine index 191dc747993..38a40dd2323 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit 191dc747993dec992eceb1ebfcd8afc3dcd35acc +Subproject commit 38a40dd232346e9a47850e237259ea6f43eeb35b diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_context.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_context.cc new file mode 100644 index 00000000000..ed1f79e37b6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_context.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hccl_context.h" +#include "utils/log_adapter.h" +#include "hccl/hccl.h" + +constexpr auto kHcclConfigFile = "MINDSPORE_HCCL_CONFIG_PATH"; + +namespace mindspore { +namespace kernel { +std::string GetRankId() { + std::string rank_id_str; + rank_id_str = std::getenv("RANK_ID"); + if (rank_id_str.empty()) { + MS_LOG(ERROR) << "Get hccl rankid failed, please set env RANK_ID"; + } + return rank_id_str; +} + +bool HcclContext::InitHccl() { + if (hccl_comm_ != nullptr) { + return true; + } + auto config_file = std::getenv(kHcclConfigFile); + if (config_file == nullptr) { + MS_LOG(ERROR) << "Get hccl config file failed"; + return false; + } + rank_id_ = std::stoi(GetRankId()); + + auto hccl_result = HcclCommInitClusterInfo(config_file, rank_id_, &hccl_comm_); + if (hccl_result != HCCL_SUCCESS) { + MS_LOG(ERROR) << "HcclCommInitClusterInfo failed, ret:" << hccl_result; + return false; + } + MS_LOG(INFO) << "HcclCommInitClusterInfo success"; + return true; +} + +bool HcclContext::Finalize() { + if (hccl_comm_ == nullptr) { + return true; + } + auto hccl_result = HcclCommDestroy(hccl_comm_); + if (hccl_result != HCCL_SUCCESS) { + MS_LOG(ERROR) << "HcclComm destroy failed, ret:" << hccl_result; + return false; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_context.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_context.h new file mode 100644 index 00000000000..9ca54e7e5ab --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_context.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_CONTEXT_H_ +#define MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_CONTEXT_H_ + +#include +#include "hccl/hccl_types.h" +#include "utils/ms_utils.h" + +namespace mindspore { +namespace kernel { +class HcclContext { + public: + static HcclContext &GetInstance() { + static HcclContext instance; + return instance; + } + + bool InitHccl(); + bool Finalize(); + HcclComm hccl_comm() { return hccl_comm_; } + + private: + HcclContext() = default; + ~HcclContext() = default; + DISABLE_COPY_AND_ASSIGN(HcclContext); + HcclComm hccl_comm_{nullptr}; + int rank_id_{0}; + uint32_t device_id_{0}; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_CONTEXT_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc index 7a1dae02d10..20e03b682fa 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc @@ -17,13 +17,27 @@ #include "backend/kernel_compiler/hccl/hcom_all_reduce.h" #include #include "utils/ms_context.h" +#include "backend/kernel_compiler/hccl/hccl_context.h" +#include "external/hccl/hccl.h" namespace mindspore { namespace kernel { -bool HcomAllReduceKernel::Launch(const std::vector & /*inputs*/, - const std::vector & /*workspace*/, - const std::vector & /*outputs*/, void * /*stream_ptr*/) { - MS_LOG(INFO) << "HcomAllReduce launch"; +bool HcomAllReduceKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector &outputs, void *stream_ptr) { + MS_LOG(INFO) << "HcclAllReduce launch"; + if (inputs.size() != 1 || outputs.size() != 1) { + MS_LOG(ERROR) << "AllReduce input output size must be 1"; + return false; + } + MS_EXCEPTION_IF_NULL(inputs[0]); + MS_EXCEPTION_IF_NULL(outputs[0]); + MS_EXCEPTION_IF_NULL(stream_ptr); + auto hccl_result = HcclAllReduce(inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], op_type_, + HcclContext::GetInstance().hccl_comm(), stream_ptr); + if (hccl_result != HCCL_SUCCESS) { + MS_LOG(ERROR) << "HcclAllReduce faled, ret:" << hccl_result; + return false; + } return true; } } // namespace kernel diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index c1d19886717..4daf1969ddd 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -60,6 +60,7 @@ #include "utils/config_manager.h" #include "runtime/device/ascend/profiling/reporter/op_name_task_stream_reporter.h" #include "runtime/hccl_adapter/hccl_adapter.h" +#include "backend/kernel_compiler/hccl/hccl_context.h" using ge::model_runner::ModelRunner; using mindspore::device::ascend::ProfilingManager; @@ -801,6 +802,11 @@ bool AscendKernelRuntime::ResetDevice() { stream_ = nullptr; } + if (!DestroySingleOpHccl()) { + MS_LOG(ERROR) << "Destroy hccl failed"; + return false; + } + if (rt_context_ != nullptr) { auto ret = rtCtxDestroy(rt_context_); if (ret != RT_ERROR_NONE) { @@ -818,6 +824,10 @@ bool AscendKernelRuntime::ResetDevice() { bool AscendKernelRuntime::HcclInit() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { + MS_LOG(INFO) << "PyNative hccl init"; + return kernel::HcclContext::GetInstance().InitHccl(); + } if (!context::IsTsdOpened(context_ptr)) { MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open"; } @@ -850,9 +860,31 @@ bool AscendKernelRuntime::HcclInit() { return true; } +bool AscendKernelRuntime::DestroySingleOpHccl() { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode) { + return true; + } + if (!NeedDestroyHccl()) { + MS_LOG(INFO) << "Hccl is not enable, no need to close."; + return true; + } + if (!kernel::HcclContext::GetInstance().Finalize()) { + MS_LOG(ERROR) << "Hccl finalize failed"; + return false; + } + MS_LOG(INFO) << "Hccl destroy successful."; + context_ptr->set_param(MS_CTX_ENABLE_HCCL, false); + return true; +} + bool AscendKernelRuntime::DestroyHccl() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { + return true; + } if (!NeedDestroyHccl()) { MS_LOG(INFO) << "Hccl is not enable, no need to close."; return true; @@ -861,13 +893,11 @@ bool AscendKernelRuntime::DestroyHccl() { if (!HcclExecutorManager::GetInstance().Finalize()) { MS_LOG(ERROR) << "Dynamic Shape Hccl Finalize Failed"; } - bool res = hccl::FinalizeHccl(); if (!res) { MS_LOG(ERROR) << "Hccl destroy failed"; return false; } - MS_LOG(INFO) << "Hccl destroy successful."; context_ptr->set_param(MS_CTX_ENABLE_HCCL, false); return true; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index c7822040bec..4bbea9db8e7 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -70,6 +70,7 @@ class AscendKernelRuntime : public KernelRuntime { bool HcclInit(); bool NeedDestroyHccl(); bool DestroyHccl(); + bool DestroySingleOpHccl(); void InnerSetContext(); void ClearGraphModelMap(); diff --git a/tests/ut/cpp/stub/hccl/hccl_stub.cc b/tests/ut/cpp/stub/hccl/hccl_stub.cc index e1d5d29398a..9601a1fdb9c 100644 --- a/tests/ut/cpp/stub/hccl/hccl_stub.cc +++ b/tests/ut/cpp/stub/hccl/hccl_stub.cc @@ -18,6 +18,7 @@ /* HCCL基础数据类型声明 */ #include "hccl/hcom.h" +#include "hccl/hccl.h" #ifdef __cplusplus extern "C" { @@ -117,6 +118,43 @@ HcclResult hcom_set_split_strategy_by_index(const char *group, u32 segmentNum, c HcclResult hcom_set_split_strategy_by_size(const char *group, u32 segmentNum, const float *sizeList) { return HCCL_SUCCESS; } + +HcclResult HcclCommInitClusterInfo(const char *clusterInfo, uint32_t rank, HcclComm *comm) { + return HCCL_SUCCESS; +} + +HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo) { + return HCCL_SUCCESS; +} + +HcclResult HcclCommInitRootInfo(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm) { + return HCCL_SUCCESS; +} + +HcclResult HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op, + HcclComm comm, aclrtStream stream) { + return HCCL_SUCCESS; +} + +HcclResult HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, HcclComm comm, + aclrtStream stream) { + return HCCL_SUCCESS; +} + +HcclResult HcclReduceScatter(void *sendBuf, void *recvBuf, uint64_t recvCount, HcclDataType dataType, + HcclReduceOp op, HcclComm comm, aclrtStream stream) { + return HCCL_SUCCESS; +} + +HcclResult HcclAllGather(void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType, HcclComm comm, + aclrtStream stream) { + return HCCL_SUCCESS; +} + +HcclResult HcclCommDestroy(HcclComm comm) { + return HCCL_SUCCESS; +} + #ifdef __cplusplus } #endif