forked from mindspore-Ecosystem/mindspore
!31164 Fix the global norm missing insert allreduce
Merge pull request !31164 from huangxinjing/fx_global_norm_error
This commit is contained in:
commit
c2212f88b4
|
@ -62,8 +62,6 @@ namespace parallel {
|
||||||
static const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
|
static const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
|
||||||
static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS, LOAD, UPDATESTATE};
|
static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS, LOAD, UPDATESTATE};
|
||||||
static const std::set<std::string> NO_INPUT_TENSOR_OPS = {UNIFORM_REAL};
|
static const std::set<std::string> NO_INPUT_TENSOR_OPS = {UNIFORM_REAL};
|
||||||
static const std::vector<std::pair<const std::string, int64_t>> REDUCE_SUM_MATCH_PATTERN = {
|
|
||||||
std::make_pair(MAKE_TUPLE, 1), std::make_pair(ADDN, 1), std::make_pair(SQRT, 1)};
|
|
||||||
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
|
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
|
||||||
// it will be one item in map with key: C, and value: (B, i)
|
// it will be one item in map with key: C, and value: (B, i)
|
||||||
std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
|
std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
|
||||||
|
@ -3179,9 +3177,22 @@ static void InsertAllReduceForNormValue(const AnfNodePtr &res_node) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
|
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
|
||||||
auto expand_dims_node = node_user_map.at(res_node).front().first;
|
auto find_node = res_node;
|
||||||
auto sqrt_node = MatchPattern(expand_dims_node, node_user_map, REDUCE_SUM_MATCH_PATTERN);
|
uint32_t limits = 0;
|
||||||
if (!sqrt_node) return;
|
while (!IsSomePrimitive(find_node->cast<CNodePtr>(), SQRT) && limits < MAX_BFS_DEPTH) {
|
||||||
|
auto users = node_user_map.at(find_node);
|
||||||
|
if (users.empty()) return;
|
||||||
|
find_node = users.front().first;
|
||||||
|
++limits;
|
||||||
|
}
|
||||||
|
if (!find_node || !IsSomePrimitive(find_node->cast<CNodePtr>(), SQRT)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto anf_node = find_node->cast<CNodePtr>();
|
||||||
|
if (anf_node->inputs().size() > 1 && IsSomePrimitive(anf_node->input(1)->cast<CNodePtr>(), ALL_REDUCE)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto sqrt_node = find_node;
|
||||||
auto cur_stage_rank_list = g_device_manager->GetDeviceListInThisStage();
|
auto cur_stage_rank_list = g_device_manager->GetDeviceListInThisStage();
|
||||||
Group cur_stage_device_list;
|
Group cur_stage_device_list;
|
||||||
if (g_device_manager->CreateGroup(cur_stage_rank_list, &cur_stage_device_list) != SUCCESS) {
|
if (g_device_manager->CreateGroup(cur_stage_rank_list, &cur_stage_device_list) != SUCCESS) {
|
||||||
|
@ -3245,27 +3256,22 @@ AnfNodePtr FindExpanDimsWIthGradScale(const AnfNodePtr &node_ptr, const NodeUser
|
||||||
|
|
||||||
static void InsertDivAndAllReduceForNorm(const NodeUsersMap &node_user_map, const AnfNodePtr ¶meter,
|
static void InsertDivAndAllReduceForNorm(const NodeUsersMap &node_user_map, const AnfNodePtr ¶meter,
|
||||||
uint32_t dev_num) {
|
uint32_t dev_num) {
|
||||||
AnfNodePtr expand_dims_node = nullptr;
|
|
||||||
AnfNodePtr prefix_node = nullptr;
|
|
||||||
auto params_user_set = node_user_map.at(parameter);
|
auto params_user_set = node_user_map.at(parameter);
|
||||||
for (auto ¶m_pair : params_user_set) {
|
for (auto ¶m_pair : params_user_set) {
|
||||||
expand_dims_node = nullptr;
|
|
||||||
auto cnode = param_pair.first->cast<CNodePtr>();
|
auto cnode = param_pair.first->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
if (cnode->in_forward_flag()) {
|
if (cnode->in_forward_flag()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
expand_dims_node = FindExpanDimsWIthGradScale(cnode, node_user_map, MAX_BFS_DEPTH);
|
auto expand_dims_node = FindExpanDimsWIthGradScale(cnode, node_user_map, MAX_BFS_DEPTH);
|
||||||
if (!expand_dims_node) {
|
if (!expand_dims_node) continue;
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto value = GetAttrsFromAnfNode(expand_dims_node, GRAD_SCALE);
|
auto value = GetAttrsFromAnfNode(expand_dims_node, GRAD_SCALE);
|
||||||
if (!value || !GetValue<bool>(value)) {
|
if (!value || !GetValue<bool>(value)) continue;
|
||||||
continue;
|
if (dev_num > 0) {
|
||||||
|
InsertRealDivOpToNodeInput(expand_dims_node->cast<CNodePtr>(), dev_num, PARALLEL_GLOBALNORM_DIV);
|
||||||
|
MS_LOG(INFO) << "Insert the realdiv with " << dev_num << " for the parameter " << parameter->fullname_with_scope()
|
||||||
|
<< "succeed!";
|
||||||
}
|
}
|
||||||
InsertRealDivOpToNodeInput(expand_dims_node->cast<CNodePtr>(), dev_num, PARALLEL_GLOBALNORM_DIV);
|
|
||||||
MS_LOG(INFO) << "Insert the realdiv with " << dev_num << " for the parameter " << parameter->DebugString()
|
|
||||||
<< "succeed!";
|
|
||||||
// If already inserted allreduce, the pattern will not be matched and thus no allreduce will be inserted.
|
// If already inserted allreduce, the pattern will not be matched and thus no allreduce will be inserted.
|
||||||
InsertAllReduceForNormValue(expand_dims_node);
|
InsertAllReduceForNormValue(expand_dims_node);
|
||||||
}
|
}
|
||||||
|
@ -3302,22 +3308,15 @@ static void HandlGlobalNormScale(const FuncGraphPtr &root, const std::vector<Anf
|
||||||
auto parameters = root->parameters();
|
auto parameters = root->parameters();
|
||||||
auto node_user_map = manager->node_users();
|
auto node_user_map = manager->node_users();
|
||||||
MS_LOG(INFO) << "Start to process the global norm";
|
MS_LOG(INFO) << "Start to process the global norm";
|
||||||
|
|
||||||
for (auto ¶meter : parameters) {
|
for (auto ¶meter : parameters) {
|
||||||
|
int64_t dev_num = 0;
|
||||||
if (!ParameterRequireGrad(parameter)) continue;
|
if (!ParameterRequireGrad(parameter)) continue;
|
||||||
auto mirror_node = GetMirrorOp(node_user_map, parameter);
|
auto mirror_node = GetMirrorOp(node_user_map, parameter);
|
||||||
if (!mirror_node) continue;
|
|
||||||
auto device_num_ptr = GetAttrsFromAnfNode(mirror_node, DEV_NUM);
|
auto device_num_ptr = GetAttrsFromAnfNode(mirror_node, DEV_NUM);
|
||||||
if (!device_num_ptr) {
|
if (device_num_ptr && device_num_ptr->isa<Int64Imm>()) {
|
||||||
MS_LOG(ERROR) << "The mirror operator is excepted to have device number attribute, but found none. This "
|
dev_num = GetValue<int64_t>(device_num_ptr);
|
||||||
"will cause the global norm calculation with wrong precision.";
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
if (!device_num_ptr->isa<Int64Imm>()) {
|
|
||||||
MS_LOG(ERROR) << "The type of device number attribute of mirror operator is not int64.";
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto dev_num = device_num_ptr->cast<Int64ImmPtr>()->value();
|
|
||||||
if (dev_num == 0) continue;
|
|
||||||
InsertDivAndAllReduceForNorm(node_user_map, parameter, dev_num);
|
InsertDivAndAllReduceForNorm(node_user_map, parameter, dev_num);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -521,34 +521,5 @@ std::shared_ptr<Value> GetAttrsFromAnfNode(const std::shared_ptr<AnfNode> &node,
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr MatchPattern(const AnfNodePtr &node, const NodeUsersMap &user_map,
|
|
||||||
const std::vector<std::pair<const std::string, int64_t>> &match_pattern) {
|
|
||||||
AnfNodePtr start_node = node;
|
|
||||||
bool find = false;
|
|
||||||
for (uint32_t i = 0; i < match_pattern.size(); ++i) {
|
|
||||||
find = false;
|
|
||||||
if (!IsSomePrimitive(start_node->cast<CNodePtr>(), {match_pattern[i].first})) {
|
|
||||||
break;
|
|
||||||
} else if (i == match_pattern.size() - 1) {
|
|
||||||
find = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto next_node_users = user_map.at(start_node);
|
|
||||||
for (auto &next_node : next_node_users) {
|
|
||||||
if (i + 1 < match_pattern.size() &&
|
|
||||||
IsSomePrimitive(next_node.first->cast<CNodePtr>(), {match_pattern[i + 1].first}) &&
|
|
||||||
next_node.second == match_pattern[i + 1].second) {
|
|
||||||
start_node = next_node.first;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!find) {
|
|
||||||
start_node = nullptr;
|
|
||||||
}
|
|
||||||
return start_node;
|
|
||||||
}
|
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -42,8 +42,6 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
|
||||||
std::string CreateInstanceName(const CNodePtr &node, size_t index);
|
std::string CreateInstanceName(const CNodePtr &node, size_t index);
|
||||||
TensorInfo GetInputsTensorInfo(const std::pair<AnfNodePtr, int64_t> ¶m_info);
|
TensorInfo GetInputsTensorInfo(const std::pair<AnfNodePtr, int64_t> ¶m_info);
|
||||||
AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr &manager);
|
AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr &manager);
|
||||||
AnfNodePtr MatchPattern(const AnfNodePtr &node, const NodeUsersMap &user_map,
|
|
||||||
const std::vector<std::pair<const std::string, int64_t>> &match_pattern);
|
|
||||||
|
|
||||||
// for specific scenarios
|
// for specific scenarios
|
||||||
RankList FindCommonMirrorGroup(const FuncGraphPtr &root);
|
RankList FindCommonMirrorGroup(const FuncGraphPtr &root);
|
||||||
|
|
|
@ -30,6 +30,21 @@ from mindspore.ops import operations as P
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
|
||||||
|
|
||||||
|
class OneParameterNet(nn.Cell):
|
||||||
|
"""Net definition"""
|
||||||
|
def __init__(self, param_type, strategy1, strategy2):
|
||||||
|
super(OneParameterNet, self).__init__()
|
||||||
|
self.fc1 = P.MatMul().shard(strategy1)
|
||||||
|
self.p1 = Parameter(Tensor(np.ones([48, 16]).astype(param_type)), name="weight1")
|
||||||
|
self.sub = P.Sub().shard(strategy2)
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
x = P.Cast()(x, ms.float16)
|
||||||
|
p1 = P.Cast()(self.p1, ms.float16)
|
||||||
|
x = self.fc1(x, p1)
|
||||||
|
return self.sub(x, 0)
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
"""Net definition"""
|
"""Net definition"""
|
||||||
def __init__(self, param_type, strategy1, strategy2):
|
def __init__(self, param_type, strategy1, strategy2):
|
||||||
|
@ -137,6 +152,16 @@ class TestGlobalNormInserted:
|
||||||
appear_count += 1
|
appear_count += 1
|
||||||
assert appear_count == target_count
|
assert appear_count == target_count
|
||||||
|
|
||||||
|
def test_nonpipeline_global_norm_one_parameter(self):
|
||||||
|
"""
|
||||||
|
Feature: Parallel ClipByGlobalNorm
|
||||||
|
Description: Test the global norm using one parameter, there should be only one allreduce
|
||||||
|
Expectation:When there is no PARALLEL_GLOBALNORM_IN_STAGES inserted
|
||||||
|
"""
|
||||||
|
auto_parallel_compile_net("semi_auto_parallel", 8, OneParameterNet, ((1, 8), (8, 1)), ((8, 1), ()),
|
||||||
|
interleaved_batch=1, param_type=np.float32)
|
||||||
|
self.run_count_check(target_count=1, pattern=r"PARALLEL_GLOBALNORM_IN_STAGES")
|
||||||
|
|
||||||
def test_nonpipeline_global_norm(self):
|
def test_nonpipeline_global_norm(self):
|
||||||
"""
|
"""
|
||||||
Feature: Parallel ClipByGlobalNorm
|
Feature: Parallel ClipByGlobalNorm
|
||||||
|
|
Loading…
Reference in New Issue