forked from mindspore-Ecosystem/mindspore
!31164 Fix the global norm missing insert allreduce
Merge pull request !31164 from huangxinjing/fx_global_norm_error
This commit is contained in:
commit
c2212f88b4
|
@ -62,8 +62,6 @@ namespace parallel {
|
|||
static const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
|
||||
static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS, LOAD, UPDATESTATE};
|
||||
static const std::set<std::string> NO_INPUT_TENSOR_OPS = {UNIFORM_REAL};
|
||||
static const std::vector<std::pair<const std::string, int64_t>> REDUCE_SUM_MATCH_PATTERN = {
|
||||
std::make_pair(MAKE_TUPLE, 1), std::make_pair(ADDN, 1), std::make_pair(SQRT, 1)};
|
||||
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
|
||||
// it will be one item in map with key: C, and value: (B, i)
|
||||
std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
|
||||
|
@ -3179,9 +3177,22 @@ static void InsertAllReduceForNormValue(const AnfNodePtr &res_node) {
|
|||
return;
|
||||
}
|
||||
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
|
||||
auto expand_dims_node = node_user_map.at(res_node).front().first;
|
||||
auto sqrt_node = MatchPattern(expand_dims_node, node_user_map, REDUCE_SUM_MATCH_PATTERN);
|
||||
if (!sqrt_node) return;
|
||||
auto find_node = res_node;
|
||||
uint32_t limits = 0;
|
||||
while (!IsSomePrimitive(find_node->cast<CNodePtr>(), SQRT) && limits < MAX_BFS_DEPTH) {
|
||||
auto users = node_user_map.at(find_node);
|
||||
if (users.empty()) return;
|
||||
find_node = users.front().first;
|
||||
++limits;
|
||||
}
|
||||
if (!find_node || !IsSomePrimitive(find_node->cast<CNodePtr>(), SQRT)) {
|
||||
return;
|
||||
}
|
||||
auto anf_node = find_node->cast<CNodePtr>();
|
||||
if (anf_node->inputs().size() > 1 && IsSomePrimitive(anf_node->input(1)->cast<CNodePtr>(), ALL_REDUCE)) {
|
||||
return;
|
||||
}
|
||||
auto sqrt_node = find_node;
|
||||
auto cur_stage_rank_list = g_device_manager->GetDeviceListInThisStage();
|
||||
Group cur_stage_device_list;
|
||||
if (g_device_manager->CreateGroup(cur_stage_rank_list, &cur_stage_device_list) != SUCCESS) {
|
||||
|
@ -3245,27 +3256,22 @@ AnfNodePtr FindExpanDimsWIthGradScale(const AnfNodePtr &node_ptr, const NodeUser
|
|||
|
||||
static void InsertDivAndAllReduceForNorm(const NodeUsersMap &node_user_map, const AnfNodePtr ¶meter,
|
||||
uint32_t dev_num) {
|
||||
AnfNodePtr expand_dims_node = nullptr;
|
||||
AnfNodePtr prefix_node = nullptr;
|
||||
auto params_user_set = node_user_map.at(parameter);
|
||||
for (auto ¶m_pair : params_user_set) {
|
||||
expand_dims_node = nullptr;
|
||||
auto cnode = param_pair.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->in_forward_flag()) {
|
||||
continue;
|
||||
}
|
||||
expand_dims_node = FindExpanDimsWIthGradScale(cnode, node_user_map, MAX_BFS_DEPTH);
|
||||
if (!expand_dims_node) {
|
||||
continue;
|
||||
}
|
||||
auto expand_dims_node = FindExpanDimsWIthGradScale(cnode, node_user_map, MAX_BFS_DEPTH);
|
||||
if (!expand_dims_node) continue;
|
||||
auto value = GetAttrsFromAnfNode(expand_dims_node, GRAD_SCALE);
|
||||
if (!value || !GetValue<bool>(value)) {
|
||||
continue;
|
||||
}
|
||||
if (!value || !GetValue<bool>(value)) continue;
|
||||
if (dev_num > 0) {
|
||||
InsertRealDivOpToNodeInput(expand_dims_node->cast<CNodePtr>(), dev_num, PARALLEL_GLOBALNORM_DIV);
|
||||
MS_LOG(INFO) << "Insert the realdiv with " << dev_num << " for the parameter " << parameter->DebugString()
|
||||
MS_LOG(INFO) << "Insert the realdiv with " << dev_num << " for the parameter " << parameter->fullname_with_scope()
|
||||
<< "succeed!";
|
||||
}
|
||||
// If already inserted allreduce, the pattern will not be matched and thus no allreduce will be inserted.
|
||||
InsertAllReduceForNormValue(expand_dims_node);
|
||||
}
|
||||
|
@ -3302,22 +3308,15 @@ static void HandlGlobalNormScale(const FuncGraphPtr &root, const std::vector<Anf
|
|||
auto parameters = root->parameters();
|
||||
auto node_user_map = manager->node_users();
|
||||
MS_LOG(INFO) << "Start to process the global norm";
|
||||
|
||||
for (auto ¶meter : parameters) {
|
||||
int64_t dev_num = 0;
|
||||
if (!ParameterRequireGrad(parameter)) continue;
|
||||
auto mirror_node = GetMirrorOp(node_user_map, parameter);
|
||||
if (!mirror_node) continue;
|
||||
auto device_num_ptr = GetAttrsFromAnfNode(mirror_node, DEV_NUM);
|
||||
if (!device_num_ptr) {
|
||||
MS_LOG(ERROR) << "The mirror operator is excepted to have device number attribute, but found none. This "
|
||||
"will cause the global norm calculation with wrong precision.";
|
||||
continue;
|
||||
if (device_num_ptr && device_num_ptr->isa<Int64Imm>()) {
|
||||
dev_num = GetValue<int64_t>(device_num_ptr);
|
||||
}
|
||||
if (!device_num_ptr->isa<Int64Imm>()) {
|
||||
MS_LOG(ERROR) << "The type of device number attribute of mirror operator is not int64.";
|
||||
continue;
|
||||
}
|
||||
auto dev_num = device_num_ptr->cast<Int64ImmPtr>()->value();
|
||||
if (dev_num == 0) continue;
|
||||
InsertDivAndAllReduceForNorm(node_user_map, parameter, dev_num);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -521,34 +521,5 @@ std::shared_ptr<Value> GetAttrsFromAnfNode(const std::shared_ptr<AnfNode> &node,
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr MatchPattern(const AnfNodePtr &node, const NodeUsersMap &user_map,
|
||||
const std::vector<std::pair<const std::string, int64_t>> &match_pattern) {
|
||||
AnfNodePtr start_node = node;
|
||||
bool find = false;
|
||||
for (uint32_t i = 0; i < match_pattern.size(); ++i) {
|
||||
find = false;
|
||||
if (!IsSomePrimitive(start_node->cast<CNodePtr>(), {match_pattern[i].first})) {
|
||||
break;
|
||||
} else if (i == match_pattern.size() - 1) {
|
||||
find = true;
|
||||
break;
|
||||
}
|
||||
|
||||
auto next_node_users = user_map.at(start_node);
|
||||
for (auto &next_node : next_node_users) {
|
||||
if (i + 1 < match_pattern.size() &&
|
||||
IsSomePrimitive(next_node.first->cast<CNodePtr>(), {match_pattern[i + 1].first}) &&
|
||||
next_node.second == match_pattern[i + 1].second) {
|
||||
start_node = next_node.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!find) {
|
||||
start_node = nullptr;
|
||||
}
|
||||
return start_node;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,8 +42,6 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
|
|||
std::string CreateInstanceName(const CNodePtr &node, size_t index);
|
||||
TensorInfo GetInputsTensorInfo(const std::pair<AnfNodePtr, int64_t> ¶m_info);
|
||||
AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr &manager);
|
||||
AnfNodePtr MatchPattern(const AnfNodePtr &node, const NodeUsersMap &user_map,
|
||||
const std::vector<std::pair<const std::string, int64_t>> &match_pattern);
|
||||
|
||||
// for specific scenarios
|
||||
RankList FindCommonMirrorGroup(const FuncGraphPtr &root);
|
||||
|
|
|
@ -30,6 +30,21 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops import composite as C
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class OneParameterNet(nn.Cell):
|
||||
"""Net definition"""
|
||||
def __init__(self, param_type, strategy1, strategy2):
|
||||
super(OneParameterNet, self).__init__()
|
||||
self.fc1 = P.MatMul().shard(strategy1)
|
||||
self.p1 = Parameter(Tensor(np.ones([48, 16]).astype(param_type)), name="weight1")
|
||||
self.sub = P.Sub().shard(strategy2)
|
||||
|
||||
def construct(self, x, y):
|
||||
x = P.Cast()(x, ms.float16)
|
||||
p1 = P.Cast()(self.p1, ms.float16)
|
||||
x = self.fc1(x, p1)
|
||||
return self.sub(x, 0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""Net definition"""
|
||||
def __init__(self, param_type, strategy1, strategy2):
|
||||
|
@ -137,6 +152,16 @@ class TestGlobalNormInserted:
|
|||
appear_count += 1
|
||||
assert appear_count == target_count
|
||||
|
||||
def test_nonpipeline_global_norm_one_parameter(self):
|
||||
"""
|
||||
Feature: Parallel ClipByGlobalNorm
|
||||
Description: Test the global norm using one parameter, there should be only one allreduce
|
||||
Expectation:When there is no PARALLEL_GLOBALNORM_IN_STAGES inserted
|
||||
"""
|
||||
auto_parallel_compile_net("semi_auto_parallel", 8, OneParameterNet, ((1, 8), (8, 1)), ((8, 1), ()),
|
||||
interleaved_batch=1, param_type=np.float32)
|
||||
self.run_count_check(target_count=1, pattern=r"PARALLEL_GLOBALNORM_IN_STAGES")
|
||||
|
||||
def test_nonpipeline_global_norm(self):
|
||||
"""
|
||||
Feature: Parallel ClipByGlobalNorm
|
||||
|
|
Loading…
Reference in New Issue