forked from mindspore-Ecosystem/mindspore
check layouts for shared parameter
This commit is contained in:
parent
d8d43a1368
commit
bcd2ecc403
|
@ -3232,7 +3232,24 @@ ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)
|
|||
return parameter_users_info;
|
||||
}
|
||||
|
||||
Shape ParameterSliceShape(const std::pair<AnfNodePtr, int64_t> ¶m_info) {
|
||||
RankList GetGroupByTensorInfo(const TensorInfo &tensor_info) {
|
||||
CheckGlobalDeviceManager();
|
||||
int64_t rank = g_device_manager->global_rank();
|
||||
RankList stage_device_list = g_device_manager->GetDeviceListInThisStage();
|
||||
Shape dev_matrix_shape = tensor_info.tensor_layout().device_arrangement().array();
|
||||
Shape tensor_map = tensor_info.tensor_layout().tensor_map().array();
|
||||
|
||||
DeviceMatrix dev_matrix(rank, stage_device_list, dev_matrix_shape);
|
||||
RankList group_devices;
|
||||
if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Get devices by tensor map failed";
|
||||
}
|
||||
|
||||
std::sort(group_devices.begin(), group_devices.end());
|
||||
return group_devices;
|
||||
}
|
||||
|
||||
ParameterSliceInfo GetParameterSliceInfo(const std::pair<AnfNodePtr, int64_t> ¶m_info) {
|
||||
auto user_cnode = param_info.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(user_cnode);
|
||||
auto user_input_index = param_info.second;
|
||||
|
@ -3245,10 +3262,14 @@ Shape ParameterSliceShape(const std::pair<AnfNodePtr, int64_t> ¶m_info) {
|
|||
<< ", but the index is " << user_input_index - 1;
|
||||
}
|
||||
TensorInfo tensor_info = op_info->inputs_tensor_info()[user_input_index - 1];
|
||||
|
||||
ParameterSliceInfo parameter_slice_info;
|
||||
parameter_slice_info.slice_shape = tensor_info.slice_shape();
|
||||
parameter_slice_info.group_ranks = GetGroupByTensorInfo(tensor_info);
|
||||
MS_LOG(DEBUG) << "The op name is " << op_info->name() << ", the parameter index is " << user_input_index - 1
|
||||
<< ", the slice shape is " << ShapeToString(tensor_info.slice_shape()) << ", the origin shape is "
|
||||
<< ShapeToString(tensor_info.shape());
|
||||
return tensor_info.slice_shape();
|
||||
<< ", the slice shape is " << tensor_info.slice_shape() << ", the origin shape is "
|
||||
<< tensor_info.shape() << ", the group rank list is " << parameter_slice_info.group_ranks;
|
||||
return parameter_slice_info;
|
||||
}
|
||||
|
||||
void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
|
@ -3262,13 +3283,24 @@ void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
auto parameter_name = parameter_users_info.first;
|
||||
MS_LOG(INFO) << "The parameter: " << parameter_name << " has " << users_set.size() << " users";
|
||||
auto first_user = users_set.pop();
|
||||
Shape first_user_slice_shape = ParameterSliceShape(first_user);
|
||||
ParameterSliceInfo parameter_slice_info = GetParameterSliceInfo(first_user);
|
||||
Shape first_user_slice_shape = parameter_slice_info.slice_shape;
|
||||
RankList first_user_group_list = parameter_slice_info.group_ranks;
|
||||
|
||||
for (auto &user : users_set) {
|
||||
Shape user_slice_shape = ParameterSliceShape(user);
|
||||
ParameterSliceInfo user_slice_info = GetParameterSliceInfo(user);
|
||||
Shape user_slice_shape = user_slice_info.slice_shape;
|
||||
RankList user_group_list = user_slice_info.group_ranks;
|
||||
if (first_user_slice_shape != user_slice_shape) {
|
||||
MS_LOG(EXCEPTION) << "The parameter: " << parameter_name
|
||||
<< " has multiple users, but the split strategies are different";
|
||||
<< " has multiple users, but the slice shapes are different";
|
||||
}
|
||||
|
||||
if (ParallelContext::GetInstance()->pipeline_stage_split_num() == 1 && first_user_group_list != user_group_list) {
|
||||
MS_LOG(EXCEPTION) << "The parameter: " << parameter_name
|
||||
<< " has multiple users, but the group rank list are different, "
|
||||
<< "the group rank list for first user is " << first_user_group_list
|
||||
<< ", and the group rank list for this user is " << user_group_list;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,6 +46,11 @@ struct LossNodeInfo {
|
|||
CNodePtr loss_node = nullptr;
|
||||
};
|
||||
|
||||
struct ParameterSliceInfo {
|
||||
Shape slice_shape;
|
||||
RankList group_ranks;
|
||||
};
|
||||
|
||||
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name);
|
||||
std::string CreateInstanceName(const CNodePtr &node, size_t index);
|
||||
void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node);
|
||||
|
|
|
@ -47,9 +47,22 @@ class Net2(Cell):
|
|||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
_w = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
class Net3(Cell):
|
||||
def __init__(self, mul_weight, strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.mul = P.MatMul().shard(strategy1)
|
||||
self.mul2 = P.MatMul().shard(strategy2)
|
||||
self.mul_weight = Parameter(mul_weight, "w1")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.mul(x, self.mul_weight)
|
||||
out = self.mul2(out, self.mul_weight)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([16, 16]), dtype=ms.float32)
|
||||
_w = Tensor(np.ones([16, 16]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([16, 16]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
|
@ -63,16 +76,16 @@ def compile_net(net):
|
|||
|
||||
def test_parameter_same_split():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((16, 1, 1), (16, 1, 1))
|
||||
strategy2 = ((16, 1, 1), (16, 1, 1))
|
||||
strategy1 = ((16, 1), (16, 1))
|
||||
strategy2 = ((16, 1), (16, 1))
|
||||
net = Net(_w, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_parameter_different_split():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((16, 1, 1), (16, 1, 1))
|
||||
strategy2 = ((4, 4, 1), (4, 4, 1))
|
||||
strategy1 = ((16, 1), (16, 1))
|
||||
strategy2 = ((4, 4), (4, 4))
|
||||
net = Net(_w, strategy1, strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
@ -80,16 +93,25 @@ def test_parameter_different_split():
|
|||
|
||||
def test_input_same_split():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((16, 1, 1), (16, 1, 1))
|
||||
strategy2 = ((16, 1, 1), (16, 1, 1))
|
||||
strategy1 = ((16, 1), (16, 1))
|
||||
strategy2 = ((16, 1), (16, 1))
|
||||
net = Net(_w, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_input_different_split():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((16, 1, 1), (16, 1, 1))
|
||||
strategy2 = ((4, 4, 1), (4, 4, 1))
|
||||
strategy1 = ((16, 1), (16, 1))
|
||||
strategy2 = ((4, 4), (4, 4))
|
||||
net = Net2(_w, strategy1, strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_parameter_different_group():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((1, 2), (2, 1))
|
||||
strategy2 = ((8, 2), (2, 1))
|
||||
net = Net3(_w, strategy1, strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
|
Loading…
Reference in New Issue