Support PyNative Data Parallel

This commit is contained in:
caifubi 2020-12-01 11:14:50 +08:00
parent b273a46c53
commit 04d8cd5d8b
7 changed files with 203 additions and 7 deletions

@ -1 +1 @@
Subproject commit 191dc747993dec992eceb1ebfcd8afc3dcd35acc
Subproject commit 38a40dd232346e9a47850e237259ea6f43eeb35b

View File

@ -0,0 +1,66 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/hccl/hccl_context.h"
#include "utils/log_adapter.h"
#include "hccl/hccl.h"
constexpr auto kHcclConfigFile = "MINDSPORE_HCCL_CONFIG_PATH";
namespace mindspore {
namespace kernel {
std::string GetRankId() {
std::string rank_id_str;
rank_id_str = std::getenv("RANK_ID");
if (rank_id_str.empty()) {
MS_LOG(ERROR) << "Get hccl rankid failed, please set env RANK_ID";
}
return rank_id_str;
}
bool HcclContext::InitHccl() {
if (hccl_comm_ != nullptr) {
return true;
}
auto config_file = std::getenv(kHcclConfigFile);
if (config_file == nullptr) {
MS_LOG(ERROR) << "Get hccl config file failed";
return false;
}
rank_id_ = std::stoi(GetRankId());
auto hccl_result = HcclCommInitClusterInfo(config_file, rank_id_, &hccl_comm_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcclCommInitClusterInfo failed, ret:" << hccl_result;
return false;
}
MS_LOG(INFO) << "HcclCommInitClusterInfo success";
return true;
}
bool HcclContext::Finalize() {
if (hccl_comm_ == nullptr) {
return true;
}
auto hccl_result = HcclCommDestroy(hccl_comm_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcclComm destroy failed, ret:" << hccl_result;
return false;
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_CONTEXT_H_
#define MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_CONTEXT_H_
#include <string>
#include "hccl/hccl_types.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace kernel {
class HcclContext {
public:
static HcclContext &GetInstance() {
static HcclContext instance;
return instance;
}
bool InitHccl();
bool Finalize();
HcclComm hccl_comm() { return hccl_comm_; }
private:
HcclContext() = default;
~HcclContext() = default;
DISABLE_COPY_AND_ASSIGN(HcclContext);
HcclComm hccl_comm_{nullptr};
int rank_id_{0};
uint32_t device_id_{0};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCCL_CONTEXT_H_

View File

@ -17,13 +17,27 @@
#include "backend/kernel_compiler/hccl/hcom_all_reduce.h"
#include <memory>
#include "utils/ms_context.h"
#include "backend/kernel_compiler/hccl/hccl_context.h"
#include "external/hccl/hccl.h"
namespace mindspore {
namespace kernel {
bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> & /*inputs*/,
const std::vector<AddressPtr> & /*workspace*/,
const std::vector<AddressPtr> & /*outputs*/, void * /*stream_ptr*/) {
MS_LOG(INFO) << "HcomAllReduce launch";
bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> & /*workspace*/,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
MS_LOG(INFO) << "HcclAllReduce launch";
if (inputs.size() != 1 || outputs.size() != 1) {
MS_LOG(ERROR) << "AllReduce input output size must be 1";
return false;
}
MS_EXCEPTION_IF_NULL(inputs[0]);
MS_EXCEPTION_IF_NULL(outputs[0]);
MS_EXCEPTION_IF_NULL(stream_ptr);
auto hccl_result = HcclAllReduce(inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], op_type_,
HcclContext::GetInstance().hccl_comm(), stream_ptr);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcclAllReduce faled, ret:" << hccl_result;
return false;
}
return true;
}
} // namespace kernel

View File

@ -60,6 +60,7 @@
#include "utils/config_manager.h"
#include "runtime/device/ascend/profiling/reporter/op_name_task_stream_reporter.h"
#include "runtime/hccl_adapter/hccl_adapter.h"
#include "backend/kernel_compiler/hccl/hccl_context.h"
using ge::model_runner::ModelRunner;
using mindspore::device::ascend::ProfilingManager;
@ -801,6 +802,11 @@ bool AscendKernelRuntime::ResetDevice() {
stream_ = nullptr;
}
if (!DestroySingleOpHccl()) {
MS_LOG(ERROR) << "Destroy hccl failed";
return false;
}
if (rt_context_ != nullptr) {
auto ret = rtCtxDestroy(rt_context_);
if (ret != RT_ERROR_NONE) {
@ -818,6 +824,10 @@ bool AscendKernelRuntime::ResetDevice() {
bool AscendKernelRuntime::HcclInit() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
MS_LOG(INFO) << "PyNative hccl init";
return kernel::HcclContext::GetInstance().InitHccl();
}
if (!context::IsTsdOpened(context_ptr)) {
MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open";
}
@ -850,9 +860,31 @@ bool AscendKernelRuntime::HcclInit() {
return true;
}
bool AscendKernelRuntime::DestroySingleOpHccl() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
return true;
}
if (!NeedDestroyHccl()) {
MS_LOG(INFO) << "Hccl is not enable, no need to close.";
return true;
}
if (!kernel::HcclContext::GetInstance().Finalize()) {
MS_LOG(ERROR) << "Hccl finalize failed";
return false;
}
MS_LOG(INFO) << "Hccl destroy successful.";
context_ptr->set_param<bool>(MS_CTX_ENABLE_HCCL, false);
return true;
}
bool AscendKernelRuntime::DestroyHccl() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return true;
}
if (!NeedDestroyHccl()) {
MS_LOG(INFO) << "Hccl is not enable, no need to close.";
return true;
@ -861,13 +893,11 @@ bool AscendKernelRuntime::DestroyHccl() {
if (!HcclExecutorManager::GetInstance().Finalize()) {
MS_LOG(ERROR) << "Dynamic Shape Hccl Finalize Failed";
}
bool res = hccl::FinalizeHccl();
if (!res) {
MS_LOG(ERROR) << "Hccl destroy failed";
return false;
}
MS_LOG(INFO) << "Hccl destroy successful.";
context_ptr->set_param<bool>(MS_CTX_ENABLE_HCCL, false);
return true;

View File

@ -70,6 +70,7 @@ class AscendKernelRuntime : public KernelRuntime {
bool HcclInit();
bool NeedDestroyHccl();
bool DestroyHccl();
bool DestroySingleOpHccl();
void InnerSetContext();
void ClearGraphModelMap();

View File

@ -18,6 +18,7 @@
/* HCCL基础数据类型声明 */
#include "hccl/hcom.h"
#include "hccl/hccl.h"
#ifdef __cplusplus
extern "C" {
@ -117,6 +118,43 @@ HcclResult hcom_set_split_strategy_by_index(const char *group, u32 segmentNum, c
HcclResult hcom_set_split_strategy_by_size(const char *group, u32 segmentNum, const float *sizeList) {
return HCCL_SUCCESS;
}
HcclResult HcclCommInitClusterInfo(const char *clusterInfo, uint32_t rank, HcclComm *comm) {
return HCCL_SUCCESS;
}
HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo) {
return HCCL_SUCCESS;
}
HcclResult HcclCommInitRootInfo(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm) {
return HCCL_SUCCESS;
}
HcclResult HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op,
HcclComm comm, aclrtStream stream) {
return HCCL_SUCCESS;
}
HcclResult HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, HcclComm comm,
aclrtStream stream) {
return HCCL_SUCCESS;
}
HcclResult HcclReduceScatter(void *sendBuf, void *recvBuf, uint64_t recvCount, HcclDataType dataType,
HcclReduceOp op, HcclComm comm, aclrtStream stream) {
return HCCL_SUCCESS;
}
HcclResult HcclAllGather(void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType, HcclComm comm,
aclrtStream stream) {
return HCCL_SUCCESS;
}
HcclResult HcclCommDestroy(HcclComm comm) {
return HCCL_SUCCESS;
}
#ifdef __cplusplus
}
#endif