fix fault recover in optimizer shard

This commit is contained in:
yao_yf 2021-11-17 10:28:04 +08:00
parent e6ee9b2f19
commit 01dc4bbdf9
8 changed files with 240 additions and 46 deletions

View File

@ -259,6 +259,7 @@ RankList DeviceManager::FindRankListByHashName(const std::string &hash_name) {
}
RankList rank_list;
std::string rank_str = "";
rank_list_name = rank_list_name + "-";
for (size_t i = 0; i < rank_list_name.size(); i++) {
if (rank_list_name[i] == '-') {
int64_t rank_id = std::stoi(rank_str);

View File

@ -374,7 +374,7 @@ void HandleNoUsedParameter(const FuncGraphPtr &root) {
}
}
bool IsFullySplitParameter(const ParameterPtr &param_ptr) {
bool IsFullySplitParameter(const ParameterPtr &param_ptr, size_t allow_repeat_num) {
auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
if (tensor_layout == nullptr) {
return false;
@ -391,7 +391,7 @@ bool IsFullySplitParameter(const ParameterPtr &param_ptr) {
return false;
}
if (group_devices.size() == 1) {
if (group_devices.size() <= allow_repeat_num) {
MS_LOG(INFO) << "The parameter: " << param_ptr->name() << " is fully split";
return true;
}

View File

@ -45,7 +45,7 @@ void HandleFullySplitParameters(const FuncGraphPtr &root);
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root);
void HandleAdaFactorOpt(const FuncGraphPtr &root);
bool ParameterIsCloned(const AnfNodePtr &parameter_node);
bool IsFullySplitParameter(const ParameterPtr &param_ptr);
bool IsFullySplitParameter(const ParameterPtr &param_ptr, size_t allow_repeat_num = 1);
} // namespace parallel
} // namespace mindspore

View File

@ -3123,48 +3123,13 @@ bool IsInsertVirtualOutput(const FuncGraphPtr &root) {
current_stage == split_stage_num - 1);
}
RankList FindCommonMirrorGroup(const FuncGraphPtr &root) {
auto parameters = root->parameters();
for (auto &parameter : parameters) {
auto param_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (IsFullySplitParameter(param_ptr)) {
MS_LOG(WARNING) << "The parameter :" << param_ptr->fullname_with_scope()
<< " is fully shard, thus cannot find common data parallel group for this rank";
return {};
}
}
AnfNodePtr ret = root->get_return();
MS_EXCEPTION_IF_NULL(ret);
std::vector<int64_t> common_group_list;
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
bool is_first_group = true;
for (auto &node : all_nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimMirror)) {
continue;
}
auto prim = GetCNodePrimitive(node);
if (!prim->HasAttr(GROUP)) {
MS_LOG(EXCEPTION) << "The mirror operator dose not have group attr : " << node->DebugString();
}
std::string group_name = GetValue<std::string>(prim->GetAttr(GROUP));
std::vector<int64_t> group_list = g_device_manager->FindRankListByHashName(group_name);
if (is_first_group) {
common_group_list = group_list;
is_first_group = false;
} else {
std::vector<int64_t> new_comm_group_list;
std::set_intersection(common_group_list.begin(), common_group_list.end(), group_list.begin(), group_list.end(),
std::back_inserter(new_comm_group_list));
common_group_list = new_comm_group_list;
}
}
MS_LOG(INFO) << "The common mirror group is:" << common_group_list;
return common_group_list;
}
static void HandleGroupInfo(const FuncGraphPtr &root) {
auto group_info = g_device_manager->group_info();
auto group_info_save_path = common::GetEnv("GROUP_INFO_FILE");
if (!group_info_save_path.empty()) {
ParallelContext::GetInstance()->set_group_ckpt_save_file(group_info_save_path);
}
if (StrategyCheckpoint::GetInstance().group_info_save_on()) {
RankList comm_group = FindCommonMirrorGroup(root);
if (StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info, comm_group) != SUCCESS) {
@ -3173,6 +3138,25 @@ static void HandleGroupInfo(const FuncGraphPtr &root) {
}
}
static void HandleDataParallel() {
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
if (parallel_mode == DATA_PARALLEL) {
auto group_info_save_path = common::GetEnv("GROUP_INFO_FILE");
if (!group_info_save_path.empty()) {
std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info;
int64_t device_num = GetCommInfo().device_num;
RankList comm_group;
for (size_t i = 0; i < size_t(device_num); ++i) {
comm_group.push_back(i);
}
ParallelContext::GetInstance()->set_group_ckpt_save_file(group_info_save_path);
if (StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info, comm_group) != SUCCESS) {
MS_LOG(EXCEPTION) << "Save group info failed";
}
}
}
}
static void PipelinePostProcess(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
if (pipeline_stages > 1) {
@ -3191,6 +3175,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
MS_EXCEPTION_IF_NULL(optimizer);
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
HandleDataParallel();
pipeline::ResourceBasePtr res = optimizer->resource();
MS_EXCEPTION_IF_NULL(res);
FuncGraphManagerPtr manager = res->manager();

View File

@ -35,6 +35,7 @@
#include "frontend/parallel/graph_util/graph_info.h"
#include "frontend/parallel/graph_util/node_info.h"
#include "frontend/parallel/node_check.h"
#include "frontend/parallel/parameter_manager.h"
#include "ir/param_info.h"
#include "ir/tensor.h"
#include "utils/trace_base.h"
@ -149,6 +150,61 @@ Shapes GetNodeShape(const AnfNodePtr &node) {
return shapes;
}
RankList FindCommonMirrorGroup(const FuncGraphPtr &root) {
auto parameters = root->parameters();
for (auto &parameter : parameters) {
auto param_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (!(param_ptr->has_default() && ParameterRequireGrad(param_ptr))) {
continue;
}
size_t allow_repeat_num = 1;
if (ParallelContext::GetInstance()->enable_parallel_optimizer() &&
(!param_ptr->param_info() || !param_ptr->param_info()->parallel_optimizer())) {
if (ParallelContext::GetInstance()->optimizer_weight_shard_size() == -1) {
MS_LOG(WARNING) << "The parameter :" << param_ptr->fullname_with_scope()
<< " is fully shard by optimizer parallel,"
" thus cannot find common data parallel group for this rank";
return {g_device_manager->global_rank()};
}
allow_repeat_num = size_t(ParallelContext::GetInstance()->optimizer_weight_shard_size());
}
if (IsFullySplitParameter(param_ptr, allow_repeat_num)) {
MS_LOG(WARNING) << "The parameter :" << param_ptr->fullname_with_scope()
<< " is fully shard, thus cannot find common data parallel group for this rank";
return {g_device_manager->global_rank()};
}
}
AnfNodePtr ret = root->get_return();
MS_EXCEPTION_IF_NULL(ret);
std::vector<int64_t> common_group_list;
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
bool is_first_group = true;
for (auto &node : all_nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimMirror) && !IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep) &&
!IsPrimitiveCNode(node, prim::kPrimMirrorMiniStep)) {
continue;
}
auto prim = GetCNodePrimitive(node);
if (!prim->HasAttr(GROUP)) {
MS_LOG(EXCEPTION) << "The mirror operator dose not have group attr : " << node->DebugString();
}
std::string group_name = GetValue<std::string>(prim->GetAttr(GROUP));
std::vector<int64_t> group_list = g_device_manager->FindRankListByHashName(group_name);
if (is_first_group) {
common_group_list = group_list;
is_first_group = false;
} else {
std::vector<int64_t> new_comm_group_list;
std::set_intersection(common_group_list.begin(), common_group_list.end(), group_list.begin(), group_list.end(),
std::back_inserter(new_comm_group_list));
common_group_list = new_comm_group_list;
}
}
MS_LOG(INFO) << "The common mirror group is:" << common_group_list;
return common_group_list;
}
std::string CreateInstanceName(const CNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
if (!IsValueNode<Primitive>(node->input(0))) {

View File

@ -28,6 +28,7 @@ namespace parallel {
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
bool IsParallelCareNode(const CNodePtr &cnode);
Shapes GetNodeShape(const AnfNodePtr &node);
RankList FindCommonMirrorGroup(const FuncGraphPtr &root);
std::string CreateInstanceName(const CNodePtr &node, size_t index);
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,

View File

@ -1160,10 +1160,12 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
return merged_tensor
def ckpt_restore_group_info(group_info_file_name):
def restore_group_info_list(group_info_file_name):
"""
Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
that saves the group_info_file_name
that saves the group_info_file_name.
To save the group info file, please export GROUP_INFO_FILE environment variables like
"export GROUP_INFO_FILE=/data/group_info.pb".
Args:
group_info_file_name (str): Name of group information file.
@ -1175,7 +1177,7 @@ def ckpt_restore_group_info(group_info_file_name):
TypeError: group_info_file_name is not str.
Examples:
>>> restore_list = ckpt_restore_group_info("./group_info.ckpt")
>>> restore_list = restore_group_info_list("./group_info.pb")
"""
if not isinstance(group_info_file_name, str):
raise TypeError(f"The group_info_file_name should be str, but got {type(group_info_file_name)}.")

View File

@ -0,0 +1,149 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test group info """
import os
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.common.api import _cell_graph_executor
from mindspore.nn import TrainOneStepCell
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P
from mindspore import context
from mindspore.train.serialization import restore_group_info_list
class Net3(nn.Cell):
"""Net definition"""
def __init__(self, strategy1, strategy2, strategy3):
super(Net3, self).__init__()
self.fc1 = P.MatMul().shard(strategy1)
self.fc2 = P.MatMul().shard(strategy2)
self.fc3 = P.MatMul().shard(strategy3)
self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2", parallel_optimizer=False)
self.p3 = Parameter(Tensor(np.ones([16, 16]).astype(np.float32)), name="weight3")
def construct(self, x, y):
x = self.fc1(x, self.p1)
x = self.fc2(x, self.p2)
z = x - y
z = self.fc3(z, self.p3)
return z
def auto_parallel_compile_net(strategy1=None, strategy2=None, strategy3=None):
context.set_context(mode=context.GRAPH_MODE)
inputs = Tensor(np.ones([32, 48]).astype(np.float32))
label = Tensor(np.zeros([32, 16]).astype(np.float32))
net = Net3(strategy1, strategy2, strategy3)
auto_parallel = context.get_auto_parallel_context("parallel_mode") in ["semi_auto_parallel", "auto_parallel"]
if auto_parallel:
net = _VirtualDatasetCell(net)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_network = TrainOneStepCell(net, optimizer).set_comm_fusion(4)
train_network.set_auto_parallel()
train_network.set_train()
_cell_graph_executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=auto_parallel)
def test_mirror_group():
"""
Feature: save and load mirror group
Description: semi-auto, disable parallel optimizer.
Expectation: group info list match expectation value.
"""
os.environ['GROUP_INFO_FILE'] = "./test_mirror_group.pb"
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel",
device_num=32, enable_parallel_optimizer=False)
auto_parallel_compile_net(((8, 1), (1, 4)), ((32, 1), (1, 1)), ((8, 4), (4, 1)))
group_info_list = restore_group_info_list("./test_mirror_group.pb")
assert group_info_list == [0, 4, 8, 12, 16, 20, 24, 28]
context.reset_auto_parallel_context()
del os.environ['GROUP_INFO_FILE']
def test_data_parallel_group():
"""
Feature: save and load mirror group
Description: data-parallel, disable parallel optimizer.
Expectation: group info list match expectation value.
"""
os.environ['GROUP_INFO_FILE'] = "./test_data_parallel_group.pb"
context.set_auto_parallel_context(parallel_mode="data_parallel",
device_num=32, enable_parallel_optimizer=False)
auto_parallel_compile_net(((8, 1), (1, 4)), ((32, 1), (1, 1)), ((8, 4), (4, 1)))
group_info_list = restore_group_info_list("./test_data_parallel_group.pb")
assert group_info_list == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31]
context.reset_auto_parallel_context()
del os.environ['GROUP_INFO_FILE']
def test_mirror_group_parallel_optimizer():
"""
Feature: save and load mirror group
Description: semi-auto, enable parallel optimizer.
Expectation: group info list match expectation value.
"""
os.environ['GROUP_INFO_FILE'] = "./test_mirror_group_parallel_optimizer.pb"
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel",
device_num=32, enable_parallel_optimizer=True)
auto_parallel_compile_net(((8, 1), (1, 4)), ((32, 1), (1, 1)), ((8, 4), (4, 1)))
group_info_list = restore_group_info_list("./test_mirror_group_parallel_optimizer.pb")
assert group_info_list == [0]
context.reset_auto_parallel_context()
del os.environ['GROUP_INFO_FILE']
def test_mirror_group_parallel_optimizer_not_full_shard():
"""
Feature: save and load mirror group
Description: semi-auto, enable parallel optimizer but not fully shard.
Expectation: group info list match expectation value.
"""
os.environ['GROUP_INFO_FILE'] = "./test_mirror_group_parallel_optimizer_not_full_shard.pb"
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel",
device_num=32, enable_parallel_optimizer=True, optimizer_weight_shard_size=2)
auto_parallel_compile_net(((8, 1), (1, 4)), ((32, 1), (1, 1)), ((8, 4), (4, 1)))
group_info_list = restore_group_info_list("./test_mirror_group_parallel_optimizer_not_full_shard.pb")
assert group_info_list == [0, 8, 16, 24]
context.reset_auto_parallel_context()
del os.environ['GROUP_INFO_FILE']
def test_pipeline_split_stage0_mirror_group():
"""
Feature: save and load mirror group
Description: semi-auto, pipeline parallel.
Expectation: group info list match expectation value.
"""
import mindspore as ms
from mindspore import Model
from .test_pipeline_split import PipelineCell, PipelineSplit, DatasetLenet
os.environ['GROUP_INFO_FILE'] = "./test_pipeline_split_stage0_mirror_group.pb"
context.set_auto_parallel_context(device_num=64, global_rank=0, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 8))
strategy2 = ((4, 1), (1, 1))
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
params = net.network.cell.block[0].trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
group_info_list = restore_group_info_list("./test_pipeline_split_stage0_mirror_group.pb")
assert group_info_list == [0, 8, 16, 24]