fix fault recover in optimizer shard

This commit is contained in:
yao_yf 2021-11-17 10:28:04 +08:00
parent e6ee9b2f19
commit 01dc4bbdf9
8 changed files with 240 additions and 46 deletions

View File

@ -259,6 +259,7 @@ RankList DeviceManager::FindRankListByHashName(const std::string &hash_name) {
} }
RankList rank_list; RankList rank_list;
std::string rank_str = ""; std::string rank_str = "";
rank_list_name = rank_list_name + "-";
for (size_t i = 0; i < rank_list_name.size(); i++) { for (size_t i = 0; i < rank_list_name.size(); i++) {
if (rank_list_name[i] == '-') { if (rank_list_name[i] == '-') {
int64_t rank_id = std::stoi(rank_str); int64_t rank_id = std::stoi(rank_str);

View File

@ -374,7 +374,7 @@ void HandleNoUsedParameter(const FuncGraphPtr &root) {
} }
} }
bool IsFullySplitParameter(const ParameterPtr &param_ptr) { bool IsFullySplitParameter(const ParameterPtr &param_ptr, size_t allow_repeat_num) {
auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>(); auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
if (tensor_layout == nullptr) { if (tensor_layout == nullptr) {
return false; return false;
@ -391,7 +391,7 @@ bool IsFullySplitParameter(const ParameterPtr &param_ptr) {
return false; return false;
} }
if (group_devices.size() == 1) { if (group_devices.size() <= allow_repeat_num) {
MS_LOG(INFO) << "The parameter: " << param_ptr->name() << " is fully split"; MS_LOG(INFO) << "The parameter: " << param_ptr->name() << " is fully split";
return true; return true;
} }

View File

@ -45,7 +45,7 @@ void HandleFullySplitParameters(const FuncGraphPtr &root);
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root); void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root);
void HandleAdaFactorOpt(const FuncGraphPtr &root); void HandleAdaFactorOpt(const FuncGraphPtr &root);
bool ParameterIsCloned(const AnfNodePtr &parameter_node); bool ParameterIsCloned(const AnfNodePtr &parameter_node);
bool IsFullySplitParameter(const ParameterPtr &param_ptr); bool IsFullySplitParameter(const ParameterPtr &param_ptr, size_t allow_repeat_num = 1);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

View File

@ -3123,48 +3123,13 @@ bool IsInsertVirtualOutput(const FuncGraphPtr &root) {
current_stage == split_stage_num - 1); current_stage == split_stage_num - 1);
} }
RankList FindCommonMirrorGroup(const FuncGraphPtr &root) {
auto parameters = root->parameters();
for (auto &parameter : parameters) {
auto param_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (IsFullySplitParameter(param_ptr)) {
MS_LOG(WARNING) << "The parameter :" << param_ptr->fullname_with_scope()
<< " is fully shard, thus cannot find common data parallel group for this rank";
return {};
}
}
AnfNodePtr ret = root->get_return();
MS_EXCEPTION_IF_NULL(ret);
std::vector<int64_t> common_group_list;
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
bool is_first_group = true;
for (auto &node : all_nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimMirror)) {
continue;
}
auto prim = GetCNodePrimitive(node);
if (!prim->HasAttr(GROUP)) {
MS_LOG(EXCEPTION) << "The mirror operator dose not have group attr : " << node->DebugString();
}
std::string group_name = GetValue<std::string>(prim->GetAttr(GROUP));
std::vector<int64_t> group_list = g_device_manager->FindRankListByHashName(group_name);
if (is_first_group) {
common_group_list = group_list;
is_first_group = false;
} else {
std::vector<int64_t> new_comm_group_list;
std::set_intersection(common_group_list.begin(), common_group_list.end(), group_list.begin(), group_list.end(),
std::back_inserter(new_comm_group_list));
common_group_list = new_comm_group_list;
}
}
MS_LOG(INFO) << "The common mirror group is:" << common_group_list;
return common_group_list;
}
static void HandleGroupInfo(const FuncGraphPtr &root) { static void HandleGroupInfo(const FuncGraphPtr &root) {
auto group_info = g_device_manager->group_info(); auto group_info = g_device_manager->group_info();
auto group_info_save_path = common::GetEnv("GROUP_INFO_FILE");
if (!group_info_save_path.empty()) {
ParallelContext::GetInstance()->set_group_ckpt_save_file(group_info_save_path);
}
if (StrategyCheckpoint::GetInstance().group_info_save_on()) { if (StrategyCheckpoint::GetInstance().group_info_save_on()) {
RankList comm_group = FindCommonMirrorGroup(root); RankList comm_group = FindCommonMirrorGroup(root);
if (StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info, comm_group) != SUCCESS) { if (StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info, comm_group) != SUCCESS) {
@ -3173,6 +3138,25 @@ static void HandleGroupInfo(const FuncGraphPtr &root) {
} }
} }
static void HandleDataParallel() {
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
if (parallel_mode == DATA_PARALLEL) {
auto group_info_save_path = common::GetEnv("GROUP_INFO_FILE");
if (!group_info_save_path.empty()) {
std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info;
int64_t device_num = GetCommInfo().device_num;
RankList comm_group;
for (size_t i = 0; i < size_t(device_num); ++i) {
comm_group.push_back(i);
}
ParallelContext::GetInstance()->set_group_ckpt_save_file(group_info_save_path);
if (StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info, comm_group) != SUCCESS) {
MS_LOG(EXCEPTION) << "Save group info failed";
}
}
}
}
static void PipelinePostProcess(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) { static void PipelinePostProcess(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num(); auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
if (pipeline_stages > 1) { if (pipeline_stages > 1) {
@ -3191,6 +3175,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
MS_EXCEPTION_IF_NULL(optimizer); MS_EXCEPTION_IF_NULL(optimizer);
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
HandleDataParallel();
pipeline::ResourceBasePtr res = optimizer->resource(); pipeline::ResourceBasePtr res = optimizer->resource();
MS_EXCEPTION_IF_NULL(res); MS_EXCEPTION_IF_NULL(res);
FuncGraphManagerPtr manager = res->manager(); FuncGraphManagerPtr manager = res->manager();

View File

@ -35,6 +35,7 @@
#include "frontend/parallel/graph_util/graph_info.h" #include "frontend/parallel/graph_util/graph_info.h"
#include "frontend/parallel/graph_util/node_info.h" #include "frontend/parallel/graph_util/node_info.h"
#include "frontend/parallel/node_check.h" #include "frontend/parallel/node_check.h"
#include "frontend/parallel/parameter_manager.h"
#include "ir/param_info.h" #include "ir/param_info.h"
#include "ir/tensor.h" #include "ir/tensor.h"
#include "utils/trace_base.h" #include "utils/trace_base.h"
@ -149,6 +150,61 @@ Shapes GetNodeShape(const AnfNodePtr &node) {
return shapes; return shapes;
} }
RankList FindCommonMirrorGroup(const FuncGraphPtr &root) {
auto parameters = root->parameters();
for (auto &parameter : parameters) {
auto param_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (!(param_ptr->has_default() && ParameterRequireGrad(param_ptr))) {
continue;
}
size_t allow_repeat_num = 1;
if (ParallelContext::GetInstance()->enable_parallel_optimizer() &&
(!param_ptr->param_info() || !param_ptr->param_info()->parallel_optimizer())) {
if (ParallelContext::GetInstance()->optimizer_weight_shard_size() == -1) {
MS_LOG(WARNING) << "The parameter :" << param_ptr->fullname_with_scope()
<< " is fully shard by optimizer parallel,"
" thus cannot find common data parallel group for this rank";
return {g_device_manager->global_rank()};
}
allow_repeat_num = size_t(ParallelContext::GetInstance()->optimizer_weight_shard_size());
}
if (IsFullySplitParameter(param_ptr, allow_repeat_num)) {
MS_LOG(WARNING) << "The parameter :" << param_ptr->fullname_with_scope()
<< " is fully shard, thus cannot find common data parallel group for this rank";
return {g_device_manager->global_rank()};
}
}
AnfNodePtr ret = root->get_return();
MS_EXCEPTION_IF_NULL(ret);
std::vector<int64_t> common_group_list;
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
bool is_first_group = true;
for (auto &node : all_nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimMirror) && !IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep) &&
!IsPrimitiveCNode(node, prim::kPrimMirrorMiniStep)) {
continue;
}
auto prim = GetCNodePrimitive(node);
if (!prim->HasAttr(GROUP)) {
MS_LOG(EXCEPTION) << "The mirror operator dose not have group attr : " << node->DebugString();
}
std::string group_name = GetValue<std::string>(prim->GetAttr(GROUP));
std::vector<int64_t> group_list = g_device_manager->FindRankListByHashName(group_name);
if (is_first_group) {
common_group_list = group_list;
is_first_group = false;
} else {
std::vector<int64_t> new_comm_group_list;
std::set_intersection(common_group_list.begin(), common_group_list.end(), group_list.begin(), group_list.end(),
std::back_inserter(new_comm_group_list));
common_group_list = new_comm_group_list;
}
}
MS_LOG(INFO) << "The common mirror group is:" << common_group_list;
return common_group_list;
}
std::string CreateInstanceName(const CNodePtr &node, size_t index) { std::string CreateInstanceName(const CNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!IsValueNode<Primitive>(node->input(0))) { if (!IsValueNode<Primitive>(node->input(0))) {

View File

@ -28,6 +28,7 @@ namespace parallel {
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name); bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
bool IsParallelCareNode(const CNodePtr &cnode); bool IsParallelCareNode(const CNodePtr &cnode);
Shapes GetNodeShape(const AnfNodePtr &node); Shapes GetNodeShape(const AnfNodePtr &node);
RankList FindCommonMirrorGroup(const FuncGraphPtr &root);
std::string CreateInstanceName(const CNodePtr &node, size_t index); std::string CreateInstanceName(const CNodePtr &node, size_t index);
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input); void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,

View File

@ -1160,10 +1160,12 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
return merged_tensor return merged_tensor
def ckpt_restore_group_info(group_info_file_name): def restore_group_info_list(group_info_file_name):
""" """
Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
that saves the group_info_file_name that saves the group_info_file_name.
To save the group info file, please export GROUP_INFO_FILE environment variables like
"export GROUP_INFO_FILE=/data/group_info.pb".
Args: Args:
group_info_file_name (str): Name of group information file. group_info_file_name (str): Name of group information file.
@ -1175,7 +1177,7 @@ def ckpt_restore_group_info(group_info_file_name):
TypeError: group_info_file_name is not str. TypeError: group_info_file_name is not str.
Examples: Examples:
>>> restore_list = ckpt_restore_group_info("./group_info.ckpt") >>> restore_list = restore_group_info_list("./group_info.pb")
""" """
if not isinstance(group_info_file_name, str): if not isinstance(group_info_file_name, str):
raise TypeError(f"The group_info_file_name should be str, but got {type(group_info_file_name)}.") raise TypeError(f"The group_info_file_name should be str, but got {type(group_info_file_name)}.")

View File

@ -0,0 +1,149 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test group info """
import os
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.common.api import _cell_graph_executor
from mindspore.nn import TrainOneStepCell
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P
from mindspore import context
from mindspore.train.serialization import restore_group_info_list
class Net3(nn.Cell):
"""Net definition"""
def __init__(self, strategy1, strategy2, strategy3):
super(Net3, self).__init__()
self.fc1 = P.MatMul().shard(strategy1)
self.fc2 = P.MatMul().shard(strategy2)
self.fc3 = P.MatMul().shard(strategy3)
self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2", parallel_optimizer=False)
self.p3 = Parameter(Tensor(np.ones([16, 16]).astype(np.float32)), name="weight3")
def construct(self, x, y):
x = self.fc1(x, self.p1)
x = self.fc2(x, self.p2)
z = x - y
z = self.fc3(z, self.p3)
return z
def auto_parallel_compile_net(strategy1=None, strategy2=None, strategy3=None):
context.set_context(mode=context.GRAPH_MODE)
inputs = Tensor(np.ones([32, 48]).astype(np.float32))
label = Tensor(np.zeros([32, 16]).astype(np.float32))
net = Net3(strategy1, strategy2, strategy3)
auto_parallel = context.get_auto_parallel_context("parallel_mode") in ["semi_auto_parallel", "auto_parallel"]
if auto_parallel:
net = _VirtualDatasetCell(net)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_network = TrainOneStepCell(net, optimizer).set_comm_fusion(4)
train_network.set_auto_parallel()
train_network.set_train()
_cell_graph_executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=auto_parallel)
def test_mirror_group():
"""
Feature: save and load mirror group
Description: semi-auto, disable parallel optimizer.
Expectation: group info list match expectation value.
"""
os.environ['GROUP_INFO_FILE'] = "./test_mirror_group.pb"
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel",
device_num=32, enable_parallel_optimizer=False)
auto_parallel_compile_net(((8, 1), (1, 4)), ((32, 1), (1, 1)), ((8, 4), (4, 1)))
group_info_list = restore_group_info_list("./test_mirror_group.pb")
assert group_info_list == [0, 4, 8, 12, 16, 20, 24, 28]
context.reset_auto_parallel_context()
del os.environ['GROUP_INFO_FILE']
def test_data_parallel_group():
"""
Feature: save and load mirror group
Description: data-parallel, disable parallel optimizer.
Expectation: group info list match expectation value.
"""
os.environ['GROUP_INFO_FILE'] = "./test_data_parallel_group.pb"
context.set_auto_parallel_context(parallel_mode="data_parallel",
device_num=32, enable_parallel_optimizer=False)
auto_parallel_compile_net(((8, 1), (1, 4)), ((32, 1), (1, 1)), ((8, 4), (4, 1)))
group_info_list = restore_group_info_list("./test_data_parallel_group.pb")
assert group_info_list == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31]
context.reset_auto_parallel_context()
del os.environ['GROUP_INFO_FILE']
def test_mirror_group_parallel_optimizer():
"""
Feature: save and load mirror group
Description: semi-auto, enable parallel optimizer.
Expectation: group info list match expectation value.
"""
os.environ['GROUP_INFO_FILE'] = "./test_mirror_group_parallel_optimizer.pb"
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel",
device_num=32, enable_parallel_optimizer=True)
auto_parallel_compile_net(((8, 1), (1, 4)), ((32, 1), (1, 1)), ((8, 4), (4, 1)))
group_info_list = restore_group_info_list("./test_mirror_group_parallel_optimizer.pb")
assert group_info_list == [0]
context.reset_auto_parallel_context()
del os.environ['GROUP_INFO_FILE']
def test_mirror_group_parallel_optimizer_not_full_shard():
"""
Feature: save and load mirror group
Description: semi-auto, enable parallel optimizer but not fully shard.
Expectation: group info list match expectation value.
"""
os.environ['GROUP_INFO_FILE'] = "./test_mirror_group_parallel_optimizer_not_full_shard.pb"
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel",
device_num=32, enable_parallel_optimizer=True, optimizer_weight_shard_size=2)
auto_parallel_compile_net(((8, 1), (1, 4)), ((32, 1), (1, 1)), ((8, 4), (4, 1)))
group_info_list = restore_group_info_list("./test_mirror_group_parallel_optimizer_not_full_shard.pb")
assert group_info_list == [0, 8, 16, 24]
context.reset_auto_parallel_context()
del os.environ['GROUP_INFO_FILE']
def test_pipeline_split_stage0_mirror_group():
"""
Feature: save and load mirror group
Description: semi-auto, pipeline parallel.
Expectation: group info list match expectation value.
"""
import mindspore as ms
from mindspore import Model
from .test_pipeline_split import PipelineCell, PipelineSplit, DatasetLenet
os.environ['GROUP_INFO_FILE'] = "./test_pipeline_split_stage0_mirror_group.pb"
context.set_auto_parallel_context(device_num=64, global_rank=0, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 8))
strategy2 = ((4, 1), (1, 1))
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
params = net.network.cell.block[0].trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
group_info_list = restore_group_info_list("./test_pipeline_split_stage0_mirror_group.pb")
assert group_info_list == [0, 8, 16, 24]