forked from mindspore-Ecosystem/mindspore
optimizer weight shard mix precision optimization and finetune
This commit is contained in:
parent
8cb0a4c8e9
commit
88be613cdc
|
@ -1678,8 +1678,8 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group
|
|||
allgather = ReplaceNode(op, cast_node, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name, root);
|
||||
MS_LOG(INFO) << "Parallel optimizer is applied before Cast for " << param_name;
|
||||
} else {
|
||||
InsertNode(op, cnode, res.second, node, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name, root);
|
||||
allgather = cnode->input(res.second)->cast<CNodePtr>();
|
||||
InsertNode(op, cnode, IntToSize(res.second), node, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name, root);
|
||||
allgather = cnode->input(IntToSize(res.second))->cast<CNodePtr>();
|
||||
MS_LOG(INFO) << "Parallel optimizer is applied before " << GetPrimName(cnode) << " for " << param_name;
|
||||
}
|
||||
// add fusion flag
|
||||
|
|
|
@ -30,7 +30,7 @@ Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix
|
|||
}
|
||||
|
||||
// skip redistribution for reshape operator
|
||||
OperatorVector ConstructOperator::SkipRedisReshapeOP(Shape shape) {
|
||||
OperatorVector ConstructOperator::SkipRedisReshapeOP(const Shape &shape) {
|
||||
OperatorAttrs attrs;
|
||||
ValuePtr param_value = MakeValue(shape);
|
||||
Attr param = std::make_pair(SHAPE, param_value);
|
||||
|
|
|
@ -35,7 +35,7 @@ class ConstructOperator {
|
|||
ConstructOperator() : dev_size_(0) {}
|
||||
~ConstructOperator() = default;
|
||||
Status Init(const RankList &dev_list, const Shape &dev_matrix_shape);
|
||||
OperatorVector SkipRedisReshapeOP(Shape shape);
|
||||
OperatorVector SkipRedisReshapeOP(const Shape &shape);
|
||||
Status ReshapeOP(Shape shape);
|
||||
Status StridedSliceOP(Args args);
|
||||
Status AllGatherOP(int64_t dev_dim);
|
||||
|
|
|
@ -421,7 +421,7 @@ Status TensorLayout::GenerateOptShardSliceShape() {
|
|||
Shape tensor_map = tensor_map_.array();
|
||||
Shape repeated_dev;
|
||||
for (size_t i = 0; i < dev_max.size(); i++) {
|
||||
if (tensor_map_.GetIndexByValue(i) == MAP_NONE) {
|
||||
if (tensor_map_.GetIndexByValue(static_cast<int64_t>(i)) == MAP_NONE) {
|
||||
repeated_dev.push_back(dev_max[dev_max.size() - 1 - i]);
|
||||
dev_max[dev_max.size() - 1 - i] = 1;
|
||||
}
|
||||
|
@ -440,7 +440,7 @@ Status TensorLayout::GenerateOptShardSliceShape() {
|
|||
if (tensor_map[0] == MAP_NONE) {
|
||||
split_num = repeated_num;
|
||||
} else {
|
||||
split_num = dev_max[dev_max.size() - 1 - tensor_map[0]] * repeated_num;
|
||||
split_num = dev_max[dev_max.size() - 1 - static_cast<size_t>(tensor_map[0])] * repeated_num;
|
||||
}
|
||||
if (tensor_shape_.array()[0] % split_num != 0) {
|
||||
MS_LOG(INFO) << "Tensor could not be shard on the first dimension.";
|
||||
|
|
|
@ -63,8 +63,8 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL
|
|||
if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
|
||||
return nullptr;
|
||||
} else {
|
||||
(void)operator_vector.push_back(constructor.GetOperator());
|
||||
(void)output_info_vector.push_back(std::make_pair(false, 0));
|
||||
operator_vector.push_back(constructor.GetOperator());
|
||||
output_info_vector.push_back(std::make_pair(false, 0));
|
||||
}
|
||||
}
|
||||
if (InferRedistribution(to_repeat, to_origin_, &operator_vector, &output_info_vector, is_cost_model) ==
|
||||
|
@ -107,7 +107,6 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL
|
|||
MS_LOG(DEBUG) << "reshape from_ " << from_.ToString();
|
||||
MS_LOG(DEBUG) << "reshape to_ " << to_.ToString();
|
||||
// Step 2: Infer redistribution and insert operators
|
||||
RedistributionOperatorInfer operator_infer(construct_op_flag_);
|
||||
OperatorVector operator_vector;
|
||||
OutPutInfoVector output_info_vector;
|
||||
if (InferRedistribution(from_layout, to_layout, &operator_vector, &output_info_vector, is_cost_model) !=
|
||||
|
@ -248,14 +247,14 @@ Status TensorRedistribution::ComputePermuteCost(double input_size, Shape attrs)
|
|||
forward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
|
||||
backward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
|
||||
comm_cost_ += COST_FACTOR * input_size * ALLTOALL_SCALE_FACTOR;
|
||||
int32_t concat_dim = attrs[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
|
||||
int64_t concat_dim = attrs[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
|
||||
if (concat_dim == 0) {
|
||||
// memory cost = all_gather
|
||||
computation_cost_ += input_size;
|
||||
memory_cost_ += input_size;
|
||||
} else {
|
||||
// memory cost = all_gather + split + concat
|
||||
int32_t dev_num = attrs[TRANSFER_PERMUTE_DEV_NUM_INDEX];
|
||||
int64_t dev_num = attrs[TRANSFER_PERMUTE_DEV_NUM_INDEX];
|
||||
computation_cost_ += (input_size + input_size * dev_num + input_size * dev_num);
|
||||
memory_cost_ += (input_size * dev_num + input_size * dev_num + input_size);
|
||||
}
|
||||
|
@ -274,7 +273,7 @@ Status TensorRedistribution::ComputeConcatCost(double input_size, Shape attrs) {
|
|||
forward_comm_cost_ += input_size * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
||||
backward_comm_cost_ += input_size * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
||||
comm_cost_ += input_size * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
||||
int32_t concat_dim = attrs[TRANSFER_CONCAT_TENSOR_DIM_INDEX];
|
||||
int64_t concat_dim = attrs[TRANSFER_CONCAT_TENSOR_DIM_INDEX];
|
||||
if (concat_dim == 0) {
|
||||
// computation cost = all_gather
|
||||
computation_cost_ += input_size;
|
||||
|
|
|
@ -42,6 +42,7 @@ from mindspore._checkparam import check_input_data, Validator
|
|||
from mindspore.compression.export import quant_export
|
||||
from mindspore.parallel._tensor import _load_tensor
|
||||
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
from .._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file
|
||||
|
||||
|
||||
|
@ -566,7 +567,7 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save):
|
|||
Tensor, the combined tensor which with the whole data value.
|
||||
"""
|
||||
from mindspore.parallel._cell_wrapper import get_allgather_cell
|
||||
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
|
||||
from mindspore.parallel._tensor import _reshape_param_data
|
||||
layout = net.parameter_layout_dict[param_name]
|
||||
if len(layout) < 6:
|
||||
logger.info("layout dict does not contain the key %s", param_name)
|
||||
|
@ -574,43 +575,38 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save):
|
|||
|
||||
dev_mat = layout[0]
|
||||
tensor_map = layout[1]
|
||||
field_size = layout[3]
|
||||
uniform_split = layout[4]
|
||||
opt_shard_group = layout[5]
|
||||
|
||||
allgather_net = None
|
||||
mp_weight = False
|
||||
for dim in tensor_map:
|
||||
if dim != -1:
|
||||
mp_weight = True
|
||||
break
|
||||
if param_name in net.parallel_parameter_merge_net_dict:
|
||||
allgather_net = net.parallel_parameter_merge_net_dict[param_name]
|
||||
else:
|
||||
logger.info("need to create allgather net for %s", param_name)
|
||||
|
||||
if integrated_save:
|
||||
if uniform_split == 0:
|
||||
raise RuntimeError("Integrated save checkpoint only support uniform split tensor now.")
|
||||
# while any dim is not equal to -1, means param is split and needs to be merged
|
||||
# pipeline parallel need to be supported here later
|
||||
for dim in tensor_map:
|
||||
if dim != -1:
|
||||
if allgather_net is None:
|
||||
if opt_shard_group:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, True)
|
||||
else:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False)
|
||||
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
|
||||
param_data = allgather_net(param_data)
|
||||
if field_size:
|
||||
return _reshape_param_data_with_weight(param_data, dev_mat, field_size)
|
||||
return _reshape_param_data(param_data, dev_mat, tensor_map)
|
||||
if opt_shard_group:
|
||||
if allgather_net is None:
|
||||
if integrated_save:
|
||||
if uniform_split == 0:
|
||||
raise RuntimeError("Integrated save checkpoint only support uniform split tensor now.")
|
||||
# while any dim is not equal to -1, means param is split and needs to be merged
|
||||
# pipeline parallel need to be supported here later
|
||||
if mp_weight:
|
||||
if opt_shard_group:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, True)
|
||||
else:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False)
|
||||
elif opt_shard_group:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False)
|
||||
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
|
||||
param_data = allgather_net(param_data)
|
||||
elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_integrated_save"):
|
||||
if allgather_net is None:
|
||||
elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False)
|
||||
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
|
||||
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
|
||||
if allgather_net:
|
||||
param_data = allgather_net(param_data)
|
||||
if mp_weight and integrated_save:
|
||||
param_data = _reshape_param_data(param_data, dev_mat, tensor_map)
|
||||
return param_data
|
||||
|
||||
|
||||
|
@ -1251,6 +1247,19 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|||
param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])
|
||||
_param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy)
|
||||
split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy)
|
||||
opt_shard_group = predict_strategy[param.name][5]
|
||||
if opt_shard_group:
|
||||
data = split_param.data.asnumpy()
|
||||
rank = get_rank(opt_shard_group)
|
||||
size = get_group_size(opt_shard_group)
|
||||
try:
|
||||
data_slice = np.split(data, size)[rank]
|
||||
except BaseException as e:
|
||||
logger.error("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}"
|
||||
" and group is {}".format(param.name, split_param.data.shape, opt_shard_group))
|
||||
raise RuntimeError(e.__str__())
|
||||
split_param = Parameter(Tensor(data_slice), param.name,
|
||||
split_param.requires_grad, split_param.layerwise_parallel)
|
||||
param_dict[param.name] = split_param
|
||||
|
||||
if param_not_in_strategy:
|
||||
|
|
Loading…
Reference in New Issue