forked from mindspore-Ecosystem/mindspore
!17269 parallel_strategy_checkpoint pclint fix master
From: @yao_yf Reviewed-by: @stsuteng,@yangzhenzhang Signed-off-by: @stsuteng
This commit is contained in:
commit
5a0fc52809
|
@ -62,14 +62,14 @@ Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *
|
|||
|
||||
size_t group_num = LongToSize(parallel_group_map.parallel_group_item_size());
|
||||
for (size_t i = 0; i < group_num; ++i) {
|
||||
straspb::ParallelGroupItem parallel_group_item = parallel_group_map.parallel_group_item(SizeToLong(i));
|
||||
straspb::ParallelGroupItem parallel_group_item = parallel_group_map.parallel_group_item(SizeToInt(i));
|
||||
std::string group_name = parallel_group_item.group_name();
|
||||
|
||||
straspb::ParallelGroupRanks parallel_group_ranks = parallel_group_item.parallel_group_ranks();
|
||||
size_t rank_num = LongToSize(parallel_group_ranks.dim_size());
|
||||
std::vector<uint32_t> ranks;
|
||||
for (size_t j = 0; j < rank_num; ++j) {
|
||||
uint32_t rank = parallel_group_ranks.dim(SizeToLong(j));
|
||||
uint32_t rank = parallel_group_ranks.dim(SizeToInt(j));
|
||||
ranks.push_back(rank);
|
||||
}
|
||||
|
||||
|
@ -96,18 +96,18 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
|
|||
input.close();
|
||||
size_t node_num = LongToSize(parallel_strategy_map.parallel_strategy_item_size());
|
||||
for (size_t i = 0; i < node_num; i++) {
|
||||
straspb::ParallelStrategyItem parallel_strategy_item = parallel_strategy_map.parallel_strategy_item(SizeToLong(i));
|
||||
straspb::ParallelStrategyItem parallel_strategy_item = parallel_strategy_map.parallel_strategy_item(SizeToInt(i));
|
||||
std::string node_name = parallel_strategy_item.node_name();
|
||||
straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys();
|
||||
auto stage = (int64_t)parallel_strategys.stage();
|
||||
size_t strategys_num = LongToSize(parallel_strategys.parallel_strategy_size());
|
||||
Strategys strategy_inputs;
|
||||
for (size_t j = 0; j < strategys_num; j++) {
|
||||
straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToLong(j));
|
||||
straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j));
|
||||
Dimensions dimension;
|
||||
size_t dim_num = LongToSize(parallel_strategy.dim_size());
|
||||
for (size_t k = 0; k < dim_num; k++) {
|
||||
dimension.push_back(parallel_strategy.dim(SizeToLong(k)));
|
||||
dimension.push_back(parallel_strategy.dim(SizeToInt(k)));
|
||||
}
|
||||
strategy_inputs.push_back(dimension);
|
||||
}
|
||||
|
@ -122,7 +122,7 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
|
|||
Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
|
||||
ManualShapeMap *manual_shape_map) {
|
||||
straspb::ParallelStrategyMap parallel_strategy_map;
|
||||
parallel_strategy_map.set_current_stage(LongToUlong(++current_stage_));
|
||||
parallel_strategy_map.set_current_stage(UlongToUint(LongToUlong(++current_stage_)));
|
||||
for (auto &node_stra : strategy_map) {
|
||||
straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item();
|
||||
MS_EXCEPTION_IF_NULL(parallel_strategy_item);
|
||||
|
@ -130,12 +130,12 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
|
|||
straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys();
|
||||
MS_EXCEPTION_IF_NULL(parallel_strategys);
|
||||
MS_EXCEPTION_IF_NULL(node_stra.second);
|
||||
parallel_strategys->set_stage(LongToUlong(node_stra.second->GetInputStage()));
|
||||
parallel_strategys->set_stage(UlongToUint(LongToUlong(node_stra.second->GetInputStage())));
|
||||
for (auto &dims : node_stra.second->GetInputDim()) {
|
||||
straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy();
|
||||
MS_EXCEPTION_IF_NULL(parallel_strategy);
|
||||
for (auto stra_dim : dims) {
|
||||
parallel_strategy->add_dim(LongToUlong(stra_dim));
|
||||
parallel_strategy->add_dim(UlongToUint(LongToUlong(stra_dim)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -148,12 +148,12 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
|
|||
straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix();
|
||||
MS_EXCEPTION_IF_NULL(dev_matrix);
|
||||
for (auto dev_dim : tensor_layout->device_arrangement().array()) {
|
||||
dev_matrix->add_dim(LongToUlong(dev_dim));
|
||||
dev_matrix->add_dim(UlongToUint(LongToUlong(dev_dim)));
|
||||
}
|
||||
straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map();
|
||||
MS_EXCEPTION_IF_NULL(tensor_map);
|
||||
for (auto map_dim : tensor_layout->tensor_map().array()) {
|
||||
tensor_map->add_dim(map_dim);
|
||||
tensor_map->add_dim(LongToInt(map_dim));
|
||||
}
|
||||
straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape();
|
||||
straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset();
|
||||
|
@ -163,7 +163,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
|
|||
param_split_shape->add_dim(dim_pair.first);
|
||||
indices_offset->add_dim(dim_pair.second);
|
||||
}
|
||||
parallel_layouts->set_field(tensor_layout->get_field_size());
|
||||
parallel_layouts->set_field(LongToInt(tensor_layout->get_field_size()));
|
||||
parallel_layouts->set_opt_weight_shard_step(tensor_layout->opt_weight_shard_step());
|
||||
parallel_layouts->set_opt_weight_shard_size(tensor_layout->opt_weight_shard_size());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue