!31164 Fix the global norm missing insert allreduce

Merge pull request !31164 from huangxinjing/fx_global_norm_error
This commit is contained in:
i-robot 2022-03-16 06:46:13 +00:00 committed by Gitee
commit c2212f88b4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 52 additions and 59 deletions

View File

@ -62,8 +62,6 @@ namespace parallel {
static const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS, LOAD, UPDATESTATE};
static const std::set<std::string> NO_INPUT_TENSOR_OPS = {UNIFORM_REAL};
static const std::vector<std::pair<const std::string, int64_t>> REDUCE_SUM_MATCH_PATTERN = {
std::make_pair(MAKE_TUPLE, 1), std::make_pair(ADDN, 1), std::make_pair(SQRT, 1)};
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
// it will be one item in map with key: C, and value: (B, i)
std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
@ -3179,9 +3177,22 @@ static void InsertAllReduceForNormValue(const AnfNodePtr &res_node) {
return;
}
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
auto expand_dims_node = node_user_map.at(res_node).front().first;
auto sqrt_node = MatchPattern(expand_dims_node, node_user_map, REDUCE_SUM_MATCH_PATTERN);
if (!sqrt_node) return;
auto find_node = res_node;
uint32_t limits = 0;
while (!IsSomePrimitive(find_node->cast<CNodePtr>(), SQRT) && limits < MAX_BFS_DEPTH) {
auto users = node_user_map.at(find_node);
if (users.empty()) return;
find_node = users.front().first;
++limits;
}
if (!find_node || !IsSomePrimitive(find_node->cast<CNodePtr>(), SQRT)) {
return;
}
auto anf_node = find_node->cast<CNodePtr>();
if (anf_node->inputs().size() > 1 && IsSomePrimitive(anf_node->input(1)->cast<CNodePtr>(), ALL_REDUCE)) {
return;
}
auto sqrt_node = find_node;
auto cur_stage_rank_list = g_device_manager->GetDeviceListInThisStage();
Group cur_stage_device_list;
if (g_device_manager->CreateGroup(cur_stage_rank_list, &cur_stage_device_list) != SUCCESS) {
@ -3245,27 +3256,22 @@ AnfNodePtr FindExpanDimsWIthGradScale(const AnfNodePtr &node_ptr, const NodeUser
static void InsertDivAndAllReduceForNorm(const NodeUsersMap &node_user_map, const AnfNodePtr &parameter,
uint32_t dev_num) {
AnfNodePtr expand_dims_node = nullptr;
AnfNodePtr prefix_node = nullptr;
auto params_user_set = node_user_map.at(parameter);
for (auto &param_pair : params_user_set) {
expand_dims_node = nullptr;
auto cnode = param_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->in_forward_flag()) {
continue;
}
expand_dims_node = FindExpanDimsWIthGradScale(cnode, node_user_map, MAX_BFS_DEPTH);
if (!expand_dims_node) {
continue;
}
auto expand_dims_node = FindExpanDimsWIthGradScale(cnode, node_user_map, MAX_BFS_DEPTH);
if (!expand_dims_node) continue;
auto value = GetAttrsFromAnfNode(expand_dims_node, GRAD_SCALE);
if (!value || !GetValue<bool>(value)) {
continue;
}
if (!value || !GetValue<bool>(value)) continue;
if (dev_num > 0) {
InsertRealDivOpToNodeInput(expand_dims_node->cast<CNodePtr>(), dev_num, PARALLEL_GLOBALNORM_DIV);
MS_LOG(INFO) << "Insert the realdiv with " << dev_num << " for the parameter " << parameter->DebugString()
MS_LOG(INFO) << "Insert the realdiv with " << dev_num << " for the parameter " << parameter->fullname_with_scope()
<< "succeed!";
}
// If already inserted allreduce, the pattern will not be matched and thus no allreduce will be inserted.
InsertAllReduceForNormValue(expand_dims_node);
}
@ -3302,22 +3308,15 @@ static void HandlGlobalNormScale(const FuncGraphPtr &root, const std::vector<Anf
auto parameters = root->parameters();
auto node_user_map = manager->node_users();
MS_LOG(INFO) << "Start to process the global norm";
for (auto &parameter : parameters) {
int64_t dev_num = 0;
if (!ParameterRequireGrad(parameter)) continue;
auto mirror_node = GetMirrorOp(node_user_map, parameter);
if (!mirror_node) continue;
auto device_num_ptr = GetAttrsFromAnfNode(mirror_node, DEV_NUM);
if (!device_num_ptr) {
MS_LOG(ERROR) << "The mirror operator is excepted to have device number attribute, but found none. This "
"will cause the global norm calculation with wrong precision.";
continue;
if (device_num_ptr && device_num_ptr->isa<Int64Imm>()) {
dev_num = GetValue<int64_t>(device_num_ptr);
}
if (!device_num_ptr->isa<Int64Imm>()) {
MS_LOG(ERROR) << "The type of device number attribute of mirror operator is not int64.";
continue;
}
auto dev_num = device_num_ptr->cast<Int64ImmPtr>()->value();
if (dev_num == 0) continue;
InsertDivAndAllReduceForNorm(node_user_map, parameter, dev_num);
}
}

View File

@ -521,34 +521,5 @@ std::shared_ptr<Value> GetAttrsFromAnfNode(const std::shared_ptr<AnfNode> &node,
}
return nullptr;
}
AnfNodePtr MatchPattern(const AnfNodePtr &node, const NodeUsersMap &user_map,
const std::vector<std::pair<const std::string, int64_t>> &match_pattern) {
AnfNodePtr start_node = node;
bool find = false;
for (uint32_t i = 0; i < match_pattern.size(); ++i) {
find = false;
if (!IsSomePrimitive(start_node->cast<CNodePtr>(), {match_pattern[i].first})) {
break;
} else if (i == match_pattern.size() - 1) {
find = true;
break;
}
auto next_node_users = user_map.at(start_node);
for (auto &next_node : next_node_users) {
if (i + 1 < match_pattern.size() &&
IsSomePrimitive(next_node.first->cast<CNodePtr>(), {match_pattern[i + 1].first}) &&
next_node.second == match_pattern[i + 1].second) {
start_node = next_node.first;
break;
}
}
}
if (!find) {
start_node = nullptr;
}
return start_node;
}
} // namespace parallel
} // namespace mindspore

View File

@ -42,8 +42,6 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
std::string CreateInstanceName(const CNodePtr &node, size_t index);
TensorInfo GetInputsTensorInfo(const std::pair<AnfNodePtr, int64_t> &param_info);
AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr &manager);
AnfNodePtr MatchPattern(const AnfNodePtr &node, const NodeUsersMap &user_map,
const std::vector<std::pair<const std::string, int64_t>> &match_pattern);
// for specific scenarios
RankList FindCommonMirrorGroup(const FuncGraphPtr &root);

View File

@ -30,6 +30,21 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore import context
class OneParameterNet(nn.Cell):
"""Net definition"""
def __init__(self, param_type, strategy1, strategy2):
super(OneParameterNet, self).__init__()
self.fc1 = P.MatMul().shard(strategy1)
self.p1 = Parameter(Tensor(np.ones([48, 16]).astype(param_type)), name="weight1")
self.sub = P.Sub().shard(strategy2)
def construct(self, x, y):
x = P.Cast()(x, ms.float16)
p1 = P.Cast()(self.p1, ms.float16)
x = self.fc1(x, p1)
return self.sub(x, 0)
class Net(nn.Cell):
"""Net definition"""
def __init__(self, param_type, strategy1, strategy2):
@ -137,6 +152,16 @@ class TestGlobalNormInserted:
appear_count += 1
assert appear_count == target_count
def test_nonpipeline_global_norm_one_parameter(self):
"""
Feature: Parallel ClipByGlobalNorm
Description: Test the global norm using one parameter, there should be only one allreduce
Expectation:When there is no PARALLEL_GLOBALNORM_IN_STAGES inserted
"""
auto_parallel_compile_net("semi_auto_parallel", 8, OneParameterNet, ((1, 8), (8, 1)), ((8, 1), ()),
interleaved_batch=1, param_type=np.float32)
self.run_count_check(target_count=1, pattern=r"PARALLEL_GLOBALNORM_IN_STAGES")
def test_nonpipeline_global_norm(self):
"""
Feature: Parallel ClipByGlobalNorm