auto parallel context add notes and func mv

This commit is contained in:
yao_yf 2020-09-17 16:36:01 +08:00
parent 5a76bd717d
commit b70204c080
5 changed files with 35 additions and 20 deletions

View File

@ -1542,17 +1542,8 @@ size_t CostGraph::GetNumEdges() const {
} }
return sum; return sum;
} }
Status CostGraph::InitSelectedStrategy() {
for (auto &op : ops_) { Status CostGraph::InitReshapeStrategy() {
MS_EXCEPTION_IF_NULL(op);
if (op->name().find(RESHAPEINFO) != std::string::npos) {
continue;
}
auto result = op->InitSelectedStrategy(op->selected_strategy());
if (result != SUCCESS) {
return result;
}
}
// reshape init should be apply after the init of it's previous node and next node. // reshape init should be apply after the init of it's previous node and next node.
for (size_t i = 0; i < ops_.size(); ++i) { for (size_t i = 0; i < ops_.size(); ++i) {
if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) { if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) {
@ -1606,6 +1597,21 @@ Status CostGraph::InitSelectedStrategy() {
return SUCCESS; return SUCCESS;
} }
Status CostGraph::InitSelectedStrategy() {
for (auto &op : ops_) {
MS_EXCEPTION_IF_NULL(op);
if (op->name().find(RESHAPEINFO) != std::string::npos) {
continue;
}
auto result = op->InitSelectedStrategy(op->selected_strategy());
if (result != SUCCESS) {
return result;
}
}
auto result = InitReshapeStrategy();
return result;
}
Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { Status CostGraph::ComputeOpsAndEdgesParameterInvolved() {
for (auto &op : ops_) { for (auto &op : ops_) {
MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(op);

View File

@ -186,6 +186,7 @@ class CostGraph {
std::vector<OperatorInfoPtr> GetOperators() const { return ops_; } std::vector<OperatorInfoPtr> GetOperators() const { return ops_; }
size_t GetNumEdges() const; size_t GetNumEdges() const;
Status InitReshapeStrategy();
Status InitSelectedStrategy(); Status InitSelectedStrategy();
OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const;
// When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only

View File

@ -2275,7 +2275,6 @@ std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node)
} }
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) { void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
MS_LOG(DEBUG) << "Save strategy to checkpoint begin";
StrategyMap stra_map; StrategyMap stra_map;
TensorInfoMap tensor_info_map; TensorInfoMap tensor_info_map;
ManualShapeMap manual_shape_map; ManualShapeMap manual_shape_map;
@ -2298,10 +2297,8 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
continue; continue;
} }
std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info(); std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info();
StrategyPtr strategyPtr = operator_info->strategy();
MS_EXCEPTION_IF_NULL(node->scope());
std::string stratey_key_name = prim->name() + "_" + param_name; std::string stratey_key_name = prim->name() + "_" + param_name;
stra_map[stratey_key_name] = strategyPtr; stra_map[stratey_key_name] = operator_info->strategy();
for (auto param_name_pair : param_names) { for (auto param_name_pair : param_names) {
if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) { if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) {
continue; continue;

View File

@ -395,9 +395,10 @@ def set_auto_parallel_context(**kwargs):
should be set with True. Default: False. should be set with True. Default: False.
enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for
data parallel training in the benefit of time and memory saving. For now, data parallel training in the benefit of time and memory saving. For now,
`Lamb` and `AdamWeightDecay` are supported in data parallel mode. `Lamb` and `AdamWeightDecay` are supported in data parallel mode. No Default, if it is not set,
the fusion is closed.
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed.
Raises: Raises:
ValueError: If input key is not attribute in auto parallel context. ValueError: If input key is not attribute in auto parallel context.
@ -408,9 +409,13 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(gradients_mean=True) >>> context.set_auto_parallel_context(gradients_mean=True)
>>> context.set_auto_parallel_context(gradient_fp32_sync=False) >>> context.set_auto_parallel_context(gradient_fp32_sync=False)
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel") >>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
>>> context.set_auto_parallel_context(auto_parallel_search_mode="dynamic_programming")
>>> context.set_auto_parallel_context(parameter_broadcast=False) >>> context.set_auto_parallel_context(parameter_broadcast=False)
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(full_batch=True)
>>> context.set_auto_parallel_context(enable_parallel_optimizer=False)
>>> context.set_auto_parallel_context(all_reduce_fusion_config=[8, 160])
""" """
_set_auto_parallel_context(**kwargs) _set_auto_parallel_context(**kwargs)
@ -439,10 +444,12 @@ def reset_auto_parallel_context():
- global_rank: 0. - global_rank: 0.
- gradients_mean: False. - gradients_mean: False.
- gradient_fp32_sync: True. - gradient_fp32_sync: True.
- parallel_mode: "stand_alone". - parallel_mode: 'stand_alone'.
- auto_parallel_search_mode: 'dynamic_programming'.
- parameter_broadcast: False. - parameter_broadcast: False.
- strategy_ckpt_load_file: "". - strategy_ckpt_load_file: ''.
- strategy_ckpt_save_file: "". - strategy_ckpt_save_file: ''.
- full_batch: False.
- enable_parallel_optimizer: False. - enable_parallel_optimizer: False.
""" """
_reset_auto_parallel_context() _reset_auto_parallel_context()

View File

@ -245,6 +245,10 @@ class _AutoParallelContext:
strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint. strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint.
""" """
self.check_context_handle() self.check_context_handle()
import os
dir_path = os.path.dirname(strategy_ckpt_save_file)
if dir_path and not os.path.exists(dir_path):
os.makedirs(dir_path)
self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file) self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)
def get_strategy_ckpt_save_file(self): def get_strategy_ckpt_save_file(self):