Merge pull request !25469 from bichaoyang/master
This commit is contained in:
i-robot 2021-10-30 08:00:45 +00:00 committed by Gitee
commit bf20022e18
10 changed files with 29 additions and 21 deletions

View File

@ -29,7 +29,6 @@
namespace mindspore {
namespace parallel {
void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
const std::vector<std::vector<std::string>> &input_tensor_names,
@ -217,7 +216,7 @@ Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
if (axis >= SizeToLong(s.size())) {
MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range.";
}
s[axis] = 1;
s[LongToSize(axis)] = 1;
strategies.push_back(s);
return strategies;
@ -232,7 +231,7 @@ Strategys PrepareGatherV2P(const std::vector<std::shared_ptr<OperatorInfo>> &ops
index[i] = SizeToLong(i);
}
std::sort(index.begin(), index.end(), [&output_shape](const int64_t &a, const int64_t &b) {
return (output_shape[a + 1] > output_shape[b + 1]);
return (output_shape[LongToSize(a + 1)] > output_shape[LongToSize(b + 1)]);
});
std::transform(std::begin(index), std::end(index), std::begin(index), [](int64_t x) { return x + 1; });
index.insert(index.begin(), 0);
@ -294,7 +293,7 @@ Dimensions PrepareGatherV2POutputStrategy(const std::vector<std::shared_ptr<Oper
auto output_shape = ops[incoming_op_index]->outputs_tensor_info()[0].shape();
Dimensions index(output_shape.size() - 1, 0);
for (size_t i = 0; i < index.size(); i++) {
index[i] = LongToSize(i);
index[i] = SizeToLong(i);
}
std::sort(index.begin(), index.end(),
[&output_shape](const size_t &a, const size_t &b) { return (output_shape[a + 1] > output_shape[b + 1]); });
@ -797,7 +796,7 @@ Dimensions ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<Ope
}
for (size_t i = 0; i < (size_t)stra_dim_list.size(); i++) {
s_Squeeze.push_back(s[stra_dim_list[i]]);
s_Squeeze.push_back(s[LongToSize(stra_dim_list[i])]);
}
return s_Squeeze;
}
@ -1076,7 +1075,7 @@ Dimensions ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
// Check whether the operator can be divided by the current strategy.
Strategys CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Dimensions basic_stra) {
const Dimensions basic_stra) {
Dimensions s_empty = {};
Strategys stra;
@ -1167,7 +1166,7 @@ Dimensions ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<Ope
size_t cut = 1;
for (size_t i = 0; i < s_Squeeze.size(); i++) {
cut *= s_Squeeze[i];
cut *= LongToSize(s_Squeeze[i]);
}
if (cut != g_device_manager->DeviceNum()) {
s_Squeeze.clear();

View File

@ -264,7 +264,7 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
for (size_t j = node_in->size(); j > 0; j--) {
bool IsEliminated = (index_list->at(node_in->at(j - 1)) == SIZE_MAX);
if (IsEliminated) {
node_in->erase(node_in->begin() + SizeToLong(j) - 1);
(void)node_in->erase(node_in->begin() + SizeToLong(j) - 1);
} else {
node_in->at(j - 1) = index_list->at(node_in->at(j - 1));
}
@ -273,7 +273,7 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
for (size_t j = node_out->size(); j > 0; j--) {
bool IsEliminated = (index_list->at(node_out->at(j - 1)) == SIZE_MAX);
if (IsEliminated) {
node_out->erase(node_out->begin() + SizeToLong(j) - 1);
(void)node_out->erase(node_out->begin() + SizeToLong(j) - 1);
} else {
node_out->at(j - 1) = index_list->at(node_out->at(j - 1));
}

View File

@ -414,6 +414,5 @@ void SetUserAttrs(const std::unordered_map<std::string, ValuePtr> &origin_prim_a
}
}
}
} // namespace parallel
} // namespace mindspore

View File

@ -465,15 +465,15 @@ AnfNodePtr GetPreNode(const AnfNodePtr &node) {
while (!node_queue.empty()) {
auto cur_node = (*node_queue.begin())->cast<CNodePtr>();
if (!cur_node) {
node_queue.erase(node_queue.begin());
(void)node_queue.erase(node_queue.begin());
continue;
}
node_queue.erase(node_queue.begin());
(void)node_queue.erase(node_queue.begin());
if (!IsInEndNodeBlackList(cur_node) && cur_node->HasPrimalAttr(NEED_GRAD)) {
MS_LOG(INFO) << "Pipeline End node: " << cur_node->DebugString();
return cur_node;
}
node_queue.insert(node_queue.end(), cur_node->inputs().begin() + 1, cur_node->inputs().end());
(void)node_queue.insert(node_queue.end(), cur_node->inputs().begin() + 1, cur_node->inputs().end());
}
MS_LOG(EXCEPTION) << "Get Pipeline End node failed.";
}

View File

@ -886,7 +886,6 @@ void PipelineTransformer::CutBorderForNode(const FuncGraphPtr &graph, const AnfN
}
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
OperatorAttrs depend_attrs;
std::vector<AnfNodePtr> send_ops;
std::vector<AnfNodePtr> receive_ops;
auto ret = graph->get_return();

View File

@ -2716,7 +2716,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const FuncGrap
}
std::vector<std::pair<int64_t, int64_t>> manual_shape;
for (int64_t i = 0; i < UlongToLong(param_split_shapes.size()); ++i) {
manual_shape.push_back({param_split_shapes[LongToSize(i)], index_offsets[LongToSize(i)]});
manual_shape.emplace_back(std::make_pair(param_split_shapes[LongToSize(i)], index_offsets[LongToSize(i)]));
}
manual_shape_map[param_name] = manual_shape;
}

View File

@ -98,12 +98,14 @@ class GlobalComm:
INITED = False
CHECK_ENVS = True
class _ExistingGroup:
"""
The communication groups which exist in the progress.
"""
ITEMS = {}
def is_hccl_available():
"""
Check HCCL api is available.
@ -113,6 +115,7 @@ def is_hccl_available():
"""
return _HCCL_AVAILABLE
def is_mpi_available():
"""
Check HCCL & MPI api is available.

View File

@ -36,6 +36,7 @@ def _get_group(group):
return GlobalComm.WORLD_COMM_GROUP
return group
def _check_task_sink_envs():
"""
Check whether task_sink environment variables have been exported or not.
@ -50,6 +51,8 @@ def _check_task_sink_envs():
return False
except ValueError:
return True
finally:
pass
return True
@ -70,12 +73,15 @@ def _check_parallel_envs():
int(rank_id_str)
except ValueError:
print("RANK_ID should be number")
finally:
pass
rank_table_file_str = os.getenv("MINDSPORE_HCCL_CONFIG_PATH")
rank_table_file_str_old = os.getenv("RANK_TABLE_FILE")
if not rank_table_file_str and not rank_table_file_str_old:
raise RuntimeError("Get hccl rank_table_file failed, "
"please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE.")
def init(backend_name=None):
"""
Initialize distributed backend, e.g. HCCL/NCCL, it is required before using the communication service.

View File

@ -55,6 +55,7 @@ class MoEConfig:
default_moe_config = MoEConfig()
@constexpr
def calculate_expert_capacity(k, tokens_per_device, capacity_factor, expert_dim):
return math.ceil(k * tokens_per_device * capacity_factor / expert_dim)

View File

@ -1163,8 +1163,8 @@ def build_searched_strategy(strategy_filename):
"""
Build strategy of every parameter in network. Used in the case of distributed inference.
For details of merge_sliced_parameter, please check:
`Enabling Graph-Accounting Convergence
<https://www.mindspore.cn/docs/programming_guide/en/master/save_load_model_hybrid_parallel.html>`_.
`Enabling Graph-Accounting Convergence <https://www.mindspore.cn/docs/programming_guide
/en/master/save_load_model_hybrid_parallel.html>`_.
Args:
strategy_filename (str): Name of strategy file.
@ -1211,8 +1211,8 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
"""
Merge parameter slices into one parameter. Used in the case of distributed inference.
For details of merge_sliced_parameter, please check:
`Enabling Graph-Accounting Convergence
<https://www.mindspore.cn/docs/programming_guide/en/master/save_load_model_hybrid_parallel.html>`_.
`Enabling Graph-Accounting Convergence <https://www.mindspore.cn/docs/programming_guide
/en/master/save_load_model_hybrid_parallel.html>`_.
Args:
sliced_parameters (list[Parameter]): Parameter slices in order of rank_id.
@ -1299,8 +1299,8 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
"""
Load checkpoint into net for distributed predication. Used in the case of distributed inference.
For details of distributed inference, please check:
`Enabling Graph-Accounting Convergence
<https://www.mindspore.cn/docs/programming_guide/en/master/distributed_inference.html>`_.
`Enabling Graph-Accounting Convergence <https://www.mindspore.cn/docs/programming_guide
/en/master/distributed_inference.html>`_.
Args:
network (Cell): Network for distributed predication.
@ -1438,6 +1438,7 @@ def async_ckpt_thread_status():
def _check_predict_strategy(predict_strategy):
"""Check predict strategy."""
def _check_int_list(arg):
if not isinstance(arg, list):
return False