diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc index 5d57718110b..c906d6f5295 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc @@ -1542,17 +1542,8 @@ size_t CostGraph::GetNumEdges() const { } return sum; } -Status CostGraph::InitSelectedStrategy() { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - if (op->name().find(RESHAPEINFO) != std::string::npos) { - continue; - } - auto result = op->InitSelectedStrategy(op->selected_strategy()); - if (result != SUCCESS) { - return result; - } - } + +Status CostGraph::InitReshapeStrategy() { // reshape init should be apply after the init of it's previous node and next node. for (size_t i = 0; i < ops_.size(); ++i) { if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) { @@ -1606,6 +1597,21 @@ Status CostGraph::InitSelectedStrategy() { return SUCCESS; } +Status CostGraph::InitSelectedStrategy() { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + if (op->name().find(RESHAPEINFO) != std::string::npos) { + continue; + } + auto result = op->InitSelectedStrategy(op->selected_strategy()); + if (result != SUCCESS) { + return result; + } + } + auto result = InitReshapeStrategy(); + return result; +} + Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h index 99d45dfb0a4..b5bc27190c2 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h @@ -186,6 +186,7 @@ class CostGraph { std::vector GetOperators() const { return ops_; } size_t GetNumEdges() const; + Status InitReshapeStrategy(); Status InitSelectedStrategy(); OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index dec0059d4a6..d6110ef4d97 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -2275,7 +2275,6 @@ std::vector> NodeParameterName(const CNodePtr &node) } void CheckpointStrategy(const std::vector &all_nodes) { - MS_LOG(DEBUG) << "Save strategy to checkpoint begin"; StrategyMap stra_map; TensorInfoMap tensor_info_map; ManualShapeMap manual_shape_map; @@ -2298,10 +2297,8 @@ void CheckpointStrategy(const std::vector &all_nodes) { continue; } std::vector input_tensor_info = operator_info->inputs_tensor_info(); - StrategyPtr strategyPtr = operator_info->strategy(); - MS_EXCEPTION_IF_NULL(node->scope()); std::string stratey_key_name = prim->name() + "_" + param_name; - stra_map[stratey_key_name] = strategyPtr; + stra_map[stratey_key_name] = operator_info->strategy(); for (auto param_name_pair : param_names) { if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) { continue; diff --git a/mindspore/context.py b/mindspore/context.py index 5889b27b1d7..dcf706570f0 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -395,9 +395,10 @@ def set_auto_parallel_context(**kwargs): should be set with True. Default: False. enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for data parallel training in the benefit of time and memory saving. For now, - `Lamb` and `AdamWeightDecay` are supported in data parallel mode. + `Lamb` and `AdamWeightDecay` are supported in data parallel mode. No Default, if it is not set, + the fusion is closed. all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM - and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. + and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed. Raises: ValueError: If input key is not attribute in auto parallel context. @@ -408,9 +409,13 @@ def set_auto_parallel_context(**kwargs): >>> context.set_auto_parallel_context(gradients_mean=True) >>> context.set_auto_parallel_context(gradient_fp32_sync=False) >>> context.set_auto_parallel_context(parallel_mode="auto_parallel") + >>> context.set_auto_parallel_context(auto_parallel_search_mode="dynamic_programming") >>> context.set_auto_parallel_context(parameter_broadcast=False) >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") + >>> context.set_auto_parallel_context(full_batch=True) + >>> context.set_auto_parallel_context(enable_parallel_optimizer=False) + >>> context.set_auto_parallel_context(all_reduce_fusion_config=[8, 160]) """ _set_auto_parallel_context(**kwargs) @@ -439,10 +444,12 @@ def reset_auto_parallel_context(): - global_rank: 0. - gradients_mean: False. - gradient_fp32_sync: True. - - parallel_mode: "stand_alone". + - parallel_mode: 'stand_alone'. + - auto_parallel_search_mode: 'dynamic_programming'. - parameter_broadcast: False. - - strategy_ckpt_load_file: "". - - strategy_ckpt_save_file: "". + - strategy_ckpt_load_file: ''. + - strategy_ckpt_save_file: ''. + - full_batch: False. - enable_parallel_optimizer: False. """ _reset_auto_parallel_context() diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index aed133ee26f..9d1dfd09b3d 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -245,6 +245,10 @@ class _AutoParallelContext: strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint. """ self.check_context_handle() + import os + dir_path = os.path.dirname(strategy_ckpt_save_file) + if dir_path and not os.path.exists(dir_path): + os.makedirs(dir_path) self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file) def get_strategy_ckpt_save_file(self):