forked from mindspore-Ecosystem/mindspore
set stage id
This commit is contained in:
parent
bedc733e42
commit
41d925b68a
|
@ -27,15 +27,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
DeviceManagerPtr g_device_manager = nullptr;
|
DeviceManagerPtr g_device_manager = nullptr;
|
||||||
|
|
||||||
Stage::Stage(const std::vector<mindspore::parallel::Device> &devices, int64_t num, int64_t rank)
|
|
||||||
: devices_(devices), number_(num), rank_(rank) {
|
|
||||||
gm_ = GroupManager();
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: '-1' indicates ERROR
|
|
||||||
int64_t Stage::global_rank(Group *g) const { return ((g == nullptr) ? rank_ : -1); }
|
|
||||||
|
|
||||||
bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &backend,
|
bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &backend,
|
||||||
const std::vector<int64_t> &stage) {
|
const std::vector<int64_t> &stage) {
|
||||||
if (device_num <= 0) {
|
if (device_num <= 0) {
|
||||||
|
@ -143,36 +134,23 @@ std::shared_ptr<Device> GetListMemberByIndex(size_t index, const std::vector<std
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// E.g. devices = [4, 5, 2, 1, 7, 8, 10], stage_map = [4, 3],
|
// E.g. devices = [0, 1, 2, 3, 4, 5, 6, 7], stage_map = [4, 4],
|
||||||
// therefore the stage_devices_ = [[4, 5, 2, 1], [7, 8, 10]].
|
// therefore the stage_devices_ = [[0, 1, 2, 3], [4, 5, 6, 7]].
|
||||||
Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, const RankList &stage_map,
|
Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, const RankList &stage_map,
|
||||||
const std::string &backend) {
|
const std::string &backend) {
|
||||||
auto dev_it = devices.begin();
|
|
||||||
auto stage_it = stage_map.begin();
|
|
||||||
int64_t sum = 0;
|
|
||||||
|
|
||||||
if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) {
|
if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) {
|
||||||
MS_LOG(ERROR) << "Invalid backend: " << backend;
|
MS_LOG(ERROR) << "Invalid backend: " << backend;
|
||||||
return Status::FAILED;
|
return Status::FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (; stage_it != stage_map.end(); ++stage_it) {
|
for (auto &dev : devices) {
|
||||||
sum += (*stage_it);
|
std::shared_ptr<Device> one = std::make_shared<Device>(dev);
|
||||||
}
|
|
||||||
if (LongToSize(sum) != devices.size()) {
|
|
||||||
MS_LOG(ERROR) << "The number of 'devices' in the list is not equal to the mentioned "
|
|
||||||
<< "size of 'stage_map'";
|
|
||||||
return Status::FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (; dev_it != devices.end(); ++dev_it) {
|
|
||||||
std::shared_ptr<Device> one = std::make_shared<Device>(*dev_it);
|
|
||||||
devices_.push_back(one);
|
devices_.push_back(one);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t global_index = 0;
|
size_t global_index = 0;
|
||||||
for (stage_it = stage_map.begin(); stage_it != stage_map.end(); ++stage_it) {
|
for (auto &stage : stage_map) {
|
||||||
int64_t num_device = *stage_it;
|
int64_t num_device = stage;
|
||||||
if (num_device > MAX_DEVICE_NUM) {
|
if (num_device > MAX_DEVICE_NUM) {
|
||||||
MS_LOG(ERROR) << "The number of 'devices' in a stage must not be greater than " << MAX_DEVICE_NUM;
|
MS_LOG(ERROR) << "The number of 'devices' in a stage must not be greater than " << MAX_DEVICE_NUM;
|
||||||
return Status::FAILED;
|
return Status::FAILED;
|
||||||
|
@ -189,29 +167,14 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank,
|
||||||
stage_devices_.push_back(curr_dev_list);
|
stage_devices_.push_back(curr_dev_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
global_index = 0;
|
|
||||||
for (stage_it = stage_map.begin(); stage_it != stage_map.end(); ++stage_it) {
|
|
||||||
int64_t num_device = *stage_it;
|
|
||||||
if (num_device > MAX_DEVICE_NUM) {
|
|
||||||
MS_LOG(ERROR) << "The number of 'devices' in a stage must be less than " << MAX_DEVICE_NUM;
|
|
||||||
return Status::FAILED;
|
|
||||||
}
|
|
||||||
if (num_device <= 0) {
|
|
||||||
MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive";
|
|
||||||
return Status::FAILED;
|
|
||||||
}
|
|
||||||
std::vector<Device> curr_dev_list;
|
|
||||||
for (int64_t i = 0; i < num_device; ++i) {
|
|
||||||
curr_dev_list.push_back(*GetListMemberByIndex(global_index, devices_));
|
|
||||||
global_index++;
|
|
||||||
}
|
|
||||||
std::shared_ptr<Stage> new_stage = std::make_shared<Stage>(curr_dev_list);
|
|
||||||
stages_.push_back(new_stage);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<Device> dev = std::make_shared<Device>(global_device_rank);
|
std::shared_ptr<Device> dev = std::make_shared<Device>(global_device_rank);
|
||||||
device_ = dev;
|
device_ = dev;
|
||||||
|
|
||||||
set_global_rank(global_device_rank);
|
set_global_rank(global_device_rank);
|
||||||
|
set_stage_num(static_cast<const int64_t>(stage_map.size()));
|
||||||
|
int64_t stage_id = global_device_rank / static_cast<const int64_t>(devices.size() / stage_map.size());
|
||||||
|
set_stage_id(stage_id);
|
||||||
|
|
||||||
backend_ = backend;
|
backend_ = backend;
|
||||||
|
|
||||||
if (backend == HCCL_BACKEND) {
|
if (backend == HCCL_BACKEND) {
|
||||||
|
@ -221,25 +184,11 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank,
|
||||||
} else {
|
} else {
|
||||||
gm_.set_world_group(UNDEFINED_WORLD_GROUP);
|
gm_.set_world_group(UNDEFINED_WORLD_GROUP);
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "The device num: " << devices.size() << "rank id: " << global_device_rank
|
MS_LOG(INFO) << "The device num: " << devices.size() << ", rank id: " << global_device_rank
|
||||||
<< "the backend: " << backend;
|
<< ", the backend: " << backend << ", the stage num: " << stage_num() << ", the stage id: " << stage_id;
|
||||||
return Status::SUCCESS;
|
return Status::SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Stage> DeviceManager::GetStageById(int64_t stage_id) {
|
|
||||||
std::shared_ptr<Stage> res;
|
|
||||||
if (LongToSize(stage_id) >= stages_.size()) {
|
|
||||||
MS_LOG(ERROR) << "the 'stage_id': " << stage_id << ", is out of the scope of 'stage_devices_': " << stages_.size();
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
int64_t index = 0;
|
|
||||||
for (auto &stage : stages_) {
|
|
||||||
if (index == stage_id) return stage;
|
|
||||||
index++;
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const {
|
RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const {
|
||||||
if (LongToSize(stage_id) >= stage_devices_.size())
|
if (LongToSize(stage_id) >= stage_devices_.size())
|
||||||
MS_LOG(ERROR) << "the 'stage_id': " << stage_id
|
MS_LOG(ERROR) << "the 'stage_id': " << stage_id
|
||||||
|
|
|
@ -46,28 +46,6 @@ using DeviceManagerPtr = std::shared_ptr<DeviceManager>;
|
||||||
// 'g_device_manager' is the globally unique manager to manage the devices.
|
// 'g_device_manager' is the globally unique manager to manage the devices.
|
||||||
extern DeviceManagerPtr g_device_manager;
|
extern DeviceManagerPtr g_device_manager;
|
||||||
|
|
||||||
class Stage {
|
|
||||||
// This class is used in pipeline-parallelization. Available devices are partitioned into multiple stages.
|
|
||||||
// Currently, the function of pipeline-parallelization and this class are NOT implemented.
|
|
||||||
public:
|
|
||||||
explicit Stage(std::vector<Device> devices) : devices_(std::move(devices)), number_(0), rank_(0) {
|
|
||||||
gm_ = GroupManager();
|
|
||||||
}
|
|
||||||
Stage(const std::vector<mindspore::parallel::Device> &devices, int64_t num, int64_t rank);
|
|
||||||
~Stage() = default;
|
|
||||||
|
|
||||||
int64_t GetStageNum() const { return number_; }
|
|
||||||
size_t GetDevicesNum() const { return devices_.size(); }
|
|
||||||
std::vector<Device> GetDevicesList() { return devices_; }
|
|
||||||
int64_t global_rank(Group *g) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::vector<Device> devices_;
|
|
||||||
int64_t number_;
|
|
||||||
int64_t rank_;
|
|
||||||
GroupManager gm_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// This method is used for initializing the global DeviceManager 'g_device_manager',
|
// This method is used for initializing the global DeviceManager 'g_device_manager',
|
||||||
// arguments including 'device_num' and 'global_rank'
|
// arguments including 'device_num' and 'global_rank'
|
||||||
bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &backend, const std::vector<int64_t> &stage);
|
bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &backend, const std::vector<int64_t> &stage);
|
||||||
|
@ -79,7 +57,7 @@ std::string HashName(const std::string &rank_list_name);
|
||||||
class DeviceManager {
|
class DeviceManager {
|
||||||
// This class is used to manage the abstract devices, including group-related and stage-related management.
|
// This class is used to manage the abstract devices, including group-related and stage-related management.
|
||||||
public:
|
public:
|
||||||
DeviceManager() : local_rank_(0), global_rank_(0), stage_num_(0) { gm_ = GroupManager(); }
|
DeviceManager() : local_rank_(0), global_rank_(0), stage_num_(1), stage_id_(0) { gm_ = GroupManager(); }
|
||||||
~DeviceManager() = default;
|
~DeviceManager() = default;
|
||||||
|
|
||||||
Status Init(const RankList &devices, int64_t local_device, const RankList &stage_map, const std::string &backend);
|
Status Init(const RankList &devices, int64_t local_device, const RankList &stage_map, const std::string &backend);
|
||||||
|
@ -94,15 +72,20 @@ class DeviceManager {
|
||||||
std::string GenerateGroupNameByRanks(RankList dev_ranks);
|
std::string GenerateGroupNameByRanks(RankList dev_ranks);
|
||||||
Group CreateGroup(const std::string &group_name, const std::vector<Device> &devices);
|
Group CreateGroup(const std::string &group_name, const std::vector<Device> &devices);
|
||||||
Group CreateGroup(const RankList &dev_ranks);
|
Group CreateGroup(const RankList &dev_ranks);
|
||||||
std::shared_ptr<Stage> GetStageById(int64_t stage_id);
|
|
||||||
|
|
||||||
size_t DeviceNum() const { return devices_.size(); }
|
size_t DeviceNum() const { return devices_.size(); }
|
||||||
|
|
||||||
int64_t GetStageNum() const { return static_cast<const int64_t>(stage_devices_.size()); }
|
int64_t stage_num() const { return stage_num_; }
|
||||||
|
void set_stage_num(int64_t num) { stage_num_ = num; }
|
||||||
|
|
||||||
|
int64_t stage_id() const { return stage_id_; }
|
||||||
|
void set_stage_id(int64_t id) { stage_id_ = id; }
|
||||||
|
|
||||||
|
std::string backend() const { return backend_; }
|
||||||
|
|
||||||
int64_t global_rank() const { return global_rank_; }
|
int64_t global_rank() const { return global_rank_; }
|
||||||
std::string backend() const { return backend_; }
|
|
||||||
void set_global_rank(int64_t global_rank) { global_rank_ = global_rank; }
|
void set_global_rank(int64_t global_rank) { global_rank_ = global_rank; }
|
||||||
|
|
||||||
void Clear();
|
void Clear();
|
||||||
std::string world_group() const { return gm_.world_group(); }
|
std::string world_group() const { return gm_.world_group(); }
|
||||||
std::string FindRankListNameByHashName(const std::string &hash_name);
|
std::string FindRankListNameByHashName(const std::string &hash_name);
|
||||||
|
@ -112,7 +95,6 @@ class DeviceManager {
|
||||||
// each stage has a list of devices
|
// each stage has a list of devices
|
||||||
std::vector<std::vector<int64_t>> stage_devices_;
|
std::vector<std::vector<int64_t>> stage_devices_;
|
||||||
std::shared_ptr<Device> device_;
|
std::shared_ptr<Device> device_;
|
||||||
std::vector<std::shared_ptr<Stage>> stages_;
|
|
||||||
GroupManager gm_;
|
GroupManager gm_;
|
||||||
std::string backend_;
|
std::string backend_;
|
||||||
|
|
||||||
|
@ -123,6 +105,7 @@ class DeviceManager {
|
||||||
int64_t local_rank_;
|
int64_t local_rank_;
|
||||||
int64_t global_rank_;
|
int64_t global_rank_;
|
||||||
int64_t stage_num_;
|
int64_t stage_num_;
|
||||||
|
int64_t stage_id_;
|
||||||
};
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -148,7 +148,7 @@ class EmbeddingLookup(Cell):
|
||||||
manual_shapes (tuple): The accompaniment array in field slice mode.
|
manual_shapes (tuple): The accompaniment array in field slice mode.
|
||||||
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
||||||
or None. Default: None
|
or None. Default: None
|
||||||
sparse (bool): Using sparse mode. Currently, only support sparse mode when target is CPU. Default: True.
|
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
|
||||||
Inputs:
|
Inputs:
|
||||||
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
||||||
Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table,
|
Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table,
|
||||||
|
|
|
@ -83,7 +83,7 @@ TEST_F(TestDeviceManager, test_dm_init_AND_get_device_list) {
|
||||||
ASSERT_EQ(dm_.Init(dev_list, local_dev, stage_map, "hccl"), Status::SUCCESS);
|
ASSERT_EQ(dm_.Init(dev_list, local_dev, stage_map, "hccl"), Status::SUCCESS);
|
||||||
|
|
||||||
ASSERT_EQ(dm_.DeviceNum(), 4);
|
ASSERT_EQ(dm_.DeviceNum(), 4);
|
||||||
ASSERT_EQ(dm_.GetStageNum(), (int32_t)(2));
|
ASSERT_EQ(dm_.stage_num(), (int32_t)(2));
|
||||||
|
|
||||||
RankList dev_list_0 = dm_.GetDeviceListByStageId(0);
|
RankList dev_list_0 = dm_.GetDeviceListByStageId(0);
|
||||||
RankList dev_list_1 = dm_.GetDeviceListByStageId(1);
|
RankList dev_list_1 = dm_.GetDeviceListByStageId(1);
|
||||||
|
@ -98,11 +98,6 @@ TEST_F(TestDeviceManager, test_dm_init_AND_get_device_list) {
|
||||||
ASSERT_EQ((*it), int32_t(1));
|
ASSERT_EQ((*it), int32_t(1));
|
||||||
it++;
|
it++;
|
||||||
ASSERT_EQ((*it), int32_t(0));
|
ASSERT_EQ((*it), int32_t(0));
|
||||||
|
|
||||||
std::shared_ptr<Stage> stage_0 = dm_.GetStageById(0);
|
|
||||||
ASSERT_EQ(stage_0->GetDevicesNum(), size_t(2));
|
|
||||||
std::shared_ptr<Stage> stage_1 = dm_.GetStageById(1);
|
|
||||||
ASSERT_EQ(stage_1->GetDevicesNum(), size_t(2));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestDeviceManager, test_CreateNewDeviceByRank) {
|
TEST_F(TestDeviceManager, test_CreateNewDeviceByRank) {
|
||||||
|
@ -123,5 +118,28 @@ TEST_F(TestDeviceManager, test_CreateDeviceListByRankList) {
|
||||||
ASSERT_EQ(it->rank(), int32_t(1));
|
ASSERT_EQ(it->rank(), int32_t(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TestDeviceManager, test_StageID) {
|
||||||
|
RankList dev_list;
|
||||||
|
RankList stage_map;
|
||||||
|
int32_t local_dev = 2;
|
||||||
|
|
||||||
|
dev_list.push_back(0);
|
||||||
|
dev_list.push_back(1);
|
||||||
|
dev_list.push_back(2);
|
||||||
|
dev_list.push_back(3);
|
||||||
|
|
||||||
|
stage_map.push_back(2);
|
||||||
|
stage_map.push_back(2);
|
||||||
|
ASSERT_EQ(dm_.Init(dev_list, local_dev, stage_map, "hccl"), Status::SUCCESS);
|
||||||
|
|
||||||
|
ASSERT_EQ(dm_.DeviceNum(), 4);
|
||||||
|
ASSERT_EQ(dm_.stage_num(), 2);
|
||||||
|
ASSERT_EQ(dm_.stage_id(), 1);
|
||||||
|
|
||||||
|
RankList dev_list_0 = dm_.GetDeviceListByStageId(0);
|
||||||
|
RankList dev_list_1 = dm_.GetDeviceListByStageId(1);
|
||||||
|
ASSERT_EQ(dev_list_0.size(), 2);
|
||||||
|
ASSERT_EQ(dev_list_1.size(), 2);
|
||||||
|
}
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue