forked from mindspore-Ecosystem/mindspore
support host reduce scatter and mpi config
This commit is contained in:
parent
b096383386
commit
6034f9c1e2
2
build.sh
2
build.sh
|
@ -49,7 +49,7 @@ usage()
|
|||
echo " -Q Enable dump memory, default off"
|
||||
echo " -D Enable dumping of function graph ir, default on"
|
||||
echo " -z Compile dataset & mindrecord, default on"
|
||||
echo " -M Enable MPI and NCCL for GPU training, default on"
|
||||
echo " -M Enable MPI and NCCL for GPU training, gpu default on"
|
||||
echo " -V Specify the minimum required cuda version, default CUDA 9.2"
|
||||
echo " -I Compile predict, default off"
|
||||
echo " -K Compile with AKG, default off"
|
||||
|
|
|
@ -14,17 +14,19 @@ endif ()
|
|||
|
||||
if (ENABLE_CPU)
|
||||
file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc")
|
||||
if (ENABLE_MPI)
|
||||
# _ms_mpi
|
||||
set_property(SOURCE "gpu/mpi/mpi_initializer.cc"
|
||||
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
|
||||
pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc")
|
||||
target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi)
|
||||
else ()
|
||||
if (NOT ENABLE_MPI)
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc")
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
if (ENABLE_MPI)
|
||||
# _ms_mpi
|
||||
set_property(SOURCE "gpu/mpi/mpi_initializer.cc"
|
||||
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
|
||||
pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc")
|
||||
target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi)
|
||||
endif ()
|
||||
|
||||
# gpu
|
||||
if (ENABLE_GPU)
|
||||
file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc" "gpu/*.cu")
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "device/ascend/ascend_device_address.h"
|
||||
#include "device/cpu/mpi/mpi_adapter.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "utils/mpi/mpi_config.h"
|
||||
#include "device/ascend/profiling/profiling_manager.h"
|
||||
#include "hccl/hcom.h"
|
||||
#include "common/trans.h"
|
||||
|
@ -510,19 +511,35 @@ bool AscendKernelRuntime::HcclInit() {
|
|||
MS_LOG(ERROR) << "file path " << config_path_str << " does not exist";
|
||||
return false;
|
||||
}
|
||||
const char *identify = nullptr;
|
||||
#ifdef ENABLE_MPI
|
||||
int rank_id = device::cpu::MPIAdapter::Instance().GetRankId();
|
||||
const char *offset = std::getenv("RANK_OFFSET");
|
||||
if (offset != nullptr) {
|
||||
int rank_offset = std::stoi(offset);
|
||||
rank_id += rank_offset;
|
||||
std::string rank_id_tmp;
|
||||
auto mpi_config_ptr = MpiConfig::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(mpi_config_ptr);
|
||||
if (mpi_config_ptr->enable_mpi()) {
|
||||
int rank_id = device::cpu::MPIAdapter::Instance().GetRankId();
|
||||
const char *offset = std::getenv("RANK_OFFSET");
|
||||
if (offset != nullptr) {
|
||||
try {
|
||||
int rank_offset = std::stoi(offset);
|
||||
rank_id += rank_offset;
|
||||
} catch (std::invalid_argument) {
|
||||
MS_LOG(EXCEPTION) << "stoi invalid argument:" << offset;
|
||||
} catch (std::out_of_range) {
|
||||
MS_LOG(EXCEPTION) << "stoi out_of_range:" << offset;
|
||||
}
|
||||
}
|
||||
rank_id_tmp = std::to_string(rank_id);
|
||||
identify = rank_id_tmp.c_str();
|
||||
} else {
|
||||
identify = std::getenv("RANK_ID");
|
||||
}
|
||||
const char *identify = reinterpret_cast<const char *>(std::to_string(rank_id).c_str());
|
||||
#else
|
||||
const char *identify = std::getenv("RANK_ID");
|
||||
identify = std::getenv("RANK_ID");
|
||||
#endif
|
||||
if (identify == nullptr) {
|
||||
MS_LOG(ERROR) << "get hccl rankid failed, please set env RANK_ID";
|
||||
free(full_path);
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << identify;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "device/cpu/mpi/mpi_adapter.h"
|
||||
#include <algorithm>
|
||||
#include "utils/mpi/mpi_config.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -35,6 +36,20 @@ MPI_Op GetMpiOp(const std::string &op_type) {
|
|||
MS_LOG(EXCEPTION) << "unsupport op_type:" << op_type;
|
||||
return MPI_SUM;
|
||||
}
|
||||
|
||||
int GetScatterIndex(int rankid, const std::vector<int> &ranks_group) {
|
||||
int scatter_index = -1;
|
||||
for (size_t i = 0; i < ranks_group.size(); ++i) {
|
||||
if (ranks_group[i] == rankid) {
|
||||
scatter_index = static_cast<int>(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (scatter_index == -1) {
|
||||
MS_LOG(EXCEPTION) << "process rankid " << rankid << " does not in the input rank group!";
|
||||
}
|
||||
return scatter_index;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MPIAdapter::MPIAdapter() : rank_id_(0), rank_size_(0), comm_group_world_(MPI_GROUP_NULL) { Init(); }
|
||||
|
@ -65,6 +80,11 @@ void MPIAdapter::Init() {
|
|||
if (init) {
|
||||
return;
|
||||
}
|
||||
auto mpi_config_ptr = MpiConfig::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(mpi_config_ptr);
|
||||
if (!mpi_config_ptr->enable_mpi()) {
|
||||
MS_LOG(EXCEPTION) << "MPI is disabled now!Please enable mpi with mpi config first.";
|
||||
}
|
||||
int init_flag = 0;
|
||||
if (MPI_Initialized(&init_flag) != MPI_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Check mpi initialized fail!";
|
||||
|
@ -123,7 +143,7 @@ MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) {
|
|||
return group;
|
||||
}
|
||||
|
||||
bool MPIAdapter::ReduceScatter(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
|
||||
bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
|
||||
const std::string &op_type) {
|
||||
if (ranks_group.empty()) {
|
||||
MS_LOG(ERROR) << "input rank group is empty!";
|
||||
|
@ -159,6 +179,51 @@ bool MPIAdapter::ReduceScatter(float *input, float *output, const std::vector<in
|
|||
return result;
|
||||
}
|
||||
|
||||
bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t data_num,
|
||||
const std::string &op_type, float *output) {
|
||||
int scatter_index = GetScatterIndex(rank_id_, ranks_group);
|
||||
auto group = AddGroup(ranks_group);
|
||||
if (group == MPI_GROUP_NULL) {
|
||||
MS_LOG(EXCEPTION) << "Get mpi group fail!rankid:" << rank_id_;
|
||||
}
|
||||
MPI_Comm comm;
|
||||
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm);
|
||||
if (comm == MPI_COMM_NULL) {
|
||||
MS_LOG(EXCEPTION) << "create mpi comm fail!rankid:" << rank_id_;
|
||||
}
|
||||
|
||||
MPI_Win window;
|
||||
auto ret = MPI_Win_create(input, data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window);
|
||||
if (ret != MPI_SUCCESS) {
|
||||
MS_LOG(ERROR) << "mpi window create fail! ret = " << ret;
|
||||
return false;
|
||||
}
|
||||
MPI_Win_fence(0, window);
|
||||
for (size_t i = 0; i < ranks_group.size(); ++i) {
|
||||
int remote_rank = ranks_group[i];
|
||||
if (rank_id_ == remote_rank) {
|
||||
continue;
|
||||
}
|
||||
auto op = GetMpiOp(op_type);
|
||||
ret = MPI_Accumulate(input + i * data_num, data_num, MPI_FLOAT, remote_rank, i * data_num, data_num, MPI_FLOAT, op,
|
||||
window);
|
||||
if (ret != MPI_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "mpi accumulate " << op_type << " fail!ret = " << ret;
|
||||
}
|
||||
}
|
||||
MPI_Win_fence(0, window);
|
||||
if (output != nullptr) {
|
||||
auto data_size = data_num * sizeof(float);
|
||||
auto copy_ret = memcpy_s(output, data_size, input + scatter_index * data_num, data_size);
|
||||
if (copy_ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "copy output memory fail!";
|
||||
}
|
||||
}
|
||||
MPI_Win_free(&window);
|
||||
MPI_Comm_free(&comm);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MPIAdapter::AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
|
||||
if (ranks_group.empty()) {
|
||||
MS_LOG(ERROR) << "input rank group is empty!";
|
||||
|
|
|
@ -32,8 +32,10 @@ class MPIAdapter {
|
|||
~MPIAdapter();
|
||||
static MPIAdapter &Instance();
|
||||
int GetRankId() const;
|
||||
bool ReduceScatter(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
|
||||
bool ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
|
||||
const std::string &op_type = kOpTypeSum);
|
||||
bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t data_num,
|
||||
const std::string &op_type = kOpTypeSum, float *output = nullptr);
|
||||
bool AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num);
|
||||
|
||||
private:
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "pipeline/parse/python_adapter.h"
|
||||
#include "utils/summary/event_writer.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "utils/mpi/mpi_config.h"
|
||||
#include "parallel/context.h"
|
||||
#include "parallel/device_manager.h"
|
||||
#include "parallel/costmodel_context.h"
|
||||
|
@ -147,6 +148,11 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.")
|
||||
.def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.");
|
||||
|
||||
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
|
||||
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
|
||||
.def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.")
|
||||
.def("set_enable_mpi", &mindspore::MpiConfig::set_enable_mpi, "Set whether to enable mpi.");
|
||||
|
||||
(void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext")
|
||||
.def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.")
|
||||
.def("get_device_num", &ParallelContext::device_num, "Get device num.")
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "utils/mpi/mpi_config.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
std::shared_ptr<MpiConfig> MpiConfig::instance_ = nullptr;
|
||||
|
||||
std::shared_ptr<MpiConfig> MpiConfig::GetInstance() {
|
||||
if (instance_ == nullptr) {
|
||||
MS_LOG(DEBUG) << "Create new mpi config instance.";
|
||||
instance_.reset(new (std::nothrow) MpiConfig());
|
||||
}
|
||||
return instance_;
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_
|
||||
#define MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_
|
||||
#include <memory>
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
class MpiConfig {
|
||||
public:
|
||||
~MpiConfig() = default;
|
||||
MpiConfig(const MpiConfig &) = delete;
|
||||
MpiConfig &operator=(const MpiConfig &) = delete;
|
||||
|
||||
static std::shared_ptr<MpiConfig> GetInstance();
|
||||
|
||||
void set_enable_mpi(bool flag) { enable_mpi_ = flag; }
|
||||
bool enable_mpi() const { return enable_mpi_; }
|
||||
|
||||
private:
|
||||
MpiConfig() : enable_mpi_(false) {}
|
||||
|
||||
static std::shared_ptr<MpiConfig> instance_;
|
||||
bool enable_mpi_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_
|
|
@ -25,6 +25,7 @@ from mindspore._c_expression import MSContext
|
|||
from mindspore._checkparam import args_type_check
|
||||
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
|
||||
_reset_auto_parallel_context
|
||||
from mindspore.parallel.mpi._mpi_config import _set_mpi_config, _get_mpi_config
|
||||
|
||||
__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
|
||||
'get_auto_parallel_context', 'reset_auto_parallel_context']
|
||||
|
@ -566,3 +567,40 @@ def get_context(attr_key):
|
|||
if not hasattr(_context(), attr_key):
|
||||
raise ValueError("Get context keyword %s is not recognized!" % attr_key)
|
||||
return getattr(_context(), attr_key)
|
||||
|
||||
@args_type_check(enable_mpi=bool)
|
||||
def set_mpi_config(**kwargs):
|
||||
"""
|
||||
Sets mpi config for running environment.
|
||||
|
||||
mpi config should be configured before running your program. If there is no configuration,
|
||||
mpi moudle will be disabled by default.
|
||||
|
||||
Note:
|
||||
Attribute name is required for setting attributes.
|
||||
|
||||
Args:
|
||||
enable_mpi (bool): Whether to enable mpi. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in mpi config.
|
||||
|
||||
Examples:
|
||||
>>> mpiconfig.set_mpi_config(enable_mpi=True)
|
||||
"""
|
||||
_set_mpi_config(**kwargs)
|
||||
|
||||
def get_mpi_config(attr_key):
|
||||
"""
|
||||
Gets mpi config attribute value according to the input key.
|
||||
|
||||
Args:
|
||||
attr_key (str): The key of the attribute.
|
||||
|
||||
Returns:
|
||||
Object, The value of given attribute key.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in context.
|
||||
"""
|
||||
return _get_mpi_config(attr_key)
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
The MPI config, used to configure the MPI environment.
|
||||
"""
|
||||
import threading
|
||||
from mindspore._c_expression import MpiConfig
|
||||
from mindspore._checkparam import args_type_check
|
||||
|
||||
class _MpiConfig:
|
||||
"""
|
||||
_MpiConfig is the config tool for controlling MPI
|
||||
|
||||
Note:
|
||||
Create a config through instantiating MpiConfig object is not recommended.
|
||||
should use MpiConfig() to get the config since MpiConfig is singleton.
|
||||
"""
|
||||
_instance = None
|
||||
_instance_lock = threading.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self._mpiconfig_handle = MpiConfig.get_instance()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance_lock.acquire()
|
||||
cls._instance = object.__new__(cls)
|
||||
cls._instance_lock.release()
|
||||
return cls._instance
|
||||
|
||||
def __getattribute__(self, attr):
|
||||
value = object.__getattribute__(self, attr)
|
||||
if attr == "_mpiconfig_handle" and value is None:
|
||||
raise ValueError("mpiconfig handle is none in MpiConfig!!!")
|
||||
return value
|
||||
|
||||
@property
|
||||
def enable_mpi(self):
|
||||
return self._mpiconfig_handle.get_enable_mpi()
|
||||
|
||||
@enable_mpi.setter
|
||||
def enable_mpi(self, enable_mpi):
|
||||
self._mpiconfig_handle.set_enable_mpi(enable_mpi)
|
||||
|
||||
_k_mpi_config = None
|
||||
def _mpi_config():
|
||||
"""
|
||||
Get the global mpi config, if mpi config is not created, create a new one.
|
||||
|
||||
Returns:
|
||||
_MpiConfig, the global mpi config.
|
||||
"""
|
||||
global _k_mpi_config
|
||||
if _k_mpi_config is None:
|
||||
_k_mpi_config = _MpiConfig()
|
||||
return _k_mpi_config
|
||||
|
||||
@args_type_check(enable_mpi=bool)
|
||||
def _set_mpi_config(**kwargs):
|
||||
"""
|
||||
Sets mpi config for running environment.
|
||||
|
||||
mpi config should be configured before running your program. If there is no configuration,
|
||||
mpi moudle will be disabled by default.
|
||||
|
||||
Note:
|
||||
Attribute name is required for setting attributes.
|
||||
|
||||
Args:
|
||||
enable_mpi (bool): Whether to enable mpi. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in mpi config.
|
||||
|
||||
Examples:
|
||||
>>> mpiconfig.set_mpi_config(enable_mpi=True)
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if not hasattr(_mpi_config(), key):
|
||||
raise ValueError("Set mpi config keyword %s is not recognized!" % key)
|
||||
setattr(_mpi_config(), key, value)
|
||||
|
||||
|
||||
def _get_mpi_config(attr_key):
|
||||
"""
|
||||
Gets mpi config attribute value according to the input key.
|
||||
|
||||
Args:
|
||||
attr_key (str): The key of the attribute.
|
||||
|
||||
Returns:
|
||||
Object, The value of given attribute key.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in context.
|
||||
"""
|
||||
if not hasattr(_mpi_config(), attr_key):
|
||||
raise ValueError("Get context keyword %s is not recognized!" % attr_key)
|
||||
return getattr(_mpi_config(), attr_key)
|
|
@ -23,9 +23,10 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.ops import operations as P
|
||||
import mindspore._ms_mpi as mpi
|
||||
# run comand:
|
||||
# mpirun -np 3 python test_reduce_scatter.py
|
||||
# mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_reduce_scatter.py
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
context.set_mpi_config(enable_mpi=True)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -46,14 +47,19 @@ class AllGatherNet(nn.Cell):
|
|||
return self.hostallgather(x)
|
||||
|
||||
def test_net_reduce_scatter():
|
||||
x = np.ones(12).astype(np.float32) * 0.1
|
||||
x = np.arange(12).astype(np.float32) * 0.1
|
||||
|
||||
reducescatter = Net()
|
||||
rankid = mpi.get_rank_id()
|
||||
print("self rankid:", rankid)
|
||||
output = reducescatter(Tensor(x, mstype.float32))
|
||||
print("output:\n", output)
|
||||
expect_result = np.ones(4).astype(np.float32) * 0.3
|
||||
if rankid == 0:
|
||||
expect_result = np.arange(4).astype(np.float32) * 0.3
|
||||
if rankid == 1:
|
||||
expect_result = np.arange(4, 8).astype(np.float32) * 0.3
|
||||
if rankid == 2:
|
||||
expect_result = np.arange(8, 12).astype(np.float32) * 0.3
|
||||
diff = abs(output.asnumpy() - expect_result)
|
||||
error = np.ones(shape=expect_result.shape) * 1.0e-6
|
||||
assert np.all(diff < error)
|
||||
|
@ -61,7 +67,7 @@ def test_net_reduce_scatter():
|
|||
allgather = AllGatherNet()
|
||||
allgather_output = allgather(output)
|
||||
print("allgather result:\n", allgather_output)
|
||||
expect_allgather_result = np.ones(12).astype(np.float32) * 0.3
|
||||
expect_allgather_result = np.arange(12).astype(np.float32) * 0.3
|
||||
diff = abs(allgather_output.asnumpy() - expect_allgather_result)
|
||||
error = np.ones(shape=expect_allgather_result.shape) * 1.0e-6
|
||||
assert np.all(diff < error)
|
||||
|
|
Loading…
Reference in New Issue