diff --git a/mindspore/ccsrc/parallel/context.cc b/mindspore/ccsrc/parallel/context.cc index 8957dc842c4..062d814aa04 100644 --- a/mindspore/ccsrc/parallel/context.cc +++ b/mindspore/ccsrc/parallel/context.cc @@ -62,6 +62,7 @@ void ParallelContext::Reset() { enable_all_reduce_fusion_ = false; strategy_ckpt_load_file_ = ""; strategy_ckpt_save_file_ = ""; + enable_parallel_optimizer_ = false; } void ParallelContext::set_device_num(int32_t device_num) { diff --git a/mindspore/ccsrc/parallel/context.h b/mindspore/ccsrc/parallel/context.h index efa528d1793..6a503ca7eda 100644 --- a/mindspore/ccsrc/parallel/context.h +++ b/mindspore/ccsrc/parallel/context.h @@ -100,6 +100,11 @@ class ParallelContext { void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file); std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; } + void set_enable_parallel_optimizer(bool enable_parallel_optimizer) { + enable_parallel_optimizer_ = enable_parallel_optimizer; + } + bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } + void Reset(); private: @@ -123,6 +128,7 @@ class ParallelContext { std::map> all_reduce_fusion_split_sizes_; std::string strategy_ckpt_load_file_; std::string strategy_ckpt_save_file_; + bool enable_parallel_optimizer_; }; void ParallelParameterContextInit(const FuncGraphPtr &func_graph); diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index 7025447a29c..dc309808d9a 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -205,6 +205,10 @@ PYBIND11_MODULE(_c_expression, m) { .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.") .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.") + .def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer, + "Set enable/disable parallel optimizer.") + .def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer, + "Get enable/disable parallel optimizer.") .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); (void)py::class_>(m, "CostModelContext") diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index ba56af62193..b3d4bc8a0ed 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -29,8 +29,9 @@ from .optimizer import Optimizer _adam_opt = C.MultitypeFuncGraph("adam_opt") -@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") -def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): +@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Bool", "Bool") +def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag, optim_filter): """ Update parameters. @@ -44,38 +45,44 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad m (Tensor): m value of parameters. v (Tensor): v value of parameters. gradient (Tensor): Gradient of parameters. + decay_flag (bool): Applies weight decay or not. + optim_filter (bool): Applies parameter update or not. Returns: Tensor, the new value of v after updating. """ - op_mul = P.Mul() - op_square = P.Square() - op_sqrt = P.Sqrt() - op_cast = P.Cast() - op_reshape = P.Reshape() - op_shape = P.Shape() + if optim_filter: + op_mul = P.Mul() + op_square = P.Square() + op_sqrt = P.Sqrt() + op_cast = P.Cast() + op_reshape = P.Reshape() + op_shape = P.Shape() - param_fp32 = op_cast(param, mstype.float32) - m_fp32 = op_cast(m, mstype.float32) - v_fp32 = op_cast(v, mstype.float32) - gradient_fp32 = op_cast(gradient, mstype.float32) + param_fp32 = op_cast(param, mstype.float32) + m_fp32 = op_cast(m, mstype.float32) + v_fp32 = op_cast(v, mstype.float32) + gradient_fp32 = op_cast(gradient, mstype.float32) - next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) + next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) + - beta1, gradient_fp32) - next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - - beta2, op_square(gradient_fp32)) + next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) + - beta2, op_square(gradient_fp32)) - update = next_m / (eps + op_sqrt(next_v)) - if decay_flag: - update = op_mul(weight_decay_tensor, param_fp32) + update - update_with_lr = op_mul(lr, update) - next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) + update = next_m / (eps + op_sqrt(next_v)) + if decay_flag: + update = op_mul(weight_decay_tensor, param_fp32) + update - next_v = F.depend(next_v, F.assign(param, op_cast(next_param, F.dtype(param)))) - next_v = F.depend(next_v, F.assign(m, op_cast(next_m, F.dtype(m)))) - next_v = F.depend(next_v, F.assign(v, op_cast(next_v, F.dtype(v)))) - return next_v + update_with_lr = op_mul(lr, update) + next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) + + next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) + next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) + next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) + return next_param + return gradient def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): @@ -300,7 +307,7 @@ class AdamWeightDecay(Optimizer): - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. Outputs: - tuple[Parameter], the updated velocity value, the shape is the same as `params`. + tuple[bool], all elements are True. Examples: >>> net = Net() @@ -328,11 +335,13 @@ class AdamWeightDecay(Optimizer): def construct(self, gradients): lr = self.get_lr() - updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor), - self.params, self.moments1, self.moments2, gradients, self.decay_flag) - - return updated_velocity + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, + self.weight_decay_tensor), + self.params, self.moments1, self.moments2, gradients, + self.decay_flag, self.optim_filter) + if self.use_parallel: + optim_result = self.broadcast_params(optim_result) + return optim_result class AdamWeightDecayDynamicLR(Optimizer): @@ -363,7 +372,7 @@ class AdamWeightDecayDynamicLR(Optimizer): - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. Outputs: - tuple[Parameter], the updated velocity value, the shape is the same as `params`. + tuple[bool], all elements are True. Examples: >>> net = Net() @@ -424,12 +433,14 @@ class AdamWeightDecayDynamicLR(Optimizer): warmup_lr = self.start_learning_rate * warmup_percent is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32) lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr - updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor), - self.params, self.moments1, self.moments2, gradients, self.decay_flag) - + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, + self.weight_decay_tensor), + self.params, self.moments1, self.moments2, gradients, + self.decay_flag, self.optim_filter) + if self.use_parallel: + optim_result = self.broadcast_params(optim_result) added_global_step = self.global_step + self.one F.control_depend(lr, added_global_step) self.global_step = added_global_step - return updated_velocity + return optim_result diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index 832b35d66f1..93c7edbce84 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -32,11 +32,10 @@ num_one = Tensor(np.ones([1]), mstype.float32) _lamb_opt = C.MultitypeFuncGraph("lamb_opt") - -@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", - "Tensor", "Tensor", "Tensor", "Tensor", "Bool") +@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Bool", "Bool") def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v, - gradient, decay_flag): + gradient, decay_flag, optim_filter): """ Update parameters. @@ -52,66 +51,66 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para v (Tensor): v value of parameters. gradient (Tensor): Gradient of parameters. decay_flag (bool): Specifies whether param update with weight decay. + optim_filter(bool): Applies parameter update or not. Returns: Tensor, the new value of v after updating. """ - op_mul = P.Mul() - op_sqrt = P.Sqrt() - op_rsqrt = P.Rsqrt() - op_square = P.Square() - op_cast = P.Cast() - op_reshape = P.Reshape() - op_shape = P.Shape() - op_pow = P.Pow() - op_norm = layer.Norm() - op_select = P.Select() - op_greater = P.Greater() - op_fill = P.Fill() - op_dtype = P.DType() + if optim_filter: + op_mul = P.Mul() + op_sqrt = P.Sqrt() + op_rsqrt = P.Rsqrt() + op_square = P.Square() + op_cast = P.Cast() + op_reshape = P.Reshape() + op_shape = P.Shape() + op_pow = P.Pow() + op_norm = layer.Norm() + op_select = P.Select() + op_greater = P.Greater() + op_fill = P.Fill() + op_dtype = P.DType() - param_fp32 = op_cast(param, mstype.float32) - m_fp32 = op_cast(m, mstype.float32) - v_fp32 = op_cast(v, mstype.float32) - gradient_fp32 = op_cast(gradient, mstype.float32) + param_fp32 = op_cast(param, mstype.float32) + m_fp32 = op_cast(m, mstype.float32) + v_fp32 = op_cast(v, mstype.float32) + gradient_fp32 = op_cast(gradient, mstype.float32) - next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, - mstype.float32) - beta1, gradient_fp32) + next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32) - next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, - mstype.float32) - beta2, op_square(gradient_fp32)) + next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32)) - next_mm = next_m / (op_cast(num_one, mstype.float32) - - op_pow(beta1, op_cast(global_step + num_one, mstype.float32))) - next_vv = next_v / (op_cast(num_one, mstype.float32) - - op_pow(beta2, op_cast(global_step + num_one, mstype.float32))) - w_norm = op_norm(param_fp32) - g_norm = op_norm(gradient_fp32) + next_mm = next_m / (op_cast(num_one, mstype.float32) + - op_pow(beta1, op_cast(global_step + num_one, mstype.float32))) + next_vv = next_v / (op_cast(num_one, mstype.float32) - + op_pow(beta2, op_cast(global_step + num_one, mstype.float32))) + w_norm = op_norm(param_fp32) + g_norm = op_norm(gradient_fp32) - g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt( - next_vv + eps)) + weight_decay_tensor * param_fp32) - zeros = F.zeros_like(w_norm) - ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) - trust_ratio = op_select( - op_greater(w_norm, zeros), - op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones), - ones) - tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0) - trust_ratio = C.clip_by_value(trust_ratio, zeros, tens) - update = next_mm / (op_sqrt(next_vv) + eps) + g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32) + zeros = F.zeros_like(w_norm) + ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) + trust_ratio = op_select( + op_greater(w_norm, zeros), + op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones), + ones) + tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0) + trust_ratio = C.clip_by_value(trust_ratio, zeros, tens) + update = next_mm / (op_sqrt(next_vv) + eps) - if decay_flag: - update = update + op_mul(weight_decay_tensor, param_fp32) + if decay_flag: + update = update + op_mul(weight_decay_tensor, param_fp32) - update_with_lr = op_mul(op_mul(trust_ratio, lr), update) + update_with_lr = op_mul(op_mul(trust_ratio, lr), update) - next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) + next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) - next_v = F.depend(next_v, F.assign(param, next_param)) - next_v = F.depend(next_v, F.assign(m, next_m)) - next_v = F.depend(next_v, F.assign(v, next_v)) + next_param = F.depend(next_param, F.assign(param, next_param)) + next_param = F.depend(next_param, F.assign(m, next_m)) + next_param = F.depend(next_param, F.assign(v, next_v)) - return next_v + return next_param + return gradient lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel") @@ -238,7 +237,7 @@ class Lamb(Optimizer): - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. Outputs: - tuple[Parameter], the updated velocity value, the shape is the same as `params`. + tuple[bool], all elements are True. Examples: >>> net = Net() @@ -311,18 +310,21 @@ class Lamb(Optimizer): self.warmup_steps, self.global_step), mstype.float32) lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr if self.enable_graph_kernel: - updated_velocity = self.hyper_map(F.partial(lamb_opt_graph_kernel, - self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor, self.global_step), - self.params, self.moments1, self.moments2, gradients, self.decay_flag) + optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, + self.beta1, self.beta2, self.eps, lr, + self.weight_decay_tensor, self.global_step), + self.params, self.moments1, self.moments2, gradients, self.decay_flag) else: - updated_velocity = self.hyper_map(F.partial(_lamb_opt, - self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor, self.global_step), - self.params, self.moments1, self.moments2, gradients, self.decay_flag) + optim_result = self.hyper_map(F.partial(_lamb_opt, + self.beta1, self.beta2, self.eps, lr, + self.weight_decay_tensor, self.global_step), + self.params, self.moments1, self.moments2, gradients, + self.decay_flag, self.optim_filter) + if self.use_parallel: + optim_result = self.broadcast_params(optim_result) added_global_step = self.global_step + self.one F.control_depend(lr, added_global_step) self.global_step = added_global_step - return updated_velocity + return optim_result diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 40642226adc..cdda93ab498 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -22,11 +22,14 @@ from mindspore.ops import functional as F, composite as C, operations as P from mindspore.nn.cell import Cell from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.initializer import initializer +from mindspore.common.tensor import Tensor import mindspore.common.dtype as mstype from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel -from mindspore.common.tensor import Tensor from mindspore import log as logger +from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.train.parallel_utils import ParallelMode __all__ = ['Optimizer'] @@ -155,6 +158,27 @@ class Optimizer(Cell): self.param_length = len(self.parameters) self.map_ = C.Map() + use_parallel = auto_parallel_context().get_enable_parallel_optimizer() + self.use_parallel = use_parallel + if use_parallel: + if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]: + raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name)) + if _get_parallel_mode() not in [ParallelMode.HYBRID_PARALLEL, ParallelMode.DATA_PARALLEL, + ParallelMode.AUTO_PARALLEL]: + raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format + (_get_parallel_mode())) + self.dev_num = _get_device_num() + if self.dev_num > self.param_length: + raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is" + " less than the number of devices {}".format(self.param_length, self.dev_num)) + self.param_rank = self._get_parameter_group_id() + self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank)) + self.param_names = [] + for param in self.parameters: + self.param_names.append(param.name) + else: + self.optim_filter = (True,) * self.param_length + def decay_weight(self, gradients): """ Weight decay. @@ -401,6 +425,51 @@ class Optimizer(Cell): lr = self.learning_rate return lr + def _get_parameter_group_id(self): + """ + Get the parameter partition group id, which is less than the number of devices. + + Returns: + tuple, the group id tuple of parameters. + """ + rank_list = () + count = 0 + for _ in range(self.param_length): + rank_list = rank_list + (count,) + count = count + 1 + if count == self.dev_num: + count = 0 + return rank_list + + def broadcast_params(self, optim_result): + """ + Apply Broadcast operations in the sequential order of parameter groups. + + Returns: + bool, the status flag. + """ + param_group = [] + key_group = [] + for _ in range(self.dev_num): + param_group.append(F.make_tuple()) + key_group.append(F.make_tuple()) + for i in range(self.param_length): + param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],) + key = P.MakeRefKey(self.param_names[i])() + key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,) + new_param_group = [] + for root in range(self.dev_num): + ops = P.Broadcast(root) + next_params = ops(param_group[root]) + new_param_group.append(next_params) + for i in range(F.tuple_len(next_params)): + F.assign(key_group[root][i], next_params[i]) + status = True + for i in range(self.dev_num - 1): + status = F.control_depend(new_param_group[i][0], new_param_group[i+1]) + + return status + def construct(self, *hyper_params): raise NotImplementedError diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py index 38cf0141f09..30c943b69f9 100644 --- a/mindspore/ops/composite/multitype_ops/setitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -85,6 +85,22 @@ def _list_setitem_with_List(data, number_index, value): return F.list_setitem(data, number_index, value) +@setitem.register("List", "Number", "Tuple") +def _list_setitem_with_Tuple(data, number_index, value): + """ + Assigns value to list. + + Inputs: + data (list): Data of type lis. + number_index (Number): Index of data. + value (list): Value given. + + Outputs: + list, type is same as the element type of data. + """ + return F.list_setitem(data, number_index, value) + + @setitem.register("Dictionary", "String", "Tensor") def _dict_setitem_with_tensor(data, key, value): """ diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 1b212d161a0..f8b47a28c3f 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -98,6 +98,7 @@ class AllReduce(PrimitiveWithInfer): self.op = op self.add_prim_attr('group', _get_group(group)) self.add_prim_attr('fusion', 0) + self.add_prim_attr('index', 0) def vm_impl(self, x): """Implement by vm mode.""" diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 74c6080ab41..b6b938d800b 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -59,8 +59,7 @@ class Assign(PrimitiveWithInfer): return variable def infer_dtype(self, variable, value): - args = {"variable": variable, "value": value} - validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) + # Add a type validation later when we don't have to assign a value to RefKey. return variable diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 74250f12e5a..93fe2338557 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -400,6 +400,23 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_global_rank_is_set() + def set_enable_parallel_optimizer(self, enable_parallel_optimizer): + """ + Set enable/disable parallel optimizer. + + Args: + set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer. + """ + self.check_context_handle() + if not isinstance(enable_parallel_optimizer, bool): + raise TypeError('enable_parallel_optimizer is invalid type') + self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer) + + def get_enable_parallel_optimizer(self): + """Get parallel optimizer flag.""" + self.check_context_handle() + return self._context_handle.get_enable_parallel_optimizer() + def reset(self): """Reset all settings.""" self.check_context_handle() @@ -433,7 +450,8 @@ _set_auto_parallel_context_func_map = { "parameter_broadcast": auto_parallel_context().set_parameter_broadcast, "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, - "full_batch": auto_parallel_context().set_full_batch} + "full_batch": auto_parallel_context().set_full_batch, + "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer} _get_auto_parallel_context_func_map = { @@ -447,13 +465,15 @@ _get_auto_parallel_context_func_map = { "parameter_broadcast": auto_parallel_context().get_parameter_broadcast, "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file, "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, - "full_batch": auto_parallel_context().get_full_batch} + "full_batch": auto_parallel_context().get_full_batch, + "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer} @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, - strategy_ckpt_save_file=str, full_batch=bool) + strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) + def _set_auto_parallel_context(**kwargs): """ Set auto parallel context. @@ -493,6 +513,7 @@ def _set_auto_parallel_context(**kwargs): strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' full_batch (bool): Whether to load the whole batch on each device. Default: False. + enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False. Raises: ValueError: If input key is not attribute in auto parallel context. @@ -535,5 +556,6 @@ def _reset_auto_parallel_context(): - parameter_broadcast: False. - strategy_ckpt_load_file: "" - strategy_ckpt_save_file: "" + - enable_parallel_optimizer: False """ auto_parallel_context().reset() diff --git a/tests/ut/cpp/parallel/step_parallel_test.cc b/tests/ut/cpp/parallel/step_parallel_test.cc index b526195958d..d8f8681a349 100644 --- a/tests/ut/cpp/parallel/step_parallel_test.cc +++ b/tests/ut/cpp/parallel/step_parallel_test.cc @@ -327,6 +327,9 @@ TEST_F(TestStepParallel, CreatOpInstance) { } else if (name == "instance_name") { parse::ConvertData(py::cast(item.second), &converted_ret); ASSERT_EQ(converted_ret->ToString(), "test"); + } else if (name == "index") { + parse::ConvertData(py::cast(item.second), &converted_ret); + ASSERT_EQ(converted_ret->ToString(), "0"); } else { MS_LOG(EXCEPTION) << "Test failed"; } diff --git a/tests/ut/python/ops/test_signature.py b/tests/ut/python/ops/test_signature.py index 6de00330af0..10b84254f4e 100644 --- a/tests/ut/python/ops/test_signature.py +++ b/tests/ut/python/ops/test_signature.py @@ -16,7 +16,6 @@ test assign sub """ import numpy as np -import pytest import mindspore.context as context import mindspore.nn as nn @@ -36,27 +35,6 @@ class AssignW(nn.Cell): return x -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.b = Parameter(initializer('ones', [5]), name='b') - self.assign = AssignW() - - def construct(self, value): - return self.assign(self.b, value) - - -def test_assign_through_cell(): - context.set_context(mode=context.GRAPH_MODE) - net = Net() - net.to_float(ms.float16) - net.add_flags_recursive(fp16=False) - input_data = Tensor(np.ones([5]).astype(np.float32)) - net(input_data) - with pytest.raises(TypeError): - net(None) - - class AssignOp(nn.Cell): def __init__(self): super(AssignOp, self).__init__() diff --git a/tests/ut/python/parallel/test_parallel_optimizer.py b/tests/ut/python/parallel/test_parallel_optimizer.py new file mode 100644 index 00000000000..6663e34871e --- /dev/null +++ b/tests/ut/python/parallel/test_parallel_optimizer.py @@ -0,0 +1,114 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test adam """ +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import _executor +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR, Lamb +from mindspore.ops import operations as P +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore import context + + +class Net(nn.Cell): + """Net definition""" + + def __init__(self): + super(Net, self).__init__() + self.fc1 = nn.Dense(128, 768, activation='relu') + self.fc2 = nn.Dense(128, 768, activation='relu') + self.fc3 = nn.Dense(128, 768, activation='relu') + self.fc4 = nn.Dense(768, 768, activation='relu') + self.relu4 = nn.ReLU() + self.relu5 = nn.ReLU() + self.transpose = P.Transpose() + self.matmul1 = P.MatMul() + self.matmul2 = P.MatMul() + + def construct(self, x): + q = self.fc1(x) + k = self.fc2(x) + v = self.fc3(x) + k = self.transpose(k, (1, 0)) + c = self.relu4(self.matmul1(q, k)) + s = self.relu5(self.matmul2(c, v)) + s = self.fc4(s) + return s + + +def test_AdamWeightDecayDynamicLR(): + """ test_AdamWeightDecayDynamicLR """ + auto_parallel_context().set_enable_parallel_optimizer(True) + context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2) + inputs = Tensor(np.ones([32, 128]).astype(np.float32)) + label = Tensor(np.zeros([32, 768]).astype(np.float32)) + net = Net() + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + optimizer = AdamWeightDecayDynamicLR(net.trainable_params(), decay_steps=20, learning_rate=0.1) + + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label) + + +def test_AdamWeightDecay(): + """ test_AdamWeightDecayDynamicLR """ + auto_parallel_context().set_enable_parallel_optimizer(True) + context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2) + inputs = Tensor(np.ones([32, 128]).astype(np.float32)) + label = Tensor(np.zeros([32, 768]).astype(np.float32)) + net = Net() + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + optimizer = AdamWeightDecay(net.trainable_params(), learning_rate=0.1) + + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label) + + +def test_lamb_compile(): + """ test_Lamb_compile """ + auto_parallel_context().set_enable_parallel_optimizer(True) + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2) + inputs = Tensor(np.ones([32, 128]).astype(np.float32)) + label = Tensor(np.zeros([32, 768]).astype(np.float32)) + net = Net() + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + optimizer = Lamb(net.trainable_params(), decay_steps=10) + + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label) + + +def test_edge_case(): + """ test_edge_case """ + auto_parallel_context().set_enable_parallel_optimizer(True) + net = Net() + with pytest.raises(RuntimeError): + context.set_auto_parallel_context(parallel_mode="stand_alone") + Lamb(net.trainable_params(), decay_steps=10) + with pytest.raises(RuntimeError): + Adam(net.trainable_params(), learning_rate=0.1) + with pytest.raises(RuntimeError): + context.set_auto_parallel_context(device_num=16) + Lamb(net.trainable_params(), decay_steps=10) diff --git a/tests/ut/python/parallel/test_set_auto_parallel_context.py b/tests/ut/python/parallel/test_set_auto_parallel_context.py index 4343f34d788..c476b0cebc3 100644 --- a/tests/ut/python/parallel/test_set_auto_parallel_context.py +++ b/tests/ut/python/parallel/test_set_auto_parallel_context.py @@ -81,6 +81,10 @@ def test_set_auto_parallel_context(): with pytest.raises(ValueError): set_algo_parameters(tensor_slice_align_size=1025) + auto_parallel_context().set_enable_parallel_optimizer(True) + assert auto_parallel_context().get_enable_parallel_optimizer() is True + assert not auto_parallel_context().get_all_reduce_fusion_split_indices() + def test_reset_auto_parallel_context(): context.reset_auto_parallel_context()