fix_opt_shard_param_init

This commit is contained in:
lichen 2023-02-21 15:19:00 +08:00
parent c85fecbcb8
commit b89ec58e6c
4 changed files with 16 additions and 11 deletions

View File

@ -396,10 +396,13 @@ void SliceParameterObj(const ParameterPtr &parameter, const TensorLayoutPtr &ten
// create python layout obj
const auto &device_arrangement = tensor_layout->device_arrangement().array();
const auto &tensor_map = tensor_layout->tensor_map().array();
const auto &slice_shape = tensor_layout->slice_shape().array();
auto slice_shape = tensor_layout->slice_shape().array();
int64_t field_size = tensor_layout->get_field_size();
bool uniform_split = tensor_layout->uniform_split();
std::string opt_shard_group = tensor_layout->opt_shard_group();
if (!opt_shard_group.empty()) {
slice_shape = tensor_layout->opt_shard_slice_shape();
}
py::tuple layout =
py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group);

View File

@ -712,7 +712,7 @@ class Parameter(Tensor_):
raise TypeError("The argument 'layout' should be tuple, but got {}.".format(type(layout)))
if len(layout) < 6:
raise ValueError("The length of 'layout' must be larger than 5, but got {}.".format(len(layout)))
slice_index = int(_get_slice_index(layout[0], layout[1]))
slice_index = int(_get_slice_index(layout[0], layout[1], layout[5]))
init_data_args += (slice_index, layout[2], layout[5])
return init_data_args

View File

@ -20,7 +20,7 @@ import math
import numbers
import numpy as np
from mindspore.communication.management import get_rank, get_group_size
from mindspore.communication.management import get_group_size
from mindspore.common._utils import is_shape_unknown, is_stub_tensor
from mindspore.common.seed import get_seed
from mindspore import context
@ -2265,9 +2265,9 @@ class Tensor(Tensor_):
self._np_seed = np.random.get_state()[1][0]
self.need_set_seed = (slice_index is not None)
self._global_seed = global_seed
self._device_num = 1
self._seed_offset = 1
if self.need_set_seed:
self._device_num = get_group_size()
self._seed_offset = get_group_size() * 2
def __enter__(self):
if self.need_set_seed:
@ -2278,7 +2278,7 @@ class Tensor(Tensor_):
else:
np.random.seed(slice_index + Tensor.delta_seed)
self.init.seed = slice_index + Tensor.delta_seed
Tensor.delta_seed += self._device_num
Tensor.delta_seed += self._seed_offset
def __exit__(self, ptype, value, trace):
if self.need_set_seed:
@ -2287,10 +2287,6 @@ class Tensor(Tensor_):
with seed_context(self.init):
self.init(data)
if opt_shard_group:
rank = get_rank(opt_shard_group)
size = get_group_size(opt_shard_group)
data = np.split(data, size)[rank]
self.init = None
# At embedding cache scenes. When size of tensor is out of range, we store data to persistent storage

View File

@ -175,20 +175,26 @@ def _chunk_tensor_by_strategy(np_tensor, strategy):
return _chunk_tensor(np_tensor, strategy, len(strategy))
def _get_slice_index(dev_mat, tensor_map):
def _get_slice_index(dev_mat, tensor_map, opt_shard_group):
"""
Get the slice index for current slice.
Args:
dev_mat (list): The device matrix of devices.
tensor_map (list): The split strategy of tensor.
opt_shard_group(string): The group of optimizer shard
Returns:
Integer, the slice index for slice on this device.
"""
rank = get_rank()
dev_num = get_group_size()
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
if opt_shard_group:
tensor_slice_index += dev_num
opt_rank = get_rank(opt_shard_group)
tensor_slice_index += opt_rank
return tensor_slice_index