forked from mindspore-Ecosystem/mindspore
fix handgroupinfo bug
This commit is contained in:
parent
7bf397ea1c
commit
de0f0bf214
|
@ -0,0 +1,43 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "frontend/parallel/pass/handle_group_info.h"
|
||||||
|
#include "frontend/parallel/device_manager.h"
|
||||||
|
#include "include/common/utils/parallel_context.h"
|
||||||
|
#include "frontend/parallel/step_parallel_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
void HandleGroupInfo(const FuncGraphPtr &root) {
|
||||||
|
if (g_device_manager == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto group_info = g_device_manager->group_info();
|
||||||
|
auto group_info_save_path = common::GetEnv("GROUP_INFO_FILE");
|
||||||
|
if (!group_info_save_path.empty()) {
|
||||||
|
ParallelContext::GetInstance()->set_group_ckpt_save_file(group_info_save_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (StrategyCheckpoint::GetInstance().group_info_save_on()) {
|
||||||
|
auto &strategy_ckt = StrategyCheckpoint::GetInstance();
|
||||||
|
RankList comm_group = strategy_ckt.common_mirror_group();
|
||||||
|
if (strategy_ckt.SaveGroupInfo(group_info, comm_group) != SUCCESS) {
|
||||||
|
MS_LOG(EXCEPTION) << "Save group info failed";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,28 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PASS_HANDLE_GROUP_INFO_H_
|
||||||
|
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PASS_HANDLE_GROUP_INFO_H_
|
||||||
|
|
||||||
|
#include "ir/anf.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
// Handle hccl group info.
|
||||||
|
void HandleGroupInfo(const FuncGraphPtr &graph);
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PASS_HANDLE_GROUP_INFO_H_
|
|
@ -14,20 +14,21 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "frontend/optimizer/micro_interleaved_order_control.h"
|
#include "frontend/parallel/pass/micro_interleaved_order_control.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "mindspore/core/ops/core_ops.h"
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
#include "include/common/utils/utils.h"
|
#include "include/common/utils/utils.h"
|
||||||
#include "frontend/parallel/step_parallel.h"
|
#include "frontend/parallel/step_parallel.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace parallel {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr auto kGradientsFlag = "Gradients";
|
constexpr auto kGradientsFlag = "Gradients";
|
||||||
const size_t interleaved_size = 2;
|
const size_t interleaved_size = 2;
|
||||||
|
@ -302,5 +303,5 @@ void MicroInterleavedOrderControl(const FuncGraphPtr &graph) {
|
||||||
}
|
}
|
||||||
MicroInterleavedOrderControlPipeline(manager, origin_nodes_topological);
|
MicroInterleavedOrderControlPipeline(manager, origin_nodes_topological);
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
|
@ -14,15 +14,15 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_MICRO_INTERLEAVED_ORDER_CONTROL_H_
|
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PASS_MICRO_INTERLEAVED_ORDER_CONTROL_H_
|
||||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_MICRO_INTERLEAVED_ORDER_CONTROL_H_
|
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PASS_MICRO_INTERLEAVED_ORDER_CONTROL_H_
|
||||||
|
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace parallel {
|
||||||
// Micro interleaved nodes order control.
|
// Micro interleaved nodes order control.
|
||||||
void MicroInterleavedOrderControl(const FuncGraphPtr &graph);
|
void MicroInterleavedOrderControl(const FuncGraphPtr &graph);
|
||||||
} // namespace opt
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_MICRO_INTERLEAVED_ORDER_CONTROL_H_
|
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PASS_MICRO_INTERLEAVED_ORDER_CONTROL_H_
|
|
@ -14,7 +14,7 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "frontend/optimizer/overlap_opt_shard_in_pipeline.h"
|
#include "frontend/parallel/pass/overlap_opt_shard_in_pipeline.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <list>
|
#include <list>
|
||||||
|
@ -29,7 +29,7 @@
|
||||||
#include "include/common/utils/comm_manager.h"
|
#include "include/common/utils/comm_manager.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace parallel {
|
||||||
namespace {
|
namespace {
|
||||||
inline bool is_allgather_comm_ops(const AnfNodePtr &node) {
|
inline bool is_allgather_comm_ops(const AnfNodePtr &node) {
|
||||||
static const std::vector<PrimitivePtr> kAllGatherOpsPrim = {prim::kPrimMicroStepAllGather,
|
static const std::vector<PrimitivePtr> kAllGatherOpsPrim = {prim::kPrimMicroStepAllGather,
|
||||||
|
@ -131,5 +131,5 @@ void OverlapOptShardInPipeline(const FuncGraphPtr &graph) {
|
||||||
manager->SetEdge(recv_user, recv_user_index, depend_node);
|
manager->SetEdge(recv_user, recv_user_index, depend_node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
|
@ -14,15 +14,15 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OVERLAP_OPT_SHARD_IN_PIPELINE_H_
|
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PASS_OVERLAP_OPT_SHARD_IN_PIPELINE_H_
|
||||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OVERLAP_OPT_SHARD_IN_PIPELINE_H_
|
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PASS_OVERLAP_OPT_SHARD_IN_PIPELINE_H_
|
||||||
|
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace parallel {
|
||||||
// Automatically insert duplicated recomputed nodes.
|
// Automatically insert duplicated recomputed nodes.
|
||||||
void OverlapOptShardInPipeline(const FuncGraphPtr &graph);
|
void OverlapOptShardInPipeline(const FuncGraphPtr &graph);
|
||||||
} // namespace opt
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OVERLAP_OPT_SHARD_IN_PIPELINE_H_
|
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PASS_OVERLAP_OPT_SHARD_IN_PIPELINE_H_
|
|
@ -2468,21 +2468,6 @@ static void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphMan
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void HandleGroupInfo(const FuncGraphPtr &root) {
|
|
||||||
auto group_info = g_device_manager->group_info();
|
|
||||||
auto group_info_save_path = common::GetEnv("GROUP_INFO_FILE");
|
|
||||||
if (!group_info_save_path.empty()) {
|
|
||||||
ParallelContext::GetInstance()->set_group_ckpt_save_file(group_info_save_path);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (StrategyCheckpoint::GetInstance().group_info_save_on()) {
|
|
||||||
RankList comm_group = FindCommonMirrorGroup(root);
|
|
||||||
if (StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info, comm_group) != SUCCESS) {
|
|
||||||
MS_LOG(EXCEPTION) << "Save group info failed";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void HandleDataParallel() {
|
static void HandleDataParallel() {
|
||||||
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
||||||
if (parallel_mode == kDataParallel) {
|
if (parallel_mode == kDataParallel) {
|
||||||
|
@ -2792,8 +2777,8 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
||||||
|
|
||||||
PipelinePostProcess(root, all_nodes);
|
PipelinePostProcess(root, all_nodes);
|
||||||
|
|
||||||
HandleGroupInfo(root);
|
auto comm_group = FindCommonMirrorGroup(root);
|
||||||
|
StrategyCheckpoint::GetInstance().set_common_mirror_group(comm_group);
|
||||||
// handle full split parameters in grad accumulation, do not contain optimizer-sharding's parameter
|
// handle full split parameters in grad accumulation, do not contain optimizer-sharding's parameter
|
||||||
HandleFullySplitParameters(root);
|
HandleFullySplitParameters(root);
|
||||||
|
|
||||||
|
|
|
@ -50,6 +50,9 @@ class StrategyCheckpoint {
|
||||||
bool LoadCheckPointOn() const { return load_checkpoint_on_; }
|
bool LoadCheckPointOn() const { return load_checkpoint_on_; }
|
||||||
bool SaveCheckPointOn() const { return save_checkpoint_on_; }
|
bool SaveCheckPointOn() const { return save_checkpoint_on_; }
|
||||||
|
|
||||||
|
void set_common_mirror_group(const RankList &comm_group) { common_mirror_group_ = comm_group; }
|
||||||
|
RankList common_mirror_group() const { return common_mirror_group_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string load_file_;
|
std::string load_file_;
|
||||||
std::string save_file_;
|
std::string save_file_;
|
||||||
|
@ -63,6 +66,7 @@ class StrategyCheckpoint {
|
||||||
bool load_format_json_ = true;
|
bool load_format_json_ = true;
|
||||||
bool save_format_json_ = true;
|
bool save_format_json_ = true;
|
||||||
StrategyCheckpointInfo strategy_checkpoint_info_;
|
StrategyCheckpointInfo strategy_checkpoint_info_;
|
||||||
|
RankList common_mirror_group_;
|
||||||
};
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -43,14 +43,15 @@
|
||||||
#include "frontend/parallel/pynative_shard/pynative_shard.h"
|
#include "frontend/parallel/pynative_shard/pynative_shard.h"
|
||||||
#include "frontend/parallel/pass/label_micro_interleaved_index.h"
|
#include "frontend/parallel/pass/label_micro_interleaved_index.h"
|
||||||
#include "frontend/parallel/pass/reorder_send_recv_between_fp_bp.h"
|
#include "frontend/parallel/pass/reorder_send_recv_between_fp_bp.h"
|
||||||
|
#include "frontend/parallel/pass/micro_interleaved_order_control.h"
|
||||||
|
#include "frontend/parallel/pass/overlap_opt_shard_in_pipeline.h"
|
||||||
|
#include "frontend/parallel/pass/handle_group_info.h"
|
||||||
#include "frontend/optimizer/recompute.h"
|
#include "frontend/optimizer/recompute.h"
|
||||||
#include "frontend/optimizer/slice_activation_in_recompute.h"
|
#include "frontend/optimizer/slice_activation_in_recompute.h"
|
||||||
#include "frontend/optimizer/micro_interleaved_order_control.h"
|
|
||||||
#include "frontend/optimizer/comm_op_attrs.h"
|
#include "frontend/optimizer/comm_op_attrs.h"
|
||||||
#include "frontend/optimizer/environ_conversion.h"
|
#include "frontend/optimizer/environ_conversion.h"
|
||||||
#include "frontend/optimizer/comm_op_reuse_tag.h"
|
#include "frontend/optimizer/comm_op_reuse_tag.h"
|
||||||
#include "frontend/optimizer/py_interpret_to_execute.h"
|
#include "frontend/optimizer/py_interpret_to_execute.h"
|
||||||
#include "frontend/optimizer/overlap_opt_shard_in_pipeline.h"
|
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "pipeline/jit/pipeline_split.h"
|
#include "pipeline/jit/pipeline_split.h"
|
||||||
#include "pipeline/pynative/pynative_execute.h"
|
#include "pipeline/pynative/pynative_execute.h"
|
||||||
|
@ -639,7 +640,7 @@ bool ReorderSendRecvBetweenFpBpPass(const ResourcePtr &resource) {
|
||||||
|
|
||||||
bool MicroInterLeavedOrderControlPass(const ResourcePtr &resource) {
|
bool MicroInterLeavedOrderControlPass(const ResourcePtr &resource) {
|
||||||
MS_EXCEPTION_IF_NULL(resource);
|
MS_EXCEPTION_IF_NULL(resource);
|
||||||
opt::MicroInterleavedOrderControl(resource->func_graph());
|
parallel::MicroInterleavedOrderControl(resource->func_graph());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -657,7 +658,13 @@ bool AddCommOpReusePass(const ResourcePtr &resource) {
|
||||||
|
|
||||||
bool OverlapOptShardInPipelinePass(const ResourcePtr &resource) {
|
bool OverlapOptShardInPipelinePass(const ResourcePtr &resource) {
|
||||||
MS_EXCEPTION_IF_NULL(resource);
|
MS_EXCEPTION_IF_NULL(resource);
|
||||||
opt::OverlapOptShardInPipeline(resource->func_graph());
|
parallel::OverlapOptShardInPipeline(resource->func_graph());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HandleGroupInfoPass(const ResourcePtr &resource) {
|
||||||
|
MS_EXCEPTION_IF_NULL(resource);
|
||||||
|
parallel::HandleGroupInfo(resource->func_graph());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -864,28 +871,28 @@ bool AddEmbeddingCachePass(const ResourcePtr &resource) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<PassItem> kVmPasses = {
|
std::vector<PassItem> kVmPasses = {{"py_interpret_to_execute", PyInterpretToExecutePass},
|
||||||
{"py_interpret_to_execute", PyInterpretToExecutePass},
|
{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||||
{"simplify_data_structures", SimplifyDataStructuresPass},
|
{"opt_a", OptPassAGroup},
|
||||||
{"opt_a", OptPassAGroup},
|
{"clean_after_opta", CleanAfterOptAPass},
|
||||||
{"clean_after_opta", CleanAfterOptAPass},
|
{"opt_b", OptPassBGroup},
|
||||||
{"opt_b", OptPassBGroup},
|
{"cconv", CconvPass},
|
||||||
{"cconv", CconvPass},
|
{"opt_after_cconv", OptPassAfterCconvGroup},
|
||||||
{"opt_after_cconv", OptPassAfterCconvGroup},
|
{"remove_dup_value", RemoveValueNodeDuplicationsPass},
|
||||||
{"remove_dup_value", RemoveValueNodeDuplicationsPass},
|
{"tuple_transform", OptPassTransformGraphGroup},
|
||||||
{"tuple_transform", OptPassTransformGraphGroup},
|
{"add_cache_embedding", AddCacheEmbeddingPass},
|
||||||
{"add_cache_embedding", AddCacheEmbeddingPass},
|
{"add_recomputation", AddRecomputationPass},
|
||||||
{"add_recomputation", AddRecomputationPass},
|
{"cse_after_recomputation", OptAfterRecomputeGroup},
|
||||||
{"cse_after_recomputation", OptAfterRecomputeGroup},
|
{"environ_conv", EnvironConversionPass},
|
||||||
{"environ_conv", EnvironConversionPass},
|
{"label_micro_interleaved_index", LabelMicroInterleavedIndexPass},
|
||||||
{"label_micro_interleaved_index", LabelMicroInterleavedIndexPass},
|
{"slice_recompute_activation", SliceRecomputeActivationPass},
|
||||||
{"slice_recompute_activation", SliceRecomputeActivationPass},
|
{"micro_interleaved_order_control", MicroInterLeavedOrderControlPass},
|
||||||
{"micro_interleaved_order_control", MicroInterLeavedOrderControlPass},
|
{"reorder_send_recv_between_fp_bp", ReorderSendRecvBetweenFpBpPass},
|
||||||
{"reorder_send_recv_between_fp_bp", ReorderSendRecvBetweenFpBpPass},
|
{"comm_op_add_attrs", CommOpAddAttrs},
|
||||||
{"comm_op_add_attrs", CommOpAddAttrs},
|
{"add_comm_op_reuse_tag", AddCommOpReusePass},
|
||||||
{"add_comm_op_reuse_tag", AddCommOpReusePass},
|
{"overlap_opt_shard_in_pipeline", OverlapOptShardInPipelinePass},
|
||||||
{"overlap_opt_shard_in_pipeline", OverlapOptShardInPipelinePass},
|
// The pass cache hccl group, so the hccl group should be created before the pass
|
||||||
};
|
{"handle_group_info", HandleGroupInfoPass}};
|
||||||
|
|
||||||
std::vector<PassItem> kGePasses = {{"py_interpret_to_execute", PyInterpretToExecutePass},
|
std::vector<PassItem> kGePasses = {{"py_interpret_to_execute", PyInterpretToExecutePass},
|
||||||
{"simplify_data_structures", SimplifyDataStructuresPass},
|
{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||||
|
@ -894,7 +901,9 @@ std::vector<PassItem> kGePasses = {{"py_interpret_to_execute", PyInterpretToExec
|
||||||
{"opt_b", OptPassBGroup},
|
{"opt_b", OptPassBGroup},
|
||||||
{"opt_control", ControlGroup},
|
{"opt_control", ControlGroup},
|
||||||
{"opt_prepare", PrepareGroup},
|
{"opt_prepare", PrepareGroup},
|
||||||
{"cconv", CconvPass}};
|
{"cconv", CconvPass},
|
||||||
|
// The pass cache hccl group, so the hccl group should be created before the pass
|
||||||
|
{"handle_group_info", HandleGroupInfoPass}};
|
||||||
|
|
||||||
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup},
|
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup},
|
||||||
{"opt_b", OptPassBGroup},
|
{"opt_b", OptPassBGroup},
|
||||||
|
|
Loading…
Reference in New Issue