forked from mindspore-Ecosystem/mindspore
fix fault recover in optimizer shard
This commit is contained in:
parent
e6ee9b2f19
commit
01dc4bbdf9
|
@ -259,6 +259,7 @@ RankList DeviceManager::FindRankListByHashName(const std::string &hash_name) {
|
|||
}
|
||||
RankList rank_list;
|
||||
std::string rank_str = "";
|
||||
rank_list_name = rank_list_name + "-";
|
||||
for (size_t i = 0; i < rank_list_name.size(); i++) {
|
||||
if (rank_list_name[i] == '-') {
|
||||
int64_t rank_id = std::stoi(rank_str);
|
||||
|
|
|
@ -374,7 +374,7 @@ void HandleNoUsedParameter(const FuncGraphPtr &root) {
|
|||
}
|
||||
}
|
||||
|
||||
bool IsFullySplitParameter(const ParameterPtr ¶m_ptr) {
|
||||
bool IsFullySplitParameter(const ParameterPtr ¶m_ptr, size_t allow_repeat_num) {
|
||||
auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
|
||||
if (tensor_layout == nullptr) {
|
||||
return false;
|
||||
|
@ -391,7 +391,7 @@ bool IsFullySplitParameter(const ParameterPtr ¶m_ptr) {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (group_devices.size() == 1) {
|
||||
if (group_devices.size() <= allow_repeat_num) {
|
||||
MS_LOG(INFO) << "The parameter: " << param_ptr->name() << " is fully split";
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ void HandleFullySplitParameters(const FuncGraphPtr &root);
|
|||
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root);
|
||||
void HandleAdaFactorOpt(const FuncGraphPtr &root);
|
||||
bool ParameterIsCloned(const AnfNodePtr ¶meter_node);
|
||||
bool IsFullySplitParameter(const ParameterPtr ¶m_ptr);
|
||||
bool IsFullySplitParameter(const ParameterPtr ¶m_ptr, size_t allow_repeat_num = 1);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -3123,48 +3123,13 @@ bool IsInsertVirtualOutput(const FuncGraphPtr &root) {
|
|||
current_stage == split_stage_num - 1);
|
||||
}
|
||||
|
||||
RankList FindCommonMirrorGroup(const FuncGraphPtr &root) {
|
||||
auto parameters = root->parameters();
|
||||
for (auto ¶meter : parameters) {
|
||||
auto param_ptr = parameter->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
if (IsFullySplitParameter(param_ptr)) {
|
||||
MS_LOG(WARNING) << "The parameter :" << param_ptr->fullname_with_scope()
|
||||
<< " is fully shard, thus cannot find common data parallel group for this rank";
|
||||
return {};
|
||||
}
|
||||
}
|
||||
AnfNodePtr ret = root->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
std::vector<int64_t> common_group_list;
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
||||
bool is_first_group = true;
|
||||
for (auto &node : all_nodes) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimMirror)) {
|
||||
continue;
|
||||
}
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
if (!prim->HasAttr(GROUP)) {
|
||||
MS_LOG(EXCEPTION) << "The mirror operator dose not have group attr : " << node->DebugString();
|
||||
}
|
||||
std::string group_name = GetValue<std::string>(prim->GetAttr(GROUP));
|
||||
std::vector<int64_t> group_list = g_device_manager->FindRankListByHashName(group_name);
|
||||
if (is_first_group) {
|
||||
common_group_list = group_list;
|
||||
is_first_group = false;
|
||||
} else {
|
||||
std::vector<int64_t> new_comm_group_list;
|
||||
std::set_intersection(common_group_list.begin(), common_group_list.end(), group_list.begin(), group_list.end(),
|
||||
std::back_inserter(new_comm_group_list));
|
||||
common_group_list = new_comm_group_list;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "The common mirror group is:" << common_group_list;
|
||||
return common_group_list;
|
||||
}
|
||||
|
||||
static void HandleGroupInfo(const FuncGraphPtr &root) {
|
||||
auto group_info = g_device_manager->group_info();
|
||||
auto group_info_save_path = common::GetEnv("GROUP_INFO_FILE");
|
||||
if (!group_info_save_path.empty()) {
|
||||
ParallelContext::GetInstance()->set_group_ckpt_save_file(group_info_save_path);
|
||||
}
|
||||
|
||||
if (StrategyCheckpoint::GetInstance().group_info_save_on()) {
|
||||
RankList comm_group = FindCommonMirrorGroup(root);
|
||||
if (StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info, comm_group) != SUCCESS) {
|
||||
|
@ -3173,6 +3138,25 @@ static void HandleGroupInfo(const FuncGraphPtr &root) {
|
|||
}
|
||||
}
|
||||
|
||||
static void HandleDataParallel() {
|
||||
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
||||
if (parallel_mode == DATA_PARALLEL) {
|
||||
auto group_info_save_path = common::GetEnv("GROUP_INFO_FILE");
|
||||
if (!group_info_save_path.empty()) {
|
||||
std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info;
|
||||
int64_t device_num = GetCommInfo().device_num;
|
||||
RankList comm_group;
|
||||
for (size_t i = 0; i < size_t(device_num); ++i) {
|
||||
comm_group.push_back(i);
|
||||
}
|
||||
ParallelContext::GetInstance()->set_group_ckpt_save_file(group_info_save_path);
|
||||
if (StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info, comm_group) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Save group info failed";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void PipelinePostProcess(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
|
||||
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
|
||||
if (pipeline_stages > 1) {
|
||||
|
@ -3191,6 +3175,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
MS_EXCEPTION_IF_NULL(optimizer);
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
||||
HandleDataParallel();
|
||||
pipeline::ResourceBasePtr res = optimizer->resource();
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
FuncGraphManagerPtr manager = res->manager();
|
||||
|
|
|
@ -35,6 +35,7 @@
|
|||
#include "frontend/parallel/graph_util/graph_info.h"
|
||||
#include "frontend/parallel/graph_util/node_info.h"
|
||||
#include "frontend/parallel/node_check.h"
|
||||
#include "frontend/parallel/parameter_manager.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/trace_base.h"
|
||||
|
@ -149,6 +150,61 @@ Shapes GetNodeShape(const AnfNodePtr &node) {
|
|||
return shapes;
|
||||
}
|
||||
|
||||
RankList FindCommonMirrorGroup(const FuncGraphPtr &root) {
|
||||
auto parameters = root->parameters();
|
||||
for (auto ¶meter : parameters) {
|
||||
auto param_ptr = parameter->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
if (!(param_ptr->has_default() && ParameterRequireGrad(param_ptr))) {
|
||||
continue;
|
||||
}
|
||||
size_t allow_repeat_num = 1;
|
||||
if (ParallelContext::GetInstance()->enable_parallel_optimizer() &&
|
||||
(!param_ptr->param_info() || !param_ptr->param_info()->parallel_optimizer())) {
|
||||
if (ParallelContext::GetInstance()->optimizer_weight_shard_size() == -1) {
|
||||
MS_LOG(WARNING) << "The parameter :" << param_ptr->fullname_with_scope()
|
||||
<< " is fully shard by optimizer parallel,"
|
||||
" thus cannot find common data parallel group for this rank";
|
||||
return {g_device_manager->global_rank()};
|
||||
}
|
||||
allow_repeat_num = size_t(ParallelContext::GetInstance()->optimizer_weight_shard_size());
|
||||
}
|
||||
if (IsFullySplitParameter(param_ptr, allow_repeat_num)) {
|
||||
MS_LOG(WARNING) << "The parameter :" << param_ptr->fullname_with_scope()
|
||||
<< " is fully shard, thus cannot find common data parallel group for this rank";
|
||||
return {g_device_manager->global_rank()};
|
||||
}
|
||||
}
|
||||
AnfNodePtr ret = root->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
std::vector<int64_t> common_group_list;
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
||||
bool is_first_group = true;
|
||||
for (auto &node : all_nodes) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimMirror) && !IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep) &&
|
||||
!IsPrimitiveCNode(node, prim::kPrimMirrorMiniStep)) {
|
||||
continue;
|
||||
}
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
if (!prim->HasAttr(GROUP)) {
|
||||
MS_LOG(EXCEPTION) << "The mirror operator dose not have group attr : " << node->DebugString();
|
||||
}
|
||||
std::string group_name = GetValue<std::string>(prim->GetAttr(GROUP));
|
||||
std::vector<int64_t> group_list = g_device_manager->FindRankListByHashName(group_name);
|
||||
if (is_first_group) {
|
||||
common_group_list = group_list;
|
||||
is_first_group = false;
|
||||
} else {
|
||||
std::vector<int64_t> new_comm_group_list;
|
||||
std::set_intersection(common_group_list.begin(), common_group_list.end(), group_list.begin(), group_list.end(),
|
||||
std::back_inserter(new_comm_group_list));
|
||||
common_group_list = new_comm_group_list;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "The common mirror group is:" << common_group_list;
|
||||
return common_group_list;
|
||||
}
|
||||
|
||||
std::string CreateInstanceName(const CNodePtr &node, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!IsValueNode<Primitive>(node->input(0))) {
|
||||
|
|
|
@ -28,6 +28,7 @@ namespace parallel {
|
|||
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
|
||||
bool IsParallelCareNode(const CNodePtr &cnode);
|
||||
Shapes GetNodeShape(const AnfNodePtr &node);
|
||||
RankList FindCommonMirrorGroup(const FuncGraphPtr &root);
|
||||
std::string CreateInstanceName(const CNodePtr &node, size_t index);
|
||||
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
|
||||
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
|
||||
|
|
|
@ -1160,10 +1160,12 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
|
|||
return merged_tensor
|
||||
|
||||
|
||||
def ckpt_restore_group_info(group_info_file_name):
|
||||
def restore_group_info_list(group_info_file_name):
|
||||
"""
|
||||
Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
|
||||
that saves the group_info_file_name
|
||||
that saves the group_info_file_name.
|
||||
To save the group info file, please export GROUP_INFO_FILE environment variables like
|
||||
"export GROUP_INFO_FILE=/data/group_info.pb".
|
||||
Args:
|
||||
group_info_file_name (str): Name of group information file.
|
||||
|
||||
|
@ -1175,7 +1177,7 @@ def ckpt_restore_group_info(group_info_file_name):
|
|||
TypeError: group_info_file_name is not str.
|
||||
|
||||
Examples:
|
||||
>>> restore_list = ckpt_restore_group_info("./group_info.ckpt")
|
||||
>>> restore_list = restore_group_info_list("./group_info.pb")
|
||||
"""
|
||||
if not isinstance(group_info_file_name, str):
|
||||
raise TypeError(f"The group_info_file_name should be str, but got {type(group_info_file_name)}.")
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test group info """
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.nn import TrainOneStepCell
|
||||
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import restore_group_info_list
|
||||
|
||||
|
||||
class Net3(nn.Cell):
|
||||
"""Net definition"""
|
||||
def __init__(self, strategy1, strategy2, strategy3):
|
||||
super(Net3, self).__init__()
|
||||
self.fc1 = P.MatMul().shard(strategy1)
|
||||
self.fc2 = P.MatMul().shard(strategy2)
|
||||
self.fc3 = P.MatMul().shard(strategy3)
|
||||
self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
|
||||
self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2", parallel_optimizer=False)
|
||||
self.p3 = Parameter(Tensor(np.ones([16, 16]).astype(np.float32)), name="weight3")
|
||||
|
||||
def construct(self, x, y):
|
||||
x = self.fc1(x, self.p1)
|
||||
x = self.fc2(x, self.p2)
|
||||
z = x - y
|
||||
z = self.fc3(z, self.p3)
|
||||
return z
|
||||
|
||||
|
||||
def auto_parallel_compile_net(strategy1=None, strategy2=None, strategy3=None):
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
inputs = Tensor(np.ones([32, 48]).astype(np.float32))
|
||||
label = Tensor(np.zeros([32, 16]).astype(np.float32))
|
||||
net = Net3(strategy1, strategy2, strategy3)
|
||||
auto_parallel = context.get_auto_parallel_context("parallel_mode") in ["semi_auto_parallel", "auto_parallel"]
|
||||
if auto_parallel:
|
||||
net = _VirtualDatasetCell(net)
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_network = TrainOneStepCell(net, optimizer).set_comm_fusion(4)
|
||||
train_network.set_auto_parallel()
|
||||
train_network.set_train()
|
||||
_cell_graph_executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=auto_parallel)
|
||||
|
||||
|
||||
|
||||
def test_mirror_group():
|
||||
"""
|
||||
Feature: save and load mirror group
|
||||
Description: semi-auto, disable parallel optimizer.
|
||||
Expectation: group info list match expectation value.
|
||||
"""
|
||||
os.environ['GROUP_INFO_FILE'] = "./test_mirror_group.pb"
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel",
|
||||
device_num=32, enable_parallel_optimizer=False)
|
||||
auto_parallel_compile_net(((8, 1), (1, 4)), ((32, 1), (1, 1)), ((8, 4), (4, 1)))
|
||||
group_info_list = restore_group_info_list("./test_mirror_group.pb")
|
||||
assert group_info_list == [0, 4, 8, 12, 16, 20, 24, 28]
|
||||
context.reset_auto_parallel_context()
|
||||
del os.environ['GROUP_INFO_FILE']
|
||||
|
||||
def test_data_parallel_group():
|
||||
"""
|
||||
Feature: save and load mirror group
|
||||
Description: data-parallel, disable parallel optimizer.
|
||||
Expectation: group info list match expectation value.
|
||||
"""
|
||||
os.environ['GROUP_INFO_FILE'] = "./test_data_parallel_group.pb"
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel",
|
||||
device_num=32, enable_parallel_optimizer=False)
|
||||
auto_parallel_compile_net(((8, 1), (1, 4)), ((32, 1), (1, 1)), ((8, 4), (4, 1)))
|
||||
group_info_list = restore_group_info_list("./test_data_parallel_group.pb")
|
||||
assert group_info_list == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31]
|
||||
context.reset_auto_parallel_context()
|
||||
del os.environ['GROUP_INFO_FILE']
|
||||
|
||||
def test_mirror_group_parallel_optimizer():
|
||||
"""
|
||||
Feature: save and load mirror group
|
||||
Description: semi-auto, enable parallel optimizer.
|
||||
Expectation: group info list match expectation value.
|
||||
"""
|
||||
os.environ['GROUP_INFO_FILE'] = "./test_mirror_group_parallel_optimizer.pb"
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel",
|
||||
device_num=32, enable_parallel_optimizer=True)
|
||||
auto_parallel_compile_net(((8, 1), (1, 4)), ((32, 1), (1, 1)), ((8, 4), (4, 1)))
|
||||
group_info_list = restore_group_info_list("./test_mirror_group_parallel_optimizer.pb")
|
||||
assert group_info_list == [0]
|
||||
context.reset_auto_parallel_context()
|
||||
del os.environ['GROUP_INFO_FILE']
|
||||
|
||||
def test_mirror_group_parallel_optimizer_not_full_shard():
|
||||
"""
|
||||
Feature: save and load mirror group
|
||||
Description: semi-auto, enable parallel optimizer but not fully shard.
|
||||
Expectation: group info list match expectation value.
|
||||
"""
|
||||
os.environ['GROUP_INFO_FILE'] = "./test_mirror_group_parallel_optimizer_not_full_shard.pb"
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel",
|
||||
device_num=32, enable_parallel_optimizer=True, optimizer_weight_shard_size=2)
|
||||
auto_parallel_compile_net(((8, 1), (1, 4)), ((32, 1), (1, 1)), ((8, 4), (4, 1)))
|
||||
group_info_list = restore_group_info_list("./test_mirror_group_parallel_optimizer_not_full_shard.pb")
|
||||
assert group_info_list == [0, 8, 16, 24]
|
||||
context.reset_auto_parallel_context()
|
||||
del os.environ['GROUP_INFO_FILE']
|
||||
|
||||
def test_pipeline_split_stage0_mirror_group():
|
||||
"""
|
||||
Feature: save and load mirror group
|
||||
Description: semi-auto, pipeline parallel.
|
||||
Expectation: group info list match expectation value.
|
||||
"""
|
||||
import mindspore as ms
|
||||
from mindspore import Model
|
||||
from .test_pipeline_split import PipelineCell, PipelineSplit, DatasetLenet
|
||||
os.environ['GROUP_INFO_FILE'] = "./test_pipeline_split_stage0_mirror_group.pb"
|
||||
context.set_auto_parallel_context(device_num=64, global_rank=0, pipeline_stages=2)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
strategy1 = ((4, 1), (1, 8))
|
||||
strategy2 = ((4, 1), (1, 1))
|
||||
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
|
||||
params = net.network.cell.block[0].trainable_params()
|
||||
dataset = DatasetLenet(data, label, 3)
|
||||
optimizer = nn.Lamb(params, learning_rate=0.01)
|
||||
model = Model(net, optimizer=optimizer)
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
||||
group_info_list = restore_group_info_list("./test_pipeline_split_stage0_mirror_group.pb")
|
||||
assert group_info_list == [0, 8, 16, 24]
|
Loading…
Reference in New Issue