optimizer weight shard mix precision optimization and finetune

This commit is contained in:
Ziyan 2021-06-01 17:03:19 +08:00
parent 8cb0a4c8e9
commit 88be613cdc
6 changed files with 47 additions and 39 deletions

View File

@ -1678,8 +1678,8 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group
allgather = ReplaceNode(op, cast_node, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name, root);
MS_LOG(INFO) << "Parallel optimizer is applied before Cast for " << param_name;
} else {
InsertNode(op, cnode, res.second, node, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name, root);
allgather = cnode->input(res.second)->cast<CNodePtr>();
InsertNode(op, cnode, IntToSize(res.second), node, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name, root);
allgather = cnode->input(IntToSize(res.second))->cast<CNodePtr>();
MS_LOG(INFO) << "Parallel optimizer is applied before " << GetPrimName(cnode) << " for " << param_name;
}
// add fusion flag

View File

@ -30,7 +30,7 @@ Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix
}
// skip redistribution for reshape operator
OperatorVector ConstructOperator::SkipRedisReshapeOP(Shape shape) {
OperatorVector ConstructOperator::SkipRedisReshapeOP(const Shape &shape) {
OperatorAttrs attrs;
ValuePtr param_value = MakeValue(shape);
Attr param = std::make_pair(SHAPE, param_value);

View File

@ -35,7 +35,7 @@ class ConstructOperator {
ConstructOperator() : dev_size_(0) {}
~ConstructOperator() = default;
Status Init(const RankList &dev_list, const Shape &dev_matrix_shape);
OperatorVector SkipRedisReshapeOP(Shape shape);
OperatorVector SkipRedisReshapeOP(const Shape &shape);
Status ReshapeOP(Shape shape);
Status StridedSliceOP(Args args);
Status AllGatherOP(int64_t dev_dim);

View File

@ -421,7 +421,7 @@ Status TensorLayout::GenerateOptShardSliceShape() {
Shape tensor_map = tensor_map_.array();
Shape repeated_dev;
for (size_t i = 0; i < dev_max.size(); i++) {
if (tensor_map_.GetIndexByValue(i) == MAP_NONE) {
if (tensor_map_.GetIndexByValue(static_cast<int64_t>(i)) == MAP_NONE) {
repeated_dev.push_back(dev_max[dev_max.size() - 1 - i]);
dev_max[dev_max.size() - 1 - i] = 1;
}
@ -440,7 +440,7 @@ Status TensorLayout::GenerateOptShardSliceShape() {
if (tensor_map[0] == MAP_NONE) {
split_num = repeated_num;
} else {
split_num = dev_max[dev_max.size() - 1 - tensor_map[0]] * repeated_num;
split_num = dev_max[dev_max.size() - 1 - static_cast<size_t>(tensor_map[0])] * repeated_num;
}
if (tensor_shape_.array()[0] % split_num != 0) {
MS_LOG(INFO) << "Tensor could not be shard on the first dimension.";

View File

@ -63,8 +63,8 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL
if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
return nullptr;
} else {
(void)operator_vector.push_back(constructor.GetOperator());
(void)output_info_vector.push_back(std::make_pair(false, 0));
operator_vector.push_back(constructor.GetOperator());
output_info_vector.push_back(std::make_pair(false, 0));
}
}
if (InferRedistribution(to_repeat, to_origin_, &operator_vector, &output_info_vector, is_cost_model) ==
@ -107,7 +107,6 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL
MS_LOG(DEBUG) << "reshape from_ " << from_.ToString();
MS_LOG(DEBUG) << "reshape to_ " << to_.ToString();
// Step 2: Infer redistribution and insert operators
RedistributionOperatorInfer operator_infer(construct_op_flag_);
OperatorVector operator_vector;
OutPutInfoVector output_info_vector;
if (InferRedistribution(from_layout, to_layout, &operator_vector, &output_info_vector, is_cost_model) !=
@ -248,14 +247,14 @@ Status TensorRedistribution::ComputePermuteCost(double input_size, Shape attrs)
forward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
backward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
comm_cost_ += COST_FACTOR * input_size * ALLTOALL_SCALE_FACTOR;
int32_t concat_dim = attrs[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
int64_t concat_dim = attrs[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
if (concat_dim == 0) {
// memory cost = all_gather
computation_cost_ += input_size;
memory_cost_ += input_size;
} else {
// memory cost = all_gather + split + concat
int32_t dev_num = attrs[TRANSFER_PERMUTE_DEV_NUM_INDEX];
int64_t dev_num = attrs[TRANSFER_PERMUTE_DEV_NUM_INDEX];
computation_cost_ += (input_size + input_size * dev_num + input_size * dev_num);
memory_cost_ += (input_size * dev_num + input_size * dev_num + input_size);
}
@ -274,7 +273,7 @@ Status TensorRedistribution::ComputeConcatCost(double input_size, Shape attrs) {
forward_comm_cost_ += input_size * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
backward_comm_cost_ += input_size * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
comm_cost_ += input_size * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
int32_t concat_dim = attrs[TRANSFER_CONCAT_TENSOR_DIM_INDEX];
int64_t concat_dim = attrs[TRANSFER_CONCAT_TENSOR_DIM_INDEX];
if (concat_dim == 0) {
// computation cost = all_gather
computation_cost_ += input_size;

View File

@ -42,6 +42,7 @@ from mindspore._checkparam import check_input_data, Validator
from mindspore.compression.export import quant_export
from mindspore.parallel._tensor import _load_tensor
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
from mindspore.communication.management import get_rank, get_group_size
from .._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file
@ -566,7 +567,7 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save):
Tensor, the combined tensor which with the whole data value.
"""
from mindspore.parallel._cell_wrapper import get_allgather_cell
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
from mindspore.parallel._tensor import _reshape_param_data
layout = net.parameter_layout_dict[param_name]
if len(layout) < 6:
logger.info("layout dict does not contain the key %s", param_name)
@ -574,43 +575,38 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save):
dev_mat = layout[0]
tensor_map = layout[1]
field_size = layout[3]
uniform_split = layout[4]
opt_shard_group = layout[5]
allgather_net = None
mp_weight = False
for dim in tensor_map:
if dim != -1:
mp_weight = True
break
if param_name in net.parallel_parameter_merge_net_dict:
allgather_net = net.parallel_parameter_merge_net_dict[param_name]
else:
logger.info("need to create allgather net for %s", param_name)
if integrated_save:
if uniform_split == 0:
raise RuntimeError("Integrated save checkpoint only support uniform split tensor now.")
# while any dim is not equal to -1, means param is split and needs to be merged
# pipeline parallel need to be supported here later
for dim in tensor_map:
if dim != -1:
if allgather_net is None:
if opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, True)
else:
allgather_net = get_allgather_cell(opt_shard_group, False)
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
param_data = allgather_net(param_data)
if field_size:
return _reshape_param_data_with_weight(param_data, dev_mat, field_size)
return _reshape_param_data(param_data, dev_mat, tensor_map)
if opt_shard_group:
if allgather_net is None:
if integrated_save:
if uniform_split == 0:
raise RuntimeError("Integrated save checkpoint only support uniform split tensor now.")
# while any dim is not equal to -1, means param is split and needs to be merged
# pipeline parallel need to be supported here later
if mp_weight:
if opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, True)
else:
allgather_net = get_allgather_cell(opt_shard_group, False)
elif opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, False)
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
param_data = allgather_net(param_data)
elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_integrated_save"):
if allgather_net is None:
elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
allgather_net = get_allgather_cell(opt_shard_group, False)
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
if allgather_net:
param_data = allgather_net(param_data)
if mp_weight and integrated_save:
param_data = _reshape_param_data(param_data, dev_mat, tensor_map)
return param_data
@ -1251,6 +1247,19 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])
_param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy)
split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy)
opt_shard_group = predict_strategy[param.name][5]
if opt_shard_group:
data = split_param.data.asnumpy()
rank = get_rank(opt_shard_group)
size = get_group_size(opt_shard_group)
try:
data_slice = np.split(data, size)[rank]
except BaseException as e:
logger.error("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}"
" and group is {}".format(param.name, split_param.data.shape, opt_shard_group))
raise RuntimeError(e.__str__())
split_param = Parameter(Tensor(data_slice), param.name,
split_param.requires_grad, split_param.layerwise_parallel)
param_dict[param.name] = split_param
if param_not_in_strategy: