diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h index e6bc469daf..42c8fee9e4 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h @@ -46,6 +46,8 @@ class GatherV2PInfo : public OperatorInfo { Status SetCostUnderStrategy(const StrategyPtr &strategy) override; ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; std::shared_ptr GenerateBatchStrategies() override; + const std::vector ¶m_split_shapes() const { return param_split_shapes_; } + const std::vector &index_offsets() const { return index_offsets_; } protected: Status CheckStrategy(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 2397499162..33d2bbc609 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -334,7 +334,11 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & operator_info->set_outputs_dtype(cnode->Type()); operator_info->set_cnode(cnode); // key of strategy map - std::string strategy_key_name = NodeParameterName(cnode); + std::string strategy_key_name = ""; + auto param_names = NodeParameterName(cnode); + if (!param_names.empty()) { + strategy_key_name = param_names[0].first; + } bool load_strategy_from_ckpt = StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); // If no strategy has been configured for this operator, then candidate strategies are generated for diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 3372990130..f8ec0d3288 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1480,7 +1480,11 @@ void ExtractInformation(const std::vector &all_nodes) { } // load strategy checkpoint // key of strategy map - std::string strategy_key_name = NodeParameterName(cnode); + std::string strategy_key_name = ""; + auto param_names = NodeParameterName(cnode); + if (!param_names.empty()) { + strategy_key_name = param_names[0].first; + } bool load_strategy_from_ckpt = StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); if (!StrategyFound(attrs) && !load_strategy_from_ckpt) { @@ -2118,23 +2122,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector> NodeParameterName(const CNodePtr &node) { std::vector node_inputs{node->inputs()}; - for (auto input : node_inputs) { + std::vector> param_names; + for (int i = 0; i < UintToInt(node_inputs.size()); ++i) { + auto input = node_inputs[i]; if (input->isa()) { auto input_parameter = input->cast(); if (input_parameter->has_default()) { - input_parameter->name(); + if (ParameterRequireGrad(input_parameter)) { + param_names.push_back({input_parameter->name(), i}); + } } } } - return ""; + return param_names; } void CheckpointStrategy(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_LOG(DEBUG) << "Save strategy to checkpoint begin"; StrategyMap stra_map; + TensorInfoMap tensor_info_map; + ManualShapeMap manual_shape_map; auto ret = func_graph->get_return(); auto all_nodes = DeepScopedGraphSearch(ret); for (auto &node : all_nodes) { @@ -2143,10 +2153,11 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { continue; } - std::string param_name = NodeParameterName(cnode); - if (param_name.empty()) { + auto param_names = NodeParameterName(cnode); + if (param_names.empty()) { continue; } + string param_name = param_names[0].first; PrimitivePtr prim = GetValueNode(cnode->input(0)); MS_EXCEPTION_IF_NULL(prim); OperatorInfoPtr operator_info = cnode->user_data(); @@ -2154,12 +2165,33 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { continue; } + std::vector input_tensor_info = operator_info->inputs_tensor_info(); StrategyPtr strategyPtr = operator_info->strategy(); MS_EXCEPTION_IF_NULL(node->scope()); stra_map[param_name] = strategyPtr; + for (auto param_name_pair : param_names) { + if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) { + continue; + } + tensor_info_map[param_name_pair.first] = input_tensor_info[param_name_pair.second - 1]; + } + if (operator_info->name().find(EMBEDDING_LOOKUP) != std::string::npos || + operator_info->name().find(GATHERV2) != std::string::npos) { + auto gatherv2_info = std::dynamic_pointer_cast(operator_info); + auto param_split_shapes = gatherv2_info->param_split_shapes(); + auto index_offsets = gatherv2_info->index_offsets(); + if (param_split_shapes.size() != index_offsets.size()) { + MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets lenght should be same."; + } + std::vector> manual_shape; + for (int i = 0; i < UintToInt(param_split_shapes.size()); ++i) { + manual_shape.push_back({param_split_shapes[i], index_offsets[i]}); + } + manual_shape_map[param_name] = manual_shape; + } } } - if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) { + if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) { MS_LOG(EXCEPTION) << "Save strategy checkpoint failed"; } } diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index 4f20042d00..4e142e4aee 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -135,7 +135,7 @@ void ReshapeInit(const std::vector &all_nodes); void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, const FuncGraphManagerPtr &manager); -std::string NodeParameterName(const CNodePtr &node); +std::vector> NodeParameterName(const CNodePtr &node); void CheckpointStrategy(const FuncGraphPtr &func_graph); diff --git a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc index 0a8852464f..9ab55e6915 100644 --- a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc +++ b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc @@ -84,7 +84,8 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { return SUCCESS; } -Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { +Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, + ManualShapeMap *manual_shape_map) { straspb::ParallelStrategyMap parallel_strategy_map; parallel_strategy_map.set_current_stage(IntToUint(++current_stage_)); for (auto &node_stra : strategy_map) { @@ -103,6 +104,33 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { } } } + for (auto &node_tensor_info : tensor_info_map) { + TensorInfo tensor_info = node_tensor_info.second; + TensorLayout tensor_layout = tensor_info.tensor_layout(); + straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item(); + MS_EXCEPTION_IF_NULL(parallel_layout_item); + parallel_layout_item->set_param_name(node_tensor_info.first); + straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts(); + straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix(); + MS_EXCEPTION_IF_NULL(dev_matrix); + for (auto dim : tensor_layout.device_arrangement().array()) { + dev_matrix->add_dim(IntToUint(dim)); + } + straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map(); + MS_EXCEPTION_IF_NULL(tensor_map); + for (auto dim : tensor_layout.tensor_map().array()) { + tensor_map->add_dim(dim); + } + straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape(); + straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset(); + MS_EXCEPTION_IF_NULL(manual_shape_map); + auto manual_shape = (*manual_shape_map)[node_tensor_info.first]; + for (auto dim_pair : manual_shape) { + param_split_shape->add_dim(dim_pair.first); + indices_offset->add_dim(dim_pair.second); + } + } + std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary); if (!parallel_strategy_map.SerializeToOstream(&output)) { MS_LOG(ERROR) << "Save strategy file failed"; diff --git a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h index e6e2719533..1b1a4e17c0 100644 --- a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h +++ b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h @@ -19,13 +19,19 @@ #include #include +#include +#include #include "frontend/parallel/ops_info/ops_utils.h" #include "frontend/parallel/strategy.h" #include "frontend/parallel/context.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" +#include "frontend/parallel/tensor_layout/tensor_info.h" namespace mindspore { namespace parallel { using StrategyMap = std::unordered_map; +using TensorInfoMap = std::unordered_map; +using ManualShapeMap = std::unordered_map>>; class StrategyCheckpoint { public: StrategyCheckpoint() { @@ -38,7 +44,7 @@ class StrategyCheckpoint { ~StrategyCheckpoint() = default; Status Load(StrategyMap *strategy_map); - Status Save(const StrategyMap &strategy_map); + Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map); static StrategyCheckpoint &GetInstance(); bool LoadCheckPointOn() const { return load_checkpoint_on_; } diff --git a/mindspore/ccsrc/utils/node_strategy.proto b/mindspore/ccsrc/utils/node_strategy.proto index 8ec25f21a6..dc9d65407d 100644 --- a/mindspore/ccsrc/utils/node_strategy.proto +++ b/mindspore/ccsrc/utils/node_strategy.proto @@ -32,7 +32,36 @@ message ParallelStrategyItem { required ParallelStrategys parallel_strategys = 2; } +message DevMatrix { + repeated uint32 dim = 1; +} + +message TensorMap { + repeated int32 dim = 1; +} + +message ParamSplitShape { + repeated int64 dim = 1; +} + +message IndicesOffset { + repeated int64 dim = 1; +} + +message ParallelLayouts { + repeated DevMatrix dev_matrix = 1; + repeated TensorMap tensor_map = 2; + repeated ParamSplitShape param_split_shape = 3; + repeated IndicesOffset indices_offset = 4; +} + +message ParallelLayoutItem { + required string param_name = 1; + required ParallelLayouts parallel_layouts = 2; +} + message ParallelStrategyMap { required uint32 current_stage = 1; repeated ParallelStrategyItem parallel_strategy_item = 2; + repeated ParallelLayoutItem parallel_layout_item = 3; } \ No newline at end of file diff --git a/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc b/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc index f6f2f45092..6ae883cfbd 100644 --- a/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc +++ b/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc @@ -29,6 +29,7 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const { return f Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; } -Status StrategyCheckpoint::Save(const StrategyMap& strategy_map) { return SUCCESS; } +Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, + ManualShapeMap *manual_shape_map) { return SUCCESS; } } // namespace parallel } // namespace mindspore