fix_opt_shard_param_init
This commit is contained in:
parent
c85fecbcb8
commit
b89ec58e6c
|
@ -396,10 +396,13 @@ void SliceParameterObj(const ParameterPtr ¶meter, const TensorLayoutPtr &ten
|
||||||
// create python layout obj
|
// create python layout obj
|
||||||
const auto &device_arrangement = tensor_layout->device_arrangement().array();
|
const auto &device_arrangement = tensor_layout->device_arrangement().array();
|
||||||
const auto &tensor_map = tensor_layout->tensor_map().array();
|
const auto &tensor_map = tensor_layout->tensor_map().array();
|
||||||
const auto &slice_shape = tensor_layout->slice_shape().array();
|
auto slice_shape = tensor_layout->slice_shape().array();
|
||||||
int64_t field_size = tensor_layout->get_field_size();
|
int64_t field_size = tensor_layout->get_field_size();
|
||||||
bool uniform_split = tensor_layout->uniform_split();
|
bool uniform_split = tensor_layout->uniform_split();
|
||||||
std::string opt_shard_group = tensor_layout->opt_shard_group();
|
std::string opt_shard_group = tensor_layout->opt_shard_group();
|
||||||
|
if (!opt_shard_group.empty()) {
|
||||||
|
slice_shape = tensor_layout->opt_shard_slice_shape();
|
||||||
|
}
|
||||||
py::tuple layout =
|
py::tuple layout =
|
||||||
py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group);
|
py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group);
|
||||||
|
|
||||||
|
|
|
@ -712,7 +712,7 @@ class Parameter(Tensor_):
|
||||||
raise TypeError("The argument 'layout' should be tuple, but got {}.".format(type(layout)))
|
raise TypeError("The argument 'layout' should be tuple, but got {}.".format(type(layout)))
|
||||||
if len(layout) < 6:
|
if len(layout) < 6:
|
||||||
raise ValueError("The length of 'layout' must be larger than 5, but got {}.".format(len(layout)))
|
raise ValueError("The length of 'layout' must be larger than 5, but got {}.".format(len(layout)))
|
||||||
slice_index = int(_get_slice_index(layout[0], layout[1]))
|
slice_index = int(_get_slice_index(layout[0], layout[1], layout[5]))
|
||||||
init_data_args += (slice_index, layout[2], layout[5])
|
init_data_args += (slice_index, layout[2], layout[5])
|
||||||
return init_data_args
|
return init_data_args
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ import math
|
||||||
import numbers
|
import numbers
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore.communication.management import get_rank, get_group_size
|
from mindspore.communication.management import get_group_size
|
||||||
from mindspore.common._utils import is_shape_unknown, is_stub_tensor
|
from mindspore.common._utils import is_shape_unknown, is_stub_tensor
|
||||||
from mindspore.common.seed import get_seed
|
from mindspore.common.seed import get_seed
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
@ -2265,9 +2265,9 @@ class Tensor(Tensor_):
|
||||||
self._np_seed = np.random.get_state()[1][0]
|
self._np_seed = np.random.get_state()[1][0]
|
||||||
self.need_set_seed = (slice_index is not None)
|
self.need_set_seed = (slice_index is not None)
|
||||||
self._global_seed = global_seed
|
self._global_seed = global_seed
|
||||||
self._device_num = 1
|
self._seed_offset = 1
|
||||||
if self.need_set_seed:
|
if self.need_set_seed:
|
||||||
self._device_num = get_group_size()
|
self._seed_offset = get_group_size() * 2
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
if self.need_set_seed:
|
if self.need_set_seed:
|
||||||
|
@ -2278,7 +2278,7 @@ class Tensor(Tensor_):
|
||||||
else:
|
else:
|
||||||
np.random.seed(slice_index + Tensor.delta_seed)
|
np.random.seed(slice_index + Tensor.delta_seed)
|
||||||
self.init.seed = slice_index + Tensor.delta_seed
|
self.init.seed = slice_index + Tensor.delta_seed
|
||||||
Tensor.delta_seed += self._device_num
|
Tensor.delta_seed += self._seed_offset
|
||||||
|
|
||||||
def __exit__(self, ptype, value, trace):
|
def __exit__(self, ptype, value, trace):
|
||||||
if self.need_set_seed:
|
if self.need_set_seed:
|
||||||
|
@ -2287,10 +2287,6 @@ class Tensor(Tensor_):
|
||||||
|
|
||||||
with seed_context(self.init):
|
with seed_context(self.init):
|
||||||
self.init(data)
|
self.init(data)
|
||||||
if opt_shard_group:
|
|
||||||
rank = get_rank(opt_shard_group)
|
|
||||||
size = get_group_size(opt_shard_group)
|
|
||||||
data = np.split(data, size)[rank]
|
|
||||||
self.init = None
|
self.init = None
|
||||||
|
|
||||||
# At embedding cache scenes. When size of tensor is out of range, we store data to persistent storage
|
# At embedding cache scenes. When size of tensor is out of range, we store data to persistent storage
|
||||||
|
|
|
@ -175,20 +175,26 @@ def _chunk_tensor_by_strategy(np_tensor, strategy):
|
||||||
return _chunk_tensor(np_tensor, strategy, len(strategy))
|
return _chunk_tensor(np_tensor, strategy, len(strategy))
|
||||||
|
|
||||||
|
|
||||||
def _get_slice_index(dev_mat, tensor_map):
|
def _get_slice_index(dev_mat, tensor_map, opt_shard_group):
|
||||||
"""
|
"""
|
||||||
Get the slice index for current slice.
|
Get the slice index for current slice.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dev_mat (list): The device matrix of devices.
|
dev_mat (list): The device matrix of devices.
|
||||||
tensor_map (list): The split strategy of tensor.
|
tensor_map (list): The split strategy of tensor.
|
||||||
|
opt_shard_group(string): The group of optimizer shard
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Integer, the slice index for slice on this device.
|
Integer, the slice index for slice on this device.
|
||||||
"""
|
"""
|
||||||
rank = get_rank()
|
rank = get_rank()
|
||||||
|
dev_num = get_group_size()
|
||||||
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
||||||
tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
|
tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
|
||||||
|
if opt_shard_group:
|
||||||
|
tensor_slice_index += dev_num
|
||||||
|
opt_rank = get_rank(opt_shard_group)
|
||||||
|
tensor_slice_index += opt_rank
|
||||||
return tensor_slice_index
|
return tensor_slice_index
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue