forked from OSSInnovation/mindspore
auto parallel context add notes and func mv
This commit is contained in:
parent
5a76bd717d
commit
b70204c080
|
@ -1542,17 +1542,8 @@ size_t CostGraph::GetNumEdges() const {
|
||||||
}
|
}
|
||||||
return sum;
|
return sum;
|
||||||
}
|
}
|
||||||
Status CostGraph::InitSelectedStrategy() {
|
|
||||||
for (auto &op : ops_) {
|
Status CostGraph::InitReshapeStrategy() {
|
||||||
MS_EXCEPTION_IF_NULL(op);
|
|
||||||
if (op->name().find(RESHAPEINFO) != std::string::npos) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto result = op->InitSelectedStrategy(op->selected_strategy());
|
|
||||||
if (result != SUCCESS) {
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// reshape init should be apply after the init of it's previous node and next node.
|
// reshape init should be apply after the init of it's previous node and next node.
|
||||||
for (size_t i = 0; i < ops_.size(); ++i) {
|
for (size_t i = 0; i < ops_.size(); ++i) {
|
||||||
if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) {
|
if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) {
|
||||||
|
@ -1606,6 +1597,21 @@ Status CostGraph::InitSelectedStrategy() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status CostGraph::InitSelectedStrategy() {
|
||||||
|
for (auto &op : ops_) {
|
||||||
|
MS_EXCEPTION_IF_NULL(op);
|
||||||
|
if (op->name().find(RESHAPEINFO) != std::string::npos) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto result = op->InitSelectedStrategy(op->selected_strategy());
|
||||||
|
if (result != SUCCESS) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto result = InitReshapeStrategy();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
Status CostGraph::ComputeOpsAndEdgesParameterInvolved() {
|
Status CostGraph::ComputeOpsAndEdgesParameterInvolved() {
|
||||||
for (auto &op : ops_) {
|
for (auto &op : ops_) {
|
||||||
MS_EXCEPTION_IF_NULL(op);
|
MS_EXCEPTION_IF_NULL(op);
|
||||||
|
|
|
@ -186,6 +186,7 @@ class CostGraph {
|
||||||
|
|
||||||
std::vector<OperatorInfoPtr> GetOperators() const { return ops_; }
|
std::vector<OperatorInfoPtr> GetOperators() const { return ops_; }
|
||||||
size_t GetNumEdges() const;
|
size_t GetNumEdges() const;
|
||||||
|
Status InitReshapeStrategy();
|
||||||
Status InitSelectedStrategy();
|
Status InitSelectedStrategy();
|
||||||
OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const;
|
OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const;
|
||||||
// When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only
|
// When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only
|
||||||
|
|
|
@ -2275,7 +2275,6 @@ std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node)
|
||||||
}
|
}
|
||||||
|
|
||||||
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
||||||
MS_LOG(DEBUG) << "Save strategy to checkpoint begin";
|
|
||||||
StrategyMap stra_map;
|
StrategyMap stra_map;
|
||||||
TensorInfoMap tensor_info_map;
|
TensorInfoMap tensor_info_map;
|
||||||
ManualShapeMap manual_shape_map;
|
ManualShapeMap manual_shape_map;
|
||||||
|
@ -2298,10 +2297,8 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info();
|
std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info();
|
||||||
StrategyPtr strategyPtr = operator_info->strategy();
|
|
||||||
MS_EXCEPTION_IF_NULL(node->scope());
|
|
||||||
std::string stratey_key_name = prim->name() + "_" + param_name;
|
std::string stratey_key_name = prim->name() + "_" + param_name;
|
||||||
stra_map[stratey_key_name] = strategyPtr;
|
stra_map[stratey_key_name] = operator_info->strategy();
|
||||||
for (auto param_name_pair : param_names) {
|
for (auto param_name_pair : param_names) {
|
||||||
if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) {
|
if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) {
|
||||||
continue;
|
continue;
|
||||||
|
|
|
@ -395,9 +395,10 @@ def set_auto_parallel_context(**kwargs):
|
||||||
should be set with True. Default: False.
|
should be set with True. Default: False.
|
||||||
enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for
|
enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for
|
||||||
data parallel training in the benefit of time and memory saving. For now,
|
data parallel training in the benefit of time and memory saving. For now,
|
||||||
`Lamb` and `AdamWeightDecay` are supported in data parallel mode.
|
`Lamb` and `AdamWeightDecay` are supported in data parallel mode. No Default, if it is not set,
|
||||||
|
the fusion is closed.
|
||||||
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
|
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
|
||||||
and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP.
|
and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If input key is not attribute in auto parallel context.
|
ValueError: If input key is not attribute in auto parallel context.
|
||||||
|
@ -408,9 +409,13 @@ def set_auto_parallel_context(**kwargs):
|
||||||
>>> context.set_auto_parallel_context(gradients_mean=True)
|
>>> context.set_auto_parallel_context(gradients_mean=True)
|
||||||
>>> context.set_auto_parallel_context(gradient_fp32_sync=False)
|
>>> context.set_auto_parallel_context(gradient_fp32_sync=False)
|
||||||
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||||
|
>>> context.set_auto_parallel_context(auto_parallel_search_mode="dynamic_programming")
|
||||||
>>> context.set_auto_parallel_context(parameter_broadcast=False)
|
>>> context.set_auto_parallel_context(parameter_broadcast=False)
|
||||||
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
|
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
|
||||||
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
|
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
|
||||||
|
>>> context.set_auto_parallel_context(full_batch=True)
|
||||||
|
>>> context.set_auto_parallel_context(enable_parallel_optimizer=False)
|
||||||
|
>>> context.set_auto_parallel_context(all_reduce_fusion_config=[8, 160])
|
||||||
"""
|
"""
|
||||||
_set_auto_parallel_context(**kwargs)
|
_set_auto_parallel_context(**kwargs)
|
||||||
|
|
||||||
|
@ -439,10 +444,12 @@ def reset_auto_parallel_context():
|
||||||
- global_rank: 0.
|
- global_rank: 0.
|
||||||
- gradients_mean: False.
|
- gradients_mean: False.
|
||||||
- gradient_fp32_sync: True.
|
- gradient_fp32_sync: True.
|
||||||
- parallel_mode: "stand_alone".
|
- parallel_mode: 'stand_alone'.
|
||||||
|
- auto_parallel_search_mode: 'dynamic_programming'.
|
||||||
- parameter_broadcast: False.
|
- parameter_broadcast: False.
|
||||||
- strategy_ckpt_load_file: "".
|
- strategy_ckpt_load_file: ''.
|
||||||
- strategy_ckpt_save_file: "".
|
- strategy_ckpt_save_file: ''.
|
||||||
|
- full_batch: False.
|
||||||
- enable_parallel_optimizer: False.
|
- enable_parallel_optimizer: False.
|
||||||
"""
|
"""
|
||||||
_reset_auto_parallel_context()
|
_reset_auto_parallel_context()
|
||||||
|
|
|
@ -245,6 +245,10 @@ class _AutoParallelContext:
|
||||||
strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint.
|
strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint.
|
||||||
"""
|
"""
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
|
import os
|
||||||
|
dir_path = os.path.dirname(strategy_ckpt_save_file)
|
||||||
|
if dir_path and not os.path.exists(dir_path):
|
||||||
|
os.makedirs(dir_path)
|
||||||
self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)
|
self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)
|
||||||
|
|
||||||
def get_strategy_ckpt_save_file(self):
|
def get_strategy_ckpt_save_file(self):
|
||||||
|
|
Loading…
Reference in New Issue