check layouts for shared parameter

This commit is contained in:
yangzhenzhang 2021-04-14 16:43:07 +08:00
parent d8d43a1368
commit bcd2ecc403
3 changed files with 77 additions and 18 deletions

View File

@ -3232,7 +3232,24 @@ ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)
return parameter_users_info;
}
Shape ParameterSliceShape(const std::pair<AnfNodePtr, int64_t> &param_info) {
RankList GetGroupByTensorInfo(const TensorInfo &tensor_info) {
CheckGlobalDeviceManager();
int64_t rank = g_device_manager->global_rank();
RankList stage_device_list = g_device_manager->GetDeviceListInThisStage();
Shape dev_matrix_shape = tensor_info.tensor_layout().device_arrangement().array();
Shape tensor_map = tensor_info.tensor_layout().tensor_map().array();
DeviceMatrix dev_matrix(rank, stage_device_list, dev_matrix_shape);
RankList group_devices;
if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
MS_LOG(EXCEPTION) << "Get devices by tensor map failed";
}
std::sort(group_devices.begin(), group_devices.end());
return group_devices;
}
ParameterSliceInfo GetParameterSliceInfo(const std::pair<AnfNodePtr, int64_t> &param_info) {
auto user_cnode = param_info.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_cnode);
auto user_input_index = param_info.second;
@ -3245,10 +3262,14 @@ Shape ParameterSliceShape(const std::pair<AnfNodePtr, int64_t> &param_info) {
<< ", but the index is " << user_input_index - 1;
}
TensorInfo tensor_info = op_info->inputs_tensor_info()[user_input_index - 1];
ParameterSliceInfo parameter_slice_info;
parameter_slice_info.slice_shape = tensor_info.slice_shape();
parameter_slice_info.group_ranks = GetGroupByTensorInfo(tensor_info);
MS_LOG(DEBUG) << "The op name is " << op_info->name() << ", the parameter index is " << user_input_index - 1
<< ", the slice shape is " << ShapeToString(tensor_info.slice_shape()) << ", the origin shape is "
<< ShapeToString(tensor_info.shape());
return tensor_info.slice_shape();
<< ", the slice shape is " << tensor_info.slice_shape() << ", the origin shape is "
<< tensor_info.shape() << ", the group rank list is " << parameter_slice_info.group_ranks;
return parameter_slice_info;
}
void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
@ -3262,13 +3283,24 @@ void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
auto parameter_name = parameter_users_info.first;
MS_LOG(INFO) << "The parameter: " << parameter_name << " has " << users_set.size() << " users";
auto first_user = users_set.pop();
Shape first_user_slice_shape = ParameterSliceShape(first_user);
ParameterSliceInfo parameter_slice_info = GetParameterSliceInfo(first_user);
Shape first_user_slice_shape = parameter_slice_info.slice_shape;
RankList first_user_group_list = parameter_slice_info.group_ranks;
for (auto &user : users_set) {
Shape user_slice_shape = ParameterSliceShape(user);
ParameterSliceInfo user_slice_info = GetParameterSliceInfo(user);
Shape user_slice_shape = user_slice_info.slice_shape;
RankList user_group_list = user_slice_info.group_ranks;
if (first_user_slice_shape != user_slice_shape) {
MS_LOG(EXCEPTION) << "The parameter: " << parameter_name
<< " has multiple users, but the split strategies are different";
<< " has multiple users, but the slice shapes are different";
}
if (ParallelContext::GetInstance()->pipeline_stage_split_num() == 1 && first_user_group_list != user_group_list) {
MS_LOG(EXCEPTION) << "The parameter: " << parameter_name
<< " has multiple users, but the group rank list are different, "
<< "the group rank list for first user is " << first_user_group_list
<< ", and the group rank list for this user is " << user_group_list;
}
}
}

View File

@ -46,6 +46,11 @@ struct LossNodeInfo {
CNodePtr loss_node = nullptr;
};
struct ParameterSliceInfo {
Shape slice_shape;
RankList group_ranks;
};
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name);
std::string CreateInstanceName(const CNodePtr &node, size_t index);
void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node);

View File

@ -47,9 +47,22 @@ class Net2(Cell):
return out
_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
_w = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
class Net3(Cell):
def __init__(self, mul_weight, strategy1=None, strategy2=None):
super().__init__()
self.mul = P.MatMul().shard(strategy1)
self.mul2 = P.MatMul().shard(strategy2)
self.mul_weight = Parameter(mul_weight, "w1")
def construct(self, x, b):
out = self.mul(x, self.mul_weight)
out = self.mul2(out, self.mul_weight)
return out
_x = Tensor(np.ones([16, 16]), dtype=ms.float32)
_w = Tensor(np.ones([16, 16]), dtype=ms.float32)
_b = Tensor(np.ones([16, 16]), dtype=ms.float32)
def compile_net(net):
@ -63,16 +76,16 @@ def compile_net(net):
def test_parameter_same_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((16, 1, 1), (16, 1, 1))
strategy2 = ((16, 1, 1), (16, 1, 1))
strategy1 = ((16, 1), (16, 1))
strategy2 = ((16, 1), (16, 1))
net = Net(_w, strategy1, strategy2)
compile_net(net)
def test_parameter_different_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((16, 1, 1), (16, 1, 1))
strategy2 = ((4, 4, 1), (4, 4, 1))
strategy1 = ((16, 1), (16, 1))
strategy2 = ((4, 4), (4, 4))
net = Net(_w, strategy1, strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
@ -80,16 +93,25 @@ def test_parameter_different_split():
def test_input_same_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((16, 1, 1), (16, 1, 1))
strategy2 = ((16, 1, 1), (16, 1, 1))
strategy1 = ((16, 1), (16, 1))
strategy2 = ((16, 1), (16, 1))
net = Net(_w, strategy1, strategy2)
compile_net(net)
def test_input_different_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((16, 1, 1), (16, 1, 1))
strategy2 = ((4, 4, 1), (4, 4, 1))
strategy1 = ((16, 1), (16, 1))
strategy2 = ((4, 4), (4, 4))
net = Net2(_w, strategy1, strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_parameter_different_group():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((1, 2), (2, 1))
strategy2 = ((8, 2), (2, 1))
net = Net3(_w, strategy1, strategy2)
with pytest.raises(RuntimeError):
compile_net(net)