fix_opt_shard_param_init
This commit is contained in:
parent
c85fecbcb8
commit
b89ec58e6c
|
@ -396,10 +396,13 @@ void SliceParameterObj(const ParameterPtr ¶meter, const TensorLayoutPtr &ten
|
|||
// create python layout obj
|
||||
const auto &device_arrangement = tensor_layout->device_arrangement().array();
|
||||
const auto &tensor_map = tensor_layout->tensor_map().array();
|
||||
const auto &slice_shape = tensor_layout->slice_shape().array();
|
||||
auto slice_shape = tensor_layout->slice_shape().array();
|
||||
int64_t field_size = tensor_layout->get_field_size();
|
||||
bool uniform_split = tensor_layout->uniform_split();
|
||||
std::string opt_shard_group = tensor_layout->opt_shard_group();
|
||||
if (!opt_shard_group.empty()) {
|
||||
slice_shape = tensor_layout->opt_shard_slice_shape();
|
||||
}
|
||||
py::tuple layout =
|
||||
py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group);
|
||||
|
||||
|
|
|
@ -712,7 +712,7 @@ class Parameter(Tensor_):
|
|||
raise TypeError("The argument 'layout' should be tuple, but got {}.".format(type(layout)))
|
||||
if len(layout) < 6:
|
||||
raise ValueError("The length of 'layout' must be larger than 5, but got {}.".format(len(layout)))
|
||||
slice_index = int(_get_slice_index(layout[0], layout[1]))
|
||||
slice_index = int(_get_slice_index(layout[0], layout[1], layout[5]))
|
||||
init_data_args += (slice_index, layout[2], layout[5])
|
||||
return init_data_args
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import math
|
|||
import numbers
|
||||
import numpy as np
|
||||
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.common._utils import is_shape_unknown, is_stub_tensor
|
||||
from mindspore.common.seed import get_seed
|
||||
from mindspore import context
|
||||
|
@ -2265,9 +2265,9 @@ class Tensor(Tensor_):
|
|||
self._np_seed = np.random.get_state()[1][0]
|
||||
self.need_set_seed = (slice_index is not None)
|
||||
self._global_seed = global_seed
|
||||
self._device_num = 1
|
||||
self._seed_offset = 1
|
||||
if self.need_set_seed:
|
||||
self._device_num = get_group_size()
|
||||
self._seed_offset = get_group_size() * 2
|
||||
|
||||
def __enter__(self):
|
||||
if self.need_set_seed:
|
||||
|
@ -2278,7 +2278,7 @@ class Tensor(Tensor_):
|
|||
else:
|
||||
np.random.seed(slice_index + Tensor.delta_seed)
|
||||
self.init.seed = slice_index + Tensor.delta_seed
|
||||
Tensor.delta_seed += self._device_num
|
||||
Tensor.delta_seed += self._seed_offset
|
||||
|
||||
def __exit__(self, ptype, value, trace):
|
||||
if self.need_set_seed:
|
||||
|
@ -2287,10 +2287,6 @@ class Tensor(Tensor_):
|
|||
|
||||
with seed_context(self.init):
|
||||
self.init(data)
|
||||
if opt_shard_group:
|
||||
rank = get_rank(opt_shard_group)
|
||||
size = get_group_size(opt_shard_group)
|
||||
data = np.split(data, size)[rank]
|
||||
self.init = None
|
||||
|
||||
# At embedding cache scenes. When size of tensor is out of range, we store data to persistent storage
|
||||
|
|
|
@ -175,20 +175,26 @@ def _chunk_tensor_by_strategy(np_tensor, strategy):
|
|||
return _chunk_tensor(np_tensor, strategy, len(strategy))
|
||||
|
||||
|
||||
def _get_slice_index(dev_mat, tensor_map):
|
||||
def _get_slice_index(dev_mat, tensor_map, opt_shard_group):
|
||||
"""
|
||||
Get the slice index for current slice.
|
||||
|
||||
Args:
|
||||
dev_mat (list): The device matrix of devices.
|
||||
tensor_map (list): The split strategy of tensor.
|
||||
opt_shard_group(string): The group of optimizer shard
|
||||
|
||||
Returns:
|
||||
Integer, the slice index for slice on this device.
|
||||
"""
|
||||
rank = get_rank()
|
||||
dev_num = get_group_size()
|
||||
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
||||
tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
|
||||
if opt_shard_group:
|
||||
tensor_slice_index += dev_num
|
||||
opt_rank = get_rank(opt_shard_group)
|
||||
tensor_slice_index += opt_rank
|
||||
return tensor_slice_index
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue