forked from mindspore-Ecosystem/mindspore
!2286 enable optimizer parallel with broadcast
Merge pull request !2286 from gziyan/optimizer_parallel
This commit is contained in:
commit
ca1d2436b8
|
@ -62,6 +62,7 @@ void ParallelContext::Reset() {
|
|||
enable_all_reduce_fusion_ = false;
|
||||
strategy_ckpt_load_file_ = "";
|
||||
strategy_ckpt_save_file_ = "";
|
||||
enable_parallel_optimizer_ = false;
|
||||
}
|
||||
|
||||
void ParallelContext::set_device_num(int32_t device_num) {
|
||||
|
|
|
@ -100,6 +100,11 @@ class ParallelContext {
|
|||
void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
|
||||
std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }
|
||||
|
||||
void set_enable_parallel_optimizer(bool enable_parallel_optimizer) {
|
||||
enable_parallel_optimizer_ = enable_parallel_optimizer;
|
||||
}
|
||||
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
|
||||
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
|
@ -123,6 +128,7 @@ class ParallelContext {
|
|||
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
|
||||
std::string strategy_ckpt_load_file_;
|
||||
std::string strategy_ckpt_save_file_;
|
||||
bool enable_parallel_optimizer_;
|
||||
};
|
||||
|
||||
void ParallelParameterContextInit(const FuncGraphPtr &func_graph);
|
||||
|
|
|
@ -205,6 +205,10 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
|
||||
.def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.")
|
||||
.def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.")
|
||||
.def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer,
|
||||
"Set enable/disable parallel optimizer.")
|
||||
.def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer,
|
||||
"Get enable/disable parallel optimizer.")
|
||||
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
|
||||
|
||||
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
|
||||
|
|
|
@ -29,8 +29,9 @@ from .optimizer import Optimizer
|
|||
_adam_opt = C.MultitypeFuncGraph("adam_opt")
|
||||
|
||||
|
||||
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
|
||||
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
|
||||
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Bool", "Bool")
|
||||
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag, optim_filter):
|
||||
"""
|
||||
Update parameters.
|
||||
|
||||
|
@ -44,10 +45,13 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
|
|||
m (Tensor): m value of parameters.
|
||||
v (Tensor): v value of parameters.
|
||||
gradient (Tensor): Gradient of parameters.
|
||||
decay_flag (bool): Applies weight decay or not.
|
||||
optim_filter (bool): Applies parameter update or not.
|
||||
|
||||
Returns:
|
||||
Tensor, the new value of v after updating.
|
||||
"""
|
||||
if optim_filter:
|
||||
op_mul = P.Mul()
|
||||
op_square = P.Square()
|
||||
op_sqrt = P.Sqrt()
|
||||
|
@ -60,11 +64,13 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
|
|||
v_fp32 = op_cast(v, mstype.float32)
|
||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||
|
||||
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
|
||||
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
|
||||
- beta1, gradient_fp32)
|
||||
|
||||
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
|
||||
- beta2, op_square(gradient_fp32))
|
||||
|
||||
|
||||
update = next_m / (eps + op_sqrt(next_v))
|
||||
if decay_flag:
|
||||
update = op_mul(weight_decay_tensor, param_fp32) + update
|
||||
|
@ -72,10 +78,11 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
|
|||
update_with_lr = op_mul(lr, update)
|
||||
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
||||
|
||||
next_v = F.depend(next_v, F.assign(param, op_cast(next_param, F.dtype(param))))
|
||||
next_v = F.depend(next_v, F.assign(m, op_cast(next_m, F.dtype(m))))
|
||||
next_v = F.depend(next_v, F.assign(v, op_cast(next_v, F.dtype(v))))
|
||||
return next_v
|
||||
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
|
||||
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
|
||||
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
|
||||
return next_param
|
||||
return gradient
|
||||
|
||||
|
||||
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
|
||||
|
@ -300,7 +307,7 @@ class AdamWeightDecay(Optimizer):
|
|||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
||||
Outputs:
|
||||
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
|
||||
tuple[bool], all elements are True.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
|
@ -328,11 +335,13 @@ class AdamWeightDecay(Optimizer):
|
|||
|
||||
def construct(self, gradients):
|
||||
lr = self.get_lr()
|
||||
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
|
||||
self.weight_decay_tensor),
|
||||
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
||||
|
||||
return updated_velocity
|
||||
self.params, self.moments1, self.moments2, gradients,
|
||||
self.decay_flag, self.optim_filter)
|
||||
if self.use_parallel:
|
||||
optim_result = self.broadcast_params(optim_result)
|
||||
return optim_result
|
||||
|
||||
|
||||
class AdamWeightDecayDynamicLR(Optimizer):
|
||||
|
@ -363,7 +372,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
|
|||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
||||
Outputs:
|
||||
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
|
||||
tuple[bool], all elements are True.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
|
@ -424,12 +433,14 @@ class AdamWeightDecayDynamicLR(Optimizer):
|
|||
warmup_lr = self.start_learning_rate * warmup_percent
|
||||
is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32)
|
||||
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
|
||||
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
|
||||
self.weight_decay_tensor),
|
||||
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
||||
|
||||
self.params, self.moments1, self.moments2, gradients,
|
||||
self.decay_flag, self.optim_filter)
|
||||
if self.use_parallel:
|
||||
optim_result = self.broadcast_params(optim_result)
|
||||
added_global_step = self.global_step + self.one
|
||||
F.control_depend(lr, added_global_step)
|
||||
self.global_step = added_global_step
|
||||
|
||||
return updated_velocity
|
||||
return optim_result
|
||||
|
|
|
@ -32,11 +32,10 @@ num_one = Tensor(np.ones([1]), mstype.float32)
|
|||
|
||||
_lamb_opt = C.MultitypeFuncGraph("lamb_opt")
|
||||
|
||||
|
||||
@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "Bool")
|
||||
@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Bool", "Bool")
|
||||
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v,
|
||||
gradient, decay_flag):
|
||||
gradient, decay_flag, optim_filter):
|
||||
"""
|
||||
Update parameters.
|
||||
|
||||
|
@ -52,10 +51,12 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
|
|||
v (Tensor): v value of parameters.
|
||||
gradient (Tensor): Gradient of parameters.
|
||||
decay_flag (bool): Specifies whether param update with weight decay.
|
||||
optim_filter(bool): Applies parameter update or not.
|
||||
|
||||
Returns:
|
||||
Tensor, the new value of v after updating.
|
||||
"""
|
||||
if optim_filter:
|
||||
op_mul = P.Mul()
|
||||
op_sqrt = P.Sqrt()
|
||||
op_rsqrt = P.Rsqrt()
|
||||
|
@ -75,11 +76,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
|
|||
v_fp32 = op_cast(v, mstype.float32)
|
||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||
|
||||
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one,
|
||||
mstype.float32) - beta1, gradient_fp32)
|
||||
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32)
|
||||
|
||||
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one,
|
||||
mstype.float32) - beta2, op_square(gradient_fp32))
|
||||
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))
|
||||
|
||||
next_mm = next_m / (op_cast(num_one, mstype.float32)
|
||||
- op_pow(beta1, op_cast(global_step + num_one, mstype.float32)))
|
||||
|
@ -88,8 +87,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
|
|||
w_norm = op_norm(param_fp32)
|
||||
g_norm = op_norm(gradient_fp32)
|
||||
|
||||
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(
|
||||
next_vv + eps)) + weight_decay_tensor * param_fp32)
|
||||
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32)
|
||||
zeros = F.zeros_like(w_norm)
|
||||
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
|
||||
trust_ratio = op_select(
|
||||
|
@ -107,11 +105,12 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
|
|||
|
||||
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
||||
|
||||
next_v = F.depend(next_v, F.assign(param, next_param))
|
||||
next_v = F.depend(next_v, F.assign(m, next_m))
|
||||
next_v = F.depend(next_v, F.assign(v, next_v))
|
||||
next_param = F.depend(next_param, F.assign(param, next_param))
|
||||
next_param = F.depend(next_param, F.assign(m, next_m))
|
||||
next_param = F.depend(next_param, F.assign(v, next_v))
|
||||
|
||||
return next_v
|
||||
return next_param
|
||||
return gradient
|
||||
|
||||
|
||||
lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel")
|
||||
|
@ -238,7 +237,7 @@ class Lamb(Optimizer):
|
|||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
||||
Outputs:
|
||||
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
|
||||
tuple[bool], all elements are True.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
|
@ -311,18 +310,21 @@ class Lamb(Optimizer):
|
|||
self.warmup_steps, self.global_step), mstype.float32)
|
||||
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
|
||||
if self.enable_graph_kernel:
|
||||
updated_velocity = self.hyper_map(F.partial(lamb_opt_graph_kernel,
|
||||
optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel,
|
||||
self.beta1, self.beta2, self.eps, lr,
|
||||
self.weight_decay_tensor, self.global_step),
|
||||
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
||||
else:
|
||||
updated_velocity = self.hyper_map(F.partial(_lamb_opt,
|
||||
optim_result = self.hyper_map(F.partial(_lamb_opt,
|
||||
self.beta1, self.beta2, self.eps, lr,
|
||||
self.weight_decay_tensor, self.global_step),
|
||||
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
||||
self.params, self.moments1, self.moments2, gradients,
|
||||
self.decay_flag, self.optim_filter)
|
||||
if self.use_parallel:
|
||||
optim_result = self.broadcast_params(optim_result)
|
||||
|
||||
added_global_step = self.global_step + self.one
|
||||
F.control_depend(lr, added_global_step)
|
||||
self.global_step = added_global_step
|
||||
|
||||
return updated_velocity
|
||||
return optim_result
|
||||
|
|
|
@ -22,11 +22,14 @@ from mindspore.ops import functional as F, composite as C, operations as P
|
|||
from mindspore.nn.cell import Cell
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore import log as logger
|
||||
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
|
||||
__all__ = ['Optimizer']
|
||||
|
||||
|
@ -155,6 +158,27 @@ class Optimizer(Cell):
|
|||
self.param_length = len(self.parameters)
|
||||
self.map_ = C.Map()
|
||||
|
||||
use_parallel = auto_parallel_context().get_enable_parallel_optimizer()
|
||||
self.use_parallel = use_parallel
|
||||
if use_parallel:
|
||||
if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]:
|
||||
raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name))
|
||||
if _get_parallel_mode() not in [ParallelMode.HYBRID_PARALLEL, ParallelMode.DATA_PARALLEL,
|
||||
ParallelMode.AUTO_PARALLEL]:
|
||||
raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format
|
||||
(_get_parallel_mode()))
|
||||
self.dev_num = _get_device_num()
|
||||
if self.dev_num > self.param_length:
|
||||
raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is"
|
||||
" less than the number of devices {}".format(self.param_length, self.dev_num))
|
||||
self.param_rank = self._get_parameter_group_id()
|
||||
self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank))
|
||||
self.param_names = []
|
||||
for param in self.parameters:
|
||||
self.param_names.append(param.name)
|
||||
else:
|
||||
self.optim_filter = (True,) * self.param_length
|
||||
|
||||
def decay_weight(self, gradients):
|
||||
"""
|
||||
Weight decay.
|
||||
|
@ -401,6 +425,51 @@ class Optimizer(Cell):
|
|||
lr = self.learning_rate
|
||||
return lr
|
||||
|
||||
def _get_parameter_group_id(self):
|
||||
"""
|
||||
Get the parameter partition group id, which is less than the number of devices.
|
||||
|
||||
Returns:
|
||||
tuple, the group id tuple of parameters.
|
||||
"""
|
||||
rank_list = ()
|
||||
count = 0
|
||||
for _ in range(self.param_length):
|
||||
rank_list = rank_list + (count,)
|
||||
count = count + 1
|
||||
if count == self.dev_num:
|
||||
count = 0
|
||||
return rank_list
|
||||
|
||||
def broadcast_params(self, optim_result):
|
||||
"""
|
||||
Apply Broadcast operations in the sequential order of parameter groups.
|
||||
|
||||
Returns:
|
||||
bool, the status flag.
|
||||
"""
|
||||
param_group = []
|
||||
key_group = []
|
||||
for _ in range(self.dev_num):
|
||||
param_group.append(F.make_tuple())
|
||||
key_group.append(F.make_tuple())
|
||||
for i in range(self.param_length):
|
||||
param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],)
|
||||
key = P.MakeRefKey(self.param_names[i])()
|
||||
key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,)
|
||||
new_param_group = []
|
||||
for root in range(self.dev_num):
|
||||
ops = P.Broadcast(root)
|
||||
next_params = ops(param_group[root])
|
||||
new_param_group.append(next_params)
|
||||
for i in range(F.tuple_len(next_params)):
|
||||
F.assign(key_group[root][i], next_params[i])
|
||||
status = True
|
||||
for i in range(self.dev_num - 1):
|
||||
status = F.control_depend(new_param_group[i][0], new_param_group[i+1])
|
||||
|
||||
return status
|
||||
|
||||
def construct(self, *hyper_params):
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -85,6 +85,22 @@ def _list_setitem_with_List(data, number_index, value):
|
|||
return F.list_setitem(data, number_index, value)
|
||||
|
||||
|
||||
@setitem.register("List", "Number", "Tuple")
|
||||
def _list_setitem_with_Tuple(data, number_index, value):
|
||||
"""
|
||||
Assigns value to list.
|
||||
|
||||
Inputs:
|
||||
data (list): Data of type lis.
|
||||
number_index (Number): Index of data.
|
||||
value (list): Value given.
|
||||
|
||||
Outputs:
|
||||
list, type is same as the element type of data.
|
||||
"""
|
||||
return F.list_setitem(data, number_index, value)
|
||||
|
||||
|
||||
@setitem.register("Dictionary", "String", "Tensor")
|
||||
def _dict_setitem_with_tensor(data, key, value):
|
||||
"""
|
||||
|
|
|
@ -98,6 +98,7 @@ class AllReduce(PrimitiveWithInfer):
|
|||
self.op = op
|
||||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('fusion', 0)
|
||||
self.add_prim_attr('index', 0)
|
||||
|
||||
def vm_impl(self, x):
|
||||
"""Implement by vm mode."""
|
||||
|
|
|
@ -59,8 +59,7 @@ class Assign(PrimitiveWithInfer):
|
|||
return variable
|
||||
|
||||
def infer_dtype(self, variable, value):
|
||||
args = {"variable": variable, "value": value}
|
||||
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
|
||||
# Add a type validation later when we don't have to assign a value to RefKey.
|
||||
return variable
|
||||
|
||||
|
||||
|
|
|
@ -400,6 +400,23 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
return self._context_handle.get_global_rank_is_set()
|
||||
|
||||
def set_enable_parallel_optimizer(self, enable_parallel_optimizer):
|
||||
"""
|
||||
Set enable/disable parallel optimizer.
|
||||
|
||||
Args:
|
||||
set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if not isinstance(enable_parallel_optimizer, bool):
|
||||
raise TypeError('enable_parallel_optimizer is invalid type')
|
||||
self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
|
||||
|
||||
def get_enable_parallel_optimizer(self):
|
||||
"""Get parallel optimizer flag."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_enable_parallel_optimizer()
|
||||
|
||||
def reset(self):
|
||||
"""Reset all settings."""
|
||||
self.check_context_handle()
|
||||
|
@ -433,7 +450,8 @@ _set_auto_parallel_context_func_map = {
|
|||
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
|
||||
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
|
||||
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
|
||||
"full_batch": auto_parallel_context().set_full_batch}
|
||||
"full_batch": auto_parallel_context().set_full_batch,
|
||||
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer}
|
||||
|
||||
|
||||
_get_auto_parallel_context_func_map = {
|
||||
|
@ -447,13 +465,15 @@ _get_auto_parallel_context_func_map = {
|
|||
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
|
||||
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
|
||||
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
|
||||
"full_batch": auto_parallel_context().get_full_batch}
|
||||
"full_batch": auto_parallel_context().get_full_batch,
|
||||
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer}
|
||||
|
||||
|
||||
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
|
||||
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
|
||||
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
strategy_ckpt_save_file=str, full_batch=bool)
|
||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
|
||||
|
||||
def _set_auto_parallel_context(**kwargs):
|
||||
"""
|
||||
Set auto parallel context.
|
||||
|
@ -493,6 +513,7 @@ def _set_auto_parallel_context(**kwargs):
|
|||
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
||||
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
||||
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
||||
enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
|
@ -535,5 +556,6 @@ def _reset_auto_parallel_context():
|
|||
- parameter_broadcast: False.
|
||||
- strategy_ckpt_load_file: ""
|
||||
- strategy_ckpt_save_file: ""
|
||||
- enable_parallel_optimizer: False
|
||||
"""
|
||||
auto_parallel_context().reset()
|
||||
|
|
|
@ -327,6 +327,9 @@ TEST_F(TestStepParallel, CreatOpInstance) {
|
|||
} else if (name == "instance_name") {
|
||||
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
|
||||
ASSERT_EQ(converted_ret->ToString(), "test");
|
||||
} else if (name == "index") {
|
||||
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
|
||||
ASSERT_EQ(converted_ret->ToString(), "0");
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Test failed";
|
||||
}
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
test assign sub
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
|
@ -36,27 +35,6 @@ class AssignW(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.b = Parameter(initializer('ones', [5]), name='b')
|
||||
self.assign = AssignW()
|
||||
|
||||
def construct(self, value):
|
||||
return self.assign(self.b, value)
|
||||
|
||||
|
||||
def test_assign_through_cell():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
net.to_float(ms.float16)
|
||||
net.add_flags_recursive(fp16=False)
|
||||
input_data = Tensor(np.ones([5]).astype(np.float32))
|
||||
net(input_data)
|
||||
with pytest.raises(TypeError):
|
||||
net(None)
|
||||
|
||||
|
||||
class AssignOp(nn.Cell):
|
||||
def __init__(self):
|
||||
super(AssignOp, self).__init__()
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test adam """
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR, Lamb
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""Net definition"""
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.fc1 = nn.Dense(128, 768, activation='relu')
|
||||
self.fc2 = nn.Dense(128, 768, activation='relu')
|
||||
self.fc3 = nn.Dense(128, 768, activation='relu')
|
||||
self.fc4 = nn.Dense(768, 768, activation='relu')
|
||||
self.relu4 = nn.ReLU()
|
||||
self.relu5 = nn.ReLU()
|
||||
self.transpose = P.Transpose()
|
||||
self.matmul1 = P.MatMul()
|
||||
self.matmul2 = P.MatMul()
|
||||
|
||||
def construct(self, x):
|
||||
q = self.fc1(x)
|
||||
k = self.fc2(x)
|
||||
v = self.fc3(x)
|
||||
k = self.transpose(k, (1, 0))
|
||||
c = self.relu4(self.matmul1(q, k))
|
||||
s = self.relu5(self.matmul2(c, v))
|
||||
s = self.fc4(s)
|
||||
return s
|
||||
|
||||
|
||||
def test_AdamWeightDecayDynamicLR():
|
||||
""" test_AdamWeightDecayDynamicLR """
|
||||
auto_parallel_context().set_enable_parallel_optimizer(True)
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2)
|
||||
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
||||
net = Net()
|
||||
net.set_train()
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
optimizer = AdamWeightDecayDynamicLR(net.trainable_params(), decay_steps=20, learning_rate=0.1)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
_executor.compile(train_network, inputs, label)
|
||||
|
||||
|
||||
def test_AdamWeightDecay():
|
||||
""" test_AdamWeightDecayDynamicLR """
|
||||
auto_parallel_context().set_enable_parallel_optimizer(True)
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2)
|
||||
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
||||
net = Net()
|
||||
net.set_train()
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
optimizer = AdamWeightDecay(net.trainable_params(), learning_rate=0.1)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
_executor.compile(train_network, inputs, label)
|
||||
|
||||
|
||||
def test_lamb_compile():
|
||||
""" test_Lamb_compile """
|
||||
auto_parallel_context().set_enable_parallel_optimizer(True)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2)
|
||||
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
||||
net = Net()
|
||||
net.set_train()
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
optimizer = Lamb(net.trainable_params(), decay_steps=10)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
_executor.compile(train_network, inputs, label)
|
||||
|
||||
|
||||
def test_edge_case():
|
||||
""" test_edge_case """
|
||||
auto_parallel_context().set_enable_parallel_optimizer(True)
|
||||
net = Net()
|
||||
with pytest.raises(RuntimeError):
|
||||
context.set_auto_parallel_context(parallel_mode="stand_alone")
|
||||
Lamb(net.trainable_params(), decay_steps=10)
|
||||
with pytest.raises(RuntimeError):
|
||||
Adam(net.trainable_params(), learning_rate=0.1)
|
||||
with pytest.raises(RuntimeError):
|
||||
context.set_auto_parallel_context(device_num=16)
|
||||
Lamb(net.trainable_params(), decay_steps=10)
|
|
@ -81,6 +81,10 @@ def test_set_auto_parallel_context():
|
|||
with pytest.raises(ValueError):
|
||||
set_algo_parameters(tensor_slice_align_size=1025)
|
||||
|
||||
auto_parallel_context().set_enable_parallel_optimizer(True)
|
||||
assert auto_parallel_context().get_enable_parallel_optimizer() is True
|
||||
assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
|
||||
|
||||
|
||||
def test_reset_auto_parallel_context():
|
||||
context.reset_auto_parallel_context()
|
||||
|
|
Loading…
Reference in New Issue